# VICREG

## Setup

In [1]:
! nvidia-smi

Thu Nov 28 10:00:21 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.256.02   Driver Version: 470.256.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA TITAN Xp     Off  | 00000000:01:00.0 Off |                  N/A |
| 23%   28C    P8     8W / 250W |   1627MiB / 12196MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import transforms
import torch.distributed as dist

def off_diagonal(x):
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

def get_byol_transforms(size, mean, std):
    transformT = transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomResizedCrop(size=size, scale=(0.08,1), ratio=(3 / 4, 4 / 3)),
        transforms.RandomRotation((-90, 90)),
        transforms.ColorJitter(),
        transforms.GaussianBlur(kernel_size=(23,23), sigma=(0.1, 2.0)),
        transforms.RandomGrayscale(p=0.2),
        transforms.Normalize(mean, std),
        ])

    transformT1 = transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomResizedCrop(size=size, scale=(0.08,1), ratio=(3 / 4, 4 / 3)),
        transforms.RandomRotation((-90, 90)),
        transforms.ColorJitter(),
        transforms.RandomGrayscale(p=0.2),
        transforms.GaussianBlur(kernel_size=(23,23), sigma=(0.1, 2.0)),
        transforms.Normalize(mean, std),
        ])

    transformEvalT = transforms.Compose([
        transforms.ToTensor(),
        transforms.CenterCrop(size=size),
        transforms.Normalize(mean, std),        
    ])

    return transformT, transformT1, transformEvalT

def get_cxr_transforms(size=224, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    transform = transforms.Compose([
        transforms.RandomResizedCrop(size, scale=(0.3, 0.9), ratio=(3/4, 4/3)),
        transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
                p=0.8
            ),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=25, sigma=(0.1, 2.0))], p=0.5),
        transforms.Resize((256, 256)),
        transforms.CenterCrop(size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    transform_eval = transforms.Compose([
        transforms.ToTensor(),
        transforms.CenterCrop(size=size),
        transforms.Normalize(mean, std),        
    ])

    return transform, transform, transform_eval


class MultiViewDataInjector(object):
    def __init__(self, *args):
        self.transforms = args[0]
        self.random_flip = transforms.RandomHorizontalFlip()

    def __call__(self, sample, *with_consistent_flipping):
        if with_consistent_flipping:
            sample = self.random_flip(sample)
        output = [transform(sample) for transform in self.transforms]
        return output

In [3]:
import torch
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self,
    input_size = 2048,
    output_size = 8192,
    depth = 3,
    ):
        super().__init__()
        layers = []
        inp = input_size
        for d in range(depth):
            if d == depth - 1:
                layers.append(nn.Linear(inp, output_size))
            else:
                layers.extend([nn.Linear(inp, output_size), nn.BatchNorm1d(output_size), nn.ReLU(inplace=True)])
                inp = output_size
        self.layer = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.layer(x)

class VicReg(nn.Module):
    def __init__(self,
    backend = 'resnet50',
    input_size = 2048,
    output_size = 8192,
    depth_projector = 3,
    lmbd = 5e-3, u = 1, v= 1, epsilon = 1e-3):

        super().__init__()
        self.backend = backend
        self.projector = MLP(input_size=input_size, output_size=output_size, depth=depth_projector)
        self.output_size = output_size
        self.epsilon = epsilon
        self.u = u
        self.v = v
        self.lmbd = lmbd
    
    def forward(self, x, y):
        x = self.backend(x)
        y = self.backend(y)
        x = self.projector(x)
        y = self.projector(y)

        bs = x.size(0)
    
        repr_loss = F.mse_loss(x, y)
    
        # x = torch.cat(FullGatherLayer.apply(x), dim=0)
        # y = torch.cat(FullGatherLayer.apply(y), dim=0)
        x = x - x.mean(dim=0)
        y = y - y.mean(dim=0)
    
        std_x = torch.sqrt(x.var(dim=0) + self.epsilon)
        std_y = torch.sqrt(y.var(dim=0) + self.epsilon)
        std_loss = torch.mean(F.relu(1 - std_x)) / 2 + torch.mean(F.relu(1 - std_y)) / 2
    
        cov_x = (x.T @ x) / (bs - 1)
        cov_y = (y.T @ y) / (bs - 1)
        cov_loss = off_diagonal(cov_x).pow_(2).sum().div(self.output_size) + off_diagonal(cov_y).pow_(2).sum().div(self.output_size)
    
        loss = (
            self.u * repr_loss
            + self.v * std_loss
            + self.lmbd * cov_loss
        )
        
        # Free tensors
        del x, y, cov_x, cov_y, std_x, std_y
        torch.cuda.empty_cache()

        return loss

In [4]:
# From https://github.com/noahgolmant/pytorch-lars
""" Layer-wise adaptive rate scaling for SGD in PyTorch! """
from torch.optim.optimizer import Optimizer, required

class LARS(Optimizer):
    r"""Implements layer-wise adaptive rate scaling for SGD.

    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float): base learning rate (\gamma_0)
        momentum (float, optional): momentum factor (default: 0) ("m")
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
            ("\beta")
        eta (float, optional): LARS coefficient
        max_epoch: maximum training epoch to determine polynomial LR decay.

    Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg.
    Large Batch Training of Convolutional Networks:
        https://arxiv.org/abs/1708.03888

    Example:
        >>> optimizer = LARS(model.parameters(), lr=0.1, eta=1e-3)
        >>> optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> optimizer.step()
    """
    def __init__(self, params, lr=required, momentum=.9,
                 weight_decay=.0005, eta=0.001, max_epoch=200):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}"
                             .format(weight_decay))
        if eta < 0.0:
            raise ValueError("Invalid LARS coefficient value: {}".format(eta))

        self.epoch = 0
        defaults = dict(lr=lr, momentum=momentum,
                        weight_decay=weight_decay,
                        eta=eta, max_epoch=max_epoch)
        super(LARS, self).__init__(params, defaults)

    def step(self, epoch=None, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
            epoch: current epoch to calculate polynomial LR decay schedule.
                   if None, uses self.epoch and increments it.
        """
        loss = None
        if closure is not None:
            loss = closure()

        if epoch is None:
            epoch = self.epoch
            self.epoch += 1

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            eta = group['eta']
            lr = group['lr']
            max_epoch = group['max_epoch']

            for p in group['params']:
                if p.grad is None:
                    continue

                param_state = self.state[p]
                d_p = p.grad.data

                weight_norm = torch.norm(p.data)
                grad_norm = torch.norm(d_p)

                # Global LR computed on polynomial decay schedule
                decay = (1 - float(epoch) / max_epoch) ** 2
                global_lr = lr * decay

                # Compute local learning rate for this layer
                local_lr = eta * weight_norm / \
                    (grad_norm + weight_decay * weight_norm)

                # Update the momentum term
                actual_lr = local_lr * global_lr

                if 'momentum_buffer' not in param_state:
                    buf = param_state['momentum_buffer'] = \
                            torch.zeros_like(p.data)
                else:
                    buf = param_state['momentum_buffer']
                buf.mul_(momentum).add_(actual_lr, d_p + weight_decay * p.data)
                p.data.add_(-buf)

        return loss

## Pre-training

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.transforms import transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Subset, random_split
import numpy as np
from tqdm import tqdm
import time

In [7]:
# Set params
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
epochs = 100
batch_size = 32
base_lr = 0.25
weight_decay = 1e-6

In [8]:
# Load datasets
transform_x1, transform_x2, transform_test = get_cxr_transforms()

dataset_path = './datasets/chestx-ray14-v3'
train_dataset = datasets.ImageFolder(root=f'{dataset_path}/train', transform=MultiViewDataInjector([transform_x1, transform_x2]))
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)

test_dataset = datasets.ImageFolder(root=f'{dataset_path}/test', transform=MultiViewDataInjector([transform_x1, transform_x2]))
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=8)

In [9]:
# Train loop
def train_loop(model, epoch, optimizer, train_loader, device):
    tk0 = tqdm(train_loader)
    train_loss = []

    start_time = time.time()

    for (x, x1), _ in tk0:
        x = x.to(device)
        x1 = x1.to(device)

        loss = model.forward(x, x1)

        train_loss.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step(epoch)

        # Free tensors
        del x, x1, loss
        torch.cuda.empty_cache()

    return train_loss, time.time() - start_time

In [10]:
from torchvision.models import efficientnet_b0

# Load model
model = efficientnet_b0(pretrained=True)
embed_dim = model.classifier[1].in_features
output_features = embed_dim*4
model.classifier[1] = nn.Identity()
model = VicReg(input_size=embed_dim, output_size=output_features, backend=model, depth_projector=3)
model

VicReg(
  (backend): EfficientNet(
    (features): Sequential(
      (0): ConvNormActivation(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU(inplace=True)
      )
      (1): Sequential(
        (0): MBConv(
          (block): Sequential(
            (0): ConvNormActivation(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
              (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
            (1): SqueezeExcitation(
              (avgpool): AdaptiveAvgPool2d(output_size=1)
              (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
              (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
              (activation): SiLU(inplace=True)
              (scale_activation

In [11]:
# Train script
start_time = time.time()

model = model.to(device)
params = model.parameters()
optimizer = LARS(params, lr=(base_lr * batch_size) / 256, weight_decay=weight_decay)

for epoch in range(epochs):
    train_loss, run_time = train_loop(model, epoch, optimizer, train_loader, device)
    print(f'Epoch: [{epoch+1}/{epochs}]  loss: {np.mean(train_loss)}, time: {run_time}')

end_time = time.time()
train_time = end_time - start_time
print(f'Training time {train_time}')

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ../torch/csrc/utils/python_arg_parser.cpp:1055.)
  buf.mul_(momentum).add_(actual_lr, d_p + weight_decay * p.data)
100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:16<00:00,  1.92it/s]


Epoch: [1/100]  loss: 0.8506856423521799, time: 196.4984793663025


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:12<00:00,  1.96it/s]


Epoch: [2/100]  loss: 0.7717970983376579, time: 192.84085035324097


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:14<00:00,  1.94it/s]


Epoch: [3/100]  loss: 0.7427284105744942, time: 194.91769289970398


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:17<00:00,  1.92it/s]


Epoch: [4/100]  loss: 0.7285782631742891, time: 197.13048815727234


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:18<00:00,  1.91it/s]


Epoch: [5/100]  loss: 0.7185948177304848, time: 198.211341381073


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [6/100]  loss: 0.711681333956895, time: 199.28527522087097


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [7/100]  loss: 0.7075205147266388, time: 199.9650321006775


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [8/100]  loss: 0.7030118675458998, time: 200.446209192276


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [9/100]  loss: 0.6985193910422148, time: 200.17217111587524


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [10/100]  loss: 0.6964212703326392, time: 200.21403455734253


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [11/100]  loss: 0.6934452101036355, time: 199.4116337299347


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [12/100]  loss: 0.6907250300917045, time: 199.29016137123108


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [13/100]  loss: 0.689308600608634, time: 199.58272886276245


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [14/100]  loss: 0.6872269717789201, time: 199.0073902606964


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:18<00:00,  1.90it/s]


Epoch: [15/100]  loss: 0.6846566037841575, time: 198.51671290397644


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:18<00:00,  1.90it/s]


Epoch: [16/100]  loss: 0.6837639685661073, time: 198.88416862487793


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [17/100]  loss: 0.6821513614326558, time: 199.1288025379181


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:18<00:00,  1.90it/s]


Epoch: [18/100]  loss: 0.6808242392603052, time: 198.6291708946228


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [19/100]  loss: 0.6803705993468169, time: 199.32719230651855


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [20/100]  loss: 0.678297890399499, time: 199.63642740249634


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [21/100]  loss: 0.6776622861781448, time: 199.80562090873718


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [22/100]  loss: 0.6765590775265264, time: 200.74028182029724


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [23/100]  loss: 0.6749390127166869, time: 200.31053042411804


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [24/100]  loss: 0.6745542872835089, time: 201.27559232711792


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.87it/s]


Epoch: [25/100]  loss: 0.6740481661110328, time: 201.94392919540405


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:22<00:00,  1.87it/s]


Epoch: [26/100]  loss: 0.6728149210965192, time: 202.17541027069092


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:22<00:00,  1.87it/s]


Epoch: [27/100]  loss: 0.6724035413492293, time: 202.13495922088623


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:22<00:00,  1.87it/s]


Epoch: [28/100]  loss: 0.6711558734298383, time: 202.66870379447937


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:22<00:00,  1.87it/s]


Epoch: [29/100]  loss: 0.670596240216462, time: 202.3392996788025


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:23<00:00,  1.86it/s]


Epoch: [30/100]  loss: 0.6699092186317241, time: 203.44974756240845


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:23<00:00,  1.86it/s]


Epoch: [31/100]  loss: 0.6688218802686722, time: 203.65258741378784


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:23<00:00,  1.86it/s]


Epoch: [32/100]  loss: 0.6683356687820777, time: 203.22749733924866


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:24<00:00,  1.85it/s]


Epoch: [33/100]  loss: 0.667772668063956, time: 204.18206405639648


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:22<00:00,  1.86it/s]


Epoch: [34/100]  loss: 0.6672457223216062, time: 202.97179007530212


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:22<00:00,  1.87it/s]


Epoch: [35/100]  loss: 0.6667723824422827, time: 202.62185621261597


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:22<00:00,  1.87it/s]


Epoch: [36/100]  loss: 0.6667826642435064, time: 202.10924530029297


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:22<00:00,  1.87it/s]


Epoch: [37/100]  loss: 0.6658269622653881, time: 202.3049440383911


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:22<00:00,  1.87it/s]


Epoch: [38/100]  loss: 0.6654899117177125, time: 202.4152705669403


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [39/100]  loss: 0.6642668292951331, time: 201.4390037059784


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [40/100]  loss: 0.6639109298034951, time: 201.1891906261444


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [41/100]  loss: 0.6637139046002948, time: 200.64772200584412


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [42/100]  loss: 0.6633012465068272, time: 200.39906096458435


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [43/100]  loss: 0.6636190469618197, time: 200.73982167243958


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [44/100]  loss: 0.662460389749083, time: 200.76992678642273


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [45/100]  loss: 0.662511578627995, time: 200.8782012462616


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [46/100]  loss: 0.662011427223367, time: 201.3355689048767


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.87it/s]


Epoch: [47/100]  loss: 0.6619529735158991, time: 201.72804474830627


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:22<00:00,  1.87it/s]


Epoch: [48/100]  loss: 0.6611580924382285, time: 202.23471355438232


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:22<00:00,  1.87it/s]


Epoch: [49/100]  loss: 0.6608922205589436, time: 202.5001997947693


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:22<00:00,  1.87it/s]


Epoch: [50/100]  loss: 0.6601238011052368, time: 202.3122980594635


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.87it/s]


Epoch: [51/100]  loss: 0.6603769981671893, time: 201.72272181510925


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.87it/s]


Epoch: [52/100]  loss: 0.6598373845456138, time: 201.82702040672302


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [53/100]  loss: 0.6594041213787422, time: 201.43177914619446


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [54/100]  loss: 0.6597755994430925, time: 200.22894549369812


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [55/100]  loss: 0.6589938676546491, time: 199.56179523468018


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [56/100]  loss: 0.6584222674685181, time: 199.43970370292664


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [57/100]  loss: 0.6583894486465152, time: 199.5794439315796


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [58/100]  loss: 0.6592158955871743, time: 199.94313883781433


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [59/100]  loss: 0.6578419758844628, time: 199.677987575531


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [60/100]  loss: 0.6580185644210331, time: 199.98255610466003


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [61/100]  loss: 0.6582884936736374, time: 200.35149788856506


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [62/100]  loss: 0.6578852092778241, time: 201.01632809638977


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [63/100]  loss: 0.6578649328183875, time: 201.50020623207092


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [64/100]  loss: 0.6568007849196278, time: 201.46880745887756


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [65/100]  loss: 0.6563273471183878, time: 200.58674907684326


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [66/100]  loss: 0.6570721904121378, time: 199.88097167015076


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [67/100]  loss: 0.6557208406862127, time: 199.5185809135437


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [68/100]  loss: 0.6568461533576723, time: 199.25972700119019


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [69/100]  loss: 0.6549879068104678, time: 199.78171396255493


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [70/100]  loss: 0.6560425613292311, time: 199.99852228164673


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [71/100]  loss: 0.6562193540668992, time: 200.34220385551453


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [72/100]  loss: 0.6550492537084711, time: 200.46139550209045


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [73/100]  loss: 0.655099386260623, time: 201.0068461894989


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [74/100]  loss: 0.6554081711188826, time: 201.16643905639648


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [75/100]  loss: 0.6546081723990264, time: 201.22228693962097


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [76/100]  loss: 0.6546759301077121, time: 200.77629137039185


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [77/100]  loss: 0.6548590973886863, time: 201.03124904632568


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [78/100]  loss: 0.6534861662085094, time: 200.18369340896606


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [79/100]  loss: 0.6537712284181484, time: 199.56834435462952


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [80/100]  loss: 0.6544949308274284, time: 199.74689388275146


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [81/100]  loss: 0.6534125227776785, time: 200.1022663116455


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [82/100]  loss: 0.6533783450328484, time: 200.46666622161865


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [83/100]  loss: 0.6533455094963154, time: 200.33399939537048


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [84/100]  loss: 0.6539621182850429, time: 200.53995966911316


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [85/100]  loss: 0.6521040204340819, time: 199.9142758846283


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [86/100]  loss: 0.6537699401378632, time: 201.0853886604309


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [87/100]  loss: 0.6528708821251279, time: 200.48093461990356


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [88/100]  loss: 0.6528934544671781, time: 200.84310293197632


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [89/100]  loss: 0.6536994002483509, time: 200.89242434501648


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [90/100]  loss: 0.6525887411738199, time: 201.2083125114441


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [91/100]  loss: 0.6528764373410947, time: 199.8130841255188


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [92/100]  loss: 0.6528252878832439, time: 199.91712045669556


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [93/100]  loss: 0.6522210927867385, time: 199.84474730491638


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [94/100]  loss: 0.6525729262324237, time: 199.393492937088


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [95/100]  loss: 0.651811056666904, time: 199.92905116081238


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [96/100]  loss: 0.6517418832690628, time: 199.710125207901


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [97/100]  loss: 0.6521596443400812, time: 200.38924479484558


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [98/100]  loss: 0.6521531392342199, time: 200.90530848503113


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [99/100]  loss: 0.6517015809104556, time: 200.71443724632263


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]

Epoch: [100/100]  loss: 0.6512183704704204, time: 200.84397220611572
Training time 20049.147972106934





In [12]:
# Save results
model_backbone_weights = model.backend
# print(model_backbone_weights)

In [16]:
# ! mkdir ./output/VICReg

In [13]:
save_model_path = './output/VICReg'
torch.save({ 'model_state_dict': model_backbone_weights.state_dict() }, f'{save_model_path}/efficientnet_b0_backbone_weights_v1.ckpt')

## Fine-tuning

In [14]:
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split
import torch

# Path to the dataset
dataset_path = './datasets/COVIDGR_1.0'
positive_path = os.path.join(dataset_path, 'P')
negative_path = os.path.join(dataset_path, 'N')

# Data transformations
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.3, 0.9), ratio=(3/4, 4/3)),
    transforms.RandomApply(
            [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
            p=0.8
        ),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=25, sigma=(0.1, 2.0))], p=0.5),
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# Load dataset with ImageFolder
dataset = datasets.ImageFolder(root=dataset_path, transform=train_transform)

In [15]:
# Step 4: Map labels
dataset.class_to_idx = {'N': 0, 'P': 1}
dataset

Dataset ImageFolder
    Number of datapoints: 852
    Root location: ./datasets/COVIDGR_1.0
    StandardTransform
Transform: Compose(
               RandomResizedCrop(size=(224, 224), scale=(0.3, 0.9), ratio=(0.75, 1.3333), interpolation=bilinear)
               RandomApply(
               p=0.8
               ColorJitter(brightness=[0.6, 1.4], contrast=[0.6, 1.4], saturation=[0.8, 1.2], hue=[-0.1, 0.1])
           )
               RandomGrayscale(p=0.2)
               RandomApply(
               p=0.5
               GaussianBlur(kernel_size=(25, 25), sigma=(0.1, 2.0))
           )
               Resize(size=(256, 256), interpolation=bilinear, max_size=None, antialias=None)
               CenterCrop(size=(224, 224))
               ToTensor()
               Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
           )

In [16]:
# Splitting the dataset
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size
batch_size = 256
print(train_size, batch_size)

# Split off the test set
train_val_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# DataLoader for the test set (held out)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

766 256


In [17]:
checkpoint_path = "./output/VICReg/efficientnet_b0_backbone_weights_v1.ckpt"
best_params = {"learning_rate": 0.01, "weight_decay": 0.001}

In [18]:
from torchvision.models import efficientnet_b0
import torch.nn as nn
import torch.optim as optim

def get_model():
    # Load the EfficientNet model
    model = efficientnet_b0()
    
    # Modify the final classification head for your dataset
    embed_dim = model.classifier[1].in_features
    model.classifier[1] = nn.Identity()
    
    # Load the pre-trained weights
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    state_dict = checkpoint["model_state_dict"]  # Adjust the key if needed
    
    # Load the weights into the model
    msg = model.load_state_dict(state_dict, strict=False)
    print('Pretrained weights found at {} and loaded with msg: {}'.format(checkpoint_path, msg))

    # Freeze model params
    for param in model.features.parameters():
        param.requires_grad = False
    return model, embed_dim

In [19]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np
import time
from sklearn.model_selection import KFold

def calculate_metrics(true_labels, predictions):
    accuracy = accuracy_score(true_labels, predictions)
    precision = precision_score(true_labels, predictions)
    recall = recall_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions)
    return accuracy, precision, recall, f1


def train_model(model, classifier, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    classifier = classifier.to(device)
    train_losses, val_losses = [], []
    best_val_f1 = 0.0

    for epoch in range(num_epochs):
        classifier.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            outputs = classifier(outputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        scheduler.step()  # Adjust learning rate

        # Validation phase
        classifier.eval()
        val_loss = 0.0
        all_preds, all_labels = [], []
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                outputs = classifier(outputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        # Metrics
        accuracy, precision, recall, f1 = calculate_metrics(all_labels, all_preds)
        train_losses.append(running_loss / len(train_loader))
        val_losses.append(val_loss / len(val_loader))

        if f1 > best_val_f1:
            best_val_f1 = f1

        print(f"[{device}] Epoch {epoch+1}/{num_epochs}, Train Loss: {running_loss / len(train_loader):.4f}, "
              f"Val Loss: {val_loss / len(val_loader):.4f}, F1 Score: {f1:.4f}")

    return best_val_f1

# Evaluate model on the test set
def evaluate_model(model, classifier, test_loader):
    classifier.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    classifier = classifier.to(device)

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            outputs = classifier(outputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='binary')
    recall = recall_score(all_labels, all_preds, average='binary')
    f1 = f1_score(all_labels, all_preds, average='binary')
    return accuracy, precision, recall, f1

def grid_search(train_loader, val_loader, learning_rates, weight_decays, num_epochs):
    best_model = None
    best_f1 = 0
    best_params = {}
    for lr in learning_rates:
        for wd in weight_decays:
            model = get_model()
            optimizer = optim.SGD(model.classifier[1].parameters(), lr=lr, weight_decay=wd, momentum=0.9)
            scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: min(1.0, (epoch + 1) / 10))
            criterion = nn.CrossEntropyLoss()
            print(f"\nTraining with lr={lr}, weight_decay={wd}")
            f1_score = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs)
            if f1_score > best_f1:
                best_f1 = f1_score
                best_model = model
                best_params = {"learning_rate": lr, "weight_decay": wd}
    print(f"\nBest Model F1: {best_f1} with params {best_params}")
    return best_model, best_params

# Step 9: 5-Fold Cross-Validation
def cross_validation(best_params, dataset, test_loader, num_epochs=50, folds=5):
    fold_metrics = []
    kfold = KFold(n_splits=folds, shuffle=True, random_state=100)

    for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
        print(f"\nStarting fold {fold + 1}/{folds}")

        # Split dataset indices for training and validation
        train_subset = Subset(dataset, train_idx)
        val_subset = Subset(dataset, val_idx)

        # Create DataLoaders for this fold
        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

        model, embed_dim = get_model()
        linear_classifier = nn.Linear(embed_dim, 2) # 2 is the number of features
        optimizer = optim.SGD(linear_classifier.parameters(), lr=best_params['learning_rate'], weight_decay=best_params['weight_decay'], momentum=0.9)
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: min(1.0, (epoch + 1) / 10))
        criterion = nn.CrossEntropyLoss()

        start_time = time.time()
        _ = train_model(model, linear_classifier, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs)
        end_time = time.time()

        # Evaluate on test set
        accuracy, precision, recall, f1 = evaluate_model(model, linear_classifier, test_loader)
        fold_metrics.append((accuracy, precision, recall, f1, end_time - start_time))

    return np.array(fold_metrics)

In [20]:
# Run 5-fold cross-validation
metrics = cross_validation(best_params, train_val_dataset, test_loader, num_epochs=50, folds=5)

# Calculate average and standard deviation of metrics across folds
avg_metrics = metrics.mean(axis=0)
std_metrics = metrics.std(axis=0)

print(f"\nAverage metrics over 5 folds in test set:\n"
      f"Accuracy: {avg_metrics[0]:.4f} ± {std_metrics[0]:.4f}\n"
      f"Precision: {avg_metrics[1]:.4f} ± {std_metrics[1]:.4f}\n"
      f"Recall: {avg_metrics[2]:.4f} ± {std_metrics[2]:.4f}\n"
      f"F1 Score: {avg_metrics[3]:.4f} ± {std_metrics[3]:.4f}\n"
      f"Training Time per Fold: {avg_metrics[4]:.2f} ± {std_metrics[4]:.2f} seconds")


Starting fold 1/5
Pretrained weights found at ./output/VICReg/efficientnet_b0_backbone_weights_v1.ckpt and loaded with msg: <All keys matched successfully>
[cuda] Epoch 1/50, Train Loss: 0.7362, Val Loss: 0.7302, F1 Score: 0.4029
[cuda] Epoch 2/50, Train Loss: 0.7095, Val Loss: 0.6806, F1 Score: 0.6194
[cuda] Epoch 3/50, Train Loss: 0.6457, Val Loss: 0.6133, F1 Score: 0.7362
[cuda] Epoch 4/50, Train Loss: 0.5918, Val Loss: 0.5617, F1 Score: 0.7296
[cuda] Epoch 5/50, Train Loss: 0.5541, Val Loss: 0.5579, F1 Score: 0.7547
[cuda] Epoch 6/50, Train Loss: 0.5576, Val Loss: 0.5635, F1 Score: 0.7342
[cuda] Epoch 7/50, Train Loss: 0.5860, Val Loss: 0.5636, F1 Score: 0.7320
[cuda] Epoch 8/50, Train Loss: 0.5362, Val Loss: 0.5645, F1 Score: 0.7582
[cuda] Epoch 9/50, Train Loss: 0.5302, Val Loss: 0.5386, F1 Score: 0.7532
[cuda] Epoch 10/50, Train Loss: 0.5287, Val Loss: 0.5515, F1 Score: 0.7403
[cuda] Epoch 11/50, Train Loss: 0.5452, Val Loss: 0.5495, F1 Score: 0.7134
[cuda] Epoch 12/50, Train L

In [21]:
metrics

array([[6.51162791e-01, 5.23809524e-01, 6.87500000e-01, 5.94594595e-01,
        1.90210057e+03],
       [7.09302326e-01, 5.81395349e-01, 7.81250000e-01, 6.66666667e-01,
        1.87772851e+03],
       [6.51162791e-01, 5.22727273e-01, 7.18750000e-01, 6.05263158e-01,
        1.89479169e+03],
       [6.74418605e-01, 5.47619048e-01, 7.18750000e-01, 6.21621622e-01,
        1.84858736e+03],
       [6.86046512e-01, 5.60975610e-01, 7.18750000e-01, 6.30136986e-01,
        1.89730236e+03]])

In [22]:
import pandas as pd

results = pd.DataFrame(metrics, columns=['accuracy', 'precision', 'recall', 'f1', 'training time'])
results.head()

Unnamed: 0,accuracy,precision,recall,f1,training time
0,0.651163,0.52381,0.6875,0.594595,1902.100572
1,0.709302,0.581395,0.78125,0.666667,1877.728509
2,0.651163,0.522727,0.71875,0.605263,1894.791689
3,0.674419,0.547619,0.71875,0.621622,1848.587361
4,0.686047,0.560976,0.71875,0.630137,1897.302361


In [23]:
results.to_csv('./results/vicreg_fine_tuned_v1.csv')