# VICREG

## Setup

In [1]:
! nvidia-smi

Wed Nov 27 12:53:43 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.256.02   Driver Version: 470.256.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA TITAN Xp     Off  | 00000000:01:00.0 Off |                  N/A |
| 28%   53C    P0    81W / 250W |      2MiB / 12196MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import transforms
import torch.distributed as dist

def off_diagonal(x):
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

def get_byol_transforms(size, mean, std):
    transformT = transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomResizedCrop(size=size, scale=(0.08,1), ratio=(3 / 4, 4 / 3)),
        transforms.RandomRotation((-90, 90)),
        transforms.ColorJitter(),
        transforms.GaussianBlur(kernel_size=(23,23), sigma=(0.1, 2.0)),
        transforms.RandomGrayscale(p=0.2),
        transforms.Normalize(mean, std),
        ])

    transformT1 = transforms.Compose([
        transforms.ToTensor(),
        transforms.RandomResizedCrop(size=size, scale=(0.08,1), ratio=(3 / 4, 4 / 3)),
        transforms.RandomRotation((-90, 90)),
        transforms.ColorJitter(),
        transforms.RandomGrayscale(p=0.2),
        transforms.GaussianBlur(kernel_size=(23,23), sigma=(0.1, 2.0)),
        transforms.Normalize(mean, std),
        ])

    transformEvalT = transforms.Compose([
        transforms.ToTensor(),
        transforms.CenterCrop(size=size),
        transforms.Normalize(mean, std),        
    ])

    return transformT, transformT1, transformEvalT

def get_cxr_transforms(size=224, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    transform = transforms.Compose([
        transforms.RandomResizedCrop(size, scale=(0.3, 0.9), ratio=(3/4, 4/3)),
        transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
                p=0.8
            ),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=25, sigma=(0.1, 2.0))], p=0.5),
        transforms.Resize((256, 256)),
        transforms.CenterCrop(size),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

    transform_eval = transforms.Compose([
        transforms.ToTensor(),
        transforms.CenterCrop(size=size),
        transforms.Normalize(mean, std),        
    ])

    return transform, transform, transform_eval


class MultiViewDataInjector(object):
    def __init__(self, *args):
        self.transforms = args[0]
        self.random_flip = transforms.RandomHorizontalFlip()

    def __call__(self, sample, *with_consistent_flipping):
        if with_consistent_flipping:
            sample = self.random_flip(sample)
        output = [transform(sample) for transform in self.transforms]
        return output

In [3]:
import torch
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self,
    input_size = 2048,
    output_size = 8192,
    depth = 3,
    ):
        super().__init__()
        layers = []
        inp = input_size
        for d in range(depth):
            if d == depth - 1:
                layers.append(nn.Linear(inp, output_size))
            else:
                layers.extend([nn.Linear(inp, output_size), nn.BatchNorm1d(output_size), nn.ReLU(inplace=True)])
                inp = output_size
        self.layer = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.layer(x)

class VicReg(nn.Module):
    def __init__(self,
    backend = 'resnet50',
    input_size = 2048,
    output_size = 8192,
    depth_projector = 3,
    lmbd = 5e-3, u = 1, v= 1, epsilon = 1e-3):

        super().__init__()
        self.backend = backend
        self.projector = MLP(input_size=input_size, output_size=output_size, depth=depth_projector)
        self.output_size = output_size
        self.epsilon = epsilon
        self.u = u
        self.v = v
        self.lmbd = lmbd
    
    def forward(self, x, y):
        x = self.backend(x)
        y = self.backend(y)
        x = self.projector(x)
        y = self.projector(y)

        bs = x.size(0)
    
        repr_loss = F.mse_loss(x, y)
    
        # x = torch.cat(FullGatherLayer.apply(x), dim=0)
        # y = torch.cat(FullGatherLayer.apply(y), dim=0)
        x = x - x.mean(dim=0)
        y = y - y.mean(dim=0)
    
        std_x = torch.sqrt(x.var(dim=0) + self.epsilon)
        std_y = torch.sqrt(y.var(dim=0) + self.epsilon)
        std_loss = torch.mean(F.relu(1 - std_x)) / 2 + torch.mean(F.relu(1 - std_y)) / 2
    
        cov_x = (x.T @ x) / (bs - 1)
        cov_y = (y.T @ y) / (bs - 1)
        cov_loss = off_diagonal(cov_x).pow_(2).sum().div(self.output_size) + off_diagonal(cov_y).pow_(2).sum().div(self.output_size)
    
        loss = (
            self.u * repr_loss
            + self.v * std_loss
            + self.lmbd * cov_loss
        )
        
        # Free tensors
        del x, y, cov_x, cov_y, std_x, std_y
        torch.cuda.empty_cache()

        return loss

In [4]:
# From https://github.com/noahgolmant/pytorch-lars
""" Layer-wise adaptive rate scaling for SGD in PyTorch! """
from torch.optim.optimizer import Optimizer, required

class LARS(Optimizer):
    r"""Implements layer-wise adaptive rate scaling for SGD.

    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float): base learning rate (\gamma_0)
        momentum (float, optional): momentum factor (default: 0) ("m")
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
            ("\beta")
        eta (float, optional): LARS coefficient
        max_epoch: maximum training epoch to determine polynomial LR decay.

    Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg.
    Large Batch Training of Convolutional Networks:
        https://arxiv.org/abs/1708.03888

    Example:
        >>> optimizer = LARS(model.parameters(), lr=0.1, eta=1e-3)
        >>> optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> optimizer.step()
    """
    def __init__(self, params, lr=required, momentum=.9,
                 weight_decay=.0005, eta=0.001, max_epoch=200):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}"
                             .format(weight_decay))
        if eta < 0.0:
            raise ValueError("Invalid LARS coefficient value: {}".format(eta))

        self.epoch = 0
        defaults = dict(lr=lr, momentum=momentum,
                        weight_decay=weight_decay,
                        eta=eta, max_epoch=max_epoch)
        super(LARS, self).__init__(params, defaults)

    def step(self, epoch=None, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
            epoch: current epoch to calculate polynomial LR decay schedule.
                   if None, uses self.epoch and increments it.
        """
        loss = None
        if closure is not None:
            loss = closure()

        if epoch is None:
            epoch = self.epoch
            self.epoch += 1

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            eta = group['eta']
            lr = group['lr']
            max_epoch = group['max_epoch']

            for p in group['params']:
                if p.grad is None:
                    continue

                param_state = self.state[p]
                d_p = p.grad.data

                weight_norm = torch.norm(p.data)
                grad_norm = torch.norm(d_p)

                # Global LR computed on polynomial decay schedule
                decay = (1 - float(epoch) / max_epoch) ** 2
                global_lr = lr * decay

                # Compute local learning rate for this layer
                local_lr = eta * weight_norm / \
                    (grad_norm + weight_decay * weight_norm)

                # Update the momentum term
                actual_lr = local_lr * global_lr

                if 'momentum_buffer' not in param_state:
                    buf = param_state['momentum_buffer'] = \
                            torch.zeros_like(p.data)
                else:
                    buf = param_state['momentum_buffer']
                buf.mul_(momentum).add_(actual_lr, d_p + weight_decay * p.data)
                p.data.add_(-buf)

        return loss

## Pre-training

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.transforms import transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Subset, random_split
import numpy as np
from tqdm import tqdm
import time

In [6]:
# Set params
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
epochs = 100
batch_size = 32
base_lr = 0.00025
weight_decay = 1e-6

In [7]:
# Load datasets
transform_x1, transform_x2, transform_test = get_cxr_transforms()

dataset_path = './datasets/chestx-ray14-v3'
train_dataset = datasets.ImageFolder(root=f'{dataset_path}/train', transform=MultiViewDataInjector([transform_x1, transform_x2]))
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)

test_dataset = datasets.ImageFolder(root=f'{dataset_path}/test', transform=MultiViewDataInjector([transform_x1, transform_x2]))
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=8)

In [12]:
# Train loop
def train_loop(model, epoch, optimizer, train_loader, device):
    tk0 = tqdm(train_loader)
    train_loss = []

    start_time = time.time()

    for (x, x1), _ in tk0:
        x = x.to(device)
        x1 = x1.to(device)

        loss = model.forward(x, x1)

        train_loss.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step(epoch)

        # Free tensors
        del x, x1, loss
        torch.cuda.empty_cache()

    return train_loss, time.time() - start_time

In [9]:
from torchvision.models import efficientnet_b0

# Load model
model = efficientnet_b0(pretrained=True)
embed_dim = model.classifier[1].in_features
output_features = embed_dim*4
model.classifier[1] = nn.Identity()
model = VicReg(input_size=embed_dim, output_size=output_features, backend=model, depth_projector=3)
model

VicReg(
  (backend): EfficientNet(
    (features): Sequential(
      (0): ConvNormActivation(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU(inplace=True)
      )
      (1): Sequential(
        (0): MBConv(
          (block): Sequential(
            (0): ConvNormActivation(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
              (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
            (1): SqueezeExcitation(
              (avgpool): AdaptiveAvgPool2d(output_size=1)
              (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
              (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
              (activation): SiLU(inplace=True)
              (scale_activation

In [13]:
# Train script
start_time = time.time()

model = model.to(device)
params = model.parameters()
optimizer = LARS(params, lr=base_lr, weight_decay=weight_decay)

for epoch in range(epochs):
    train_loss, run_time = train_loop(model, epoch, optimizer, train_loader, device)
    print(f'Epoch: [{epoch+1}/{epochs}]  loss: {np.mean(train_loss)}, time: {run_time}')

end_time = time.time()
train_time = end_time - start_time
print(f'Training time {train_time}')

100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:11<00:00,  1.97it/s]


Epoch: [1/100]  loss: 0.8958430943034944, time: 191.5934774875641


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:13<00:00,  1.95it/s]


Epoch: [2/100]  loss: 0.8951006262075334, time: 193.77315425872803


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:16<00:00,  1.93it/s]


Epoch: [3/100]  loss: 0.8942718471168841, time: 196.10740327835083


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:16<00:00,  1.92it/s]


Epoch: [4/100]  loss: 0.8931942925882087, time: 196.9852659702301


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:18<00:00,  1.90it/s]


Epoch: [5/100]  loss: 0.8926089172640805, time: 198.52095651626587


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [6/100]  loss: 0.8917242310665272, time: 199.83911752700806


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [7/100]  loss: 0.8909359654736897, time: 199.70517992973328


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [8/100]  loss: 0.8904558727665554, time: 199.34570908546448


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [9/100]  loss: 0.8898204147184967, time: 199.2631320953369


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:18<00:00,  1.90it/s]


Epoch: [10/100]  loss: 0.8892012079241415, time: 198.9500434398651


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:18<00:00,  1.91it/s]


Epoch: [11/100]  loss: 0.8883832801902105, time: 198.2771406173706


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [12/100]  loss: 0.8874900831431939, time: 199.02566385269165


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:18<00:00,  1.91it/s]


Epoch: [13/100]  loss: 0.8870017533264463, time: 198.25108432769775


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:18<00:00,  1.91it/s]


Epoch: [14/100]  loss: 0.8862234368210747, time: 198.12313508987427


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:18<00:00,  1.90it/s]


Epoch: [15/100]  loss: 0.8857569232504204, time: 199.0055365562439


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:18<00:00,  1.90it/s]


Epoch: [16/100]  loss: 0.8852145397473895, time: 198.97399878501892


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [17/100]  loss: 0.8842151568680213, time: 199.3594663143158


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [18/100]  loss: 0.8833925225747326, time: 200.05258202552795


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [19/100]  loss: 0.8828140889841413, time: 200.24986600875854


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [20/100]  loss: 0.8823403766546299, time: 200.4049482345581


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [21/100]  loss: 0.8814554316972298, time: 200.52200150489807


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [22/100]  loss: 0.8811780852930886, time: 200.34282636642456


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [23/100]  loss: 0.8799124681760394, time: 200.70960998535156


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.87it/s]


Epoch: [24/100]  loss: 0.8795196123854824, time: 201.70652151107788


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [25/100]  loss: 0.8788522567067828, time: 200.92377829551697


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.87it/s]


Epoch: [26/100]  loss: 0.8783283039690957, time: 201.77794289588928


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.87it/s]


Epoch: [27/100]  loss: 0.8779435110470605, time: 201.71964120864868


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.87it/s]


Epoch: [28/100]  loss: 0.8769346992805521, time: 201.80504512786865


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:22<00:00,  1.87it/s]


Epoch: [29/100]  loss: 0.87634478170405, time: 202.1645007133484


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:22<00:00,  1.87it/s]


Epoch: [30/100]  loss: 0.875686980586834, time: 202.52016401290894


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:22<00:00,  1.87it/s]


Epoch: [31/100]  loss: 0.8749421869636212, time: 202.31433939933777


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:22<00:00,  1.87it/s]


Epoch: [32/100]  loss: 0.8745006625614469, time: 202.35027360916138


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [33/100]  loss: 0.8739768157875727, time: 201.39208817481995


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [34/100]  loss: 0.8732498553064134, time: 201.25915026664734


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [35/100]  loss: 0.872496997710889, time: 201.46399927139282


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [36/100]  loss: 0.8720572362185786, time: 200.41816115379333


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [37/100]  loss: 0.8716130792779266, time: 201.05916547775269


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [38/100]  loss: 0.8708564798352579, time: 201.07016324996948


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [39/100]  loss: 0.870528535868125, time: 200.9608817100525


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [40/100]  loss: 0.8694204587154287, time: 200.4762098789215


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [41/100]  loss: 0.8692345655469037, time: 200.2870659828186


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [42/100]  loss: 0.868613305861357, time: 199.73852133750916


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [43/100]  loss: 0.8679208785453171, time: 199.14261054992676


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [44/100]  loss: 0.8675287189307036, time: 199.47419142723083


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [45/100]  loss: 0.8671657659073986, time: 199.8659405708313


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [46/100]  loss: 0.8663436243773768, time: 199.29248809814453


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [47/100]  loss: 0.8663930856676959, time: 199.54005551338196


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [48/100]  loss: 0.8658899885636789, time: 199.157719373703


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [49/100]  loss: 0.8649519472210495, time: 199.18470811843872


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:18<00:00,  1.90it/s]


Epoch: [50/100]  loss: 0.8645683438689621, time: 198.76797580718994


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:18<00:00,  1.91it/s]


Epoch: [51/100]  loss: 0.8640688949476474, time: 198.34739756584167


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:18<00:00,  1.91it/s]


Epoch: [52/100]  loss: 0.8636427668351976, time: 198.33520221710205


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:18<00:00,  1.90it/s]


Epoch: [53/100]  loss: 0.8630457443219645, time: 198.50255823135376


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [54/100]  loss: 0.862727849414109, time: 199.1151692867279


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [55/100]  loss: 0.8623525495567019, time: 199.2866427898407


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [56/100]  loss: 0.8615595532788171, time: 199.70905876159668


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [57/100]  loss: 0.8610990509469673, time: 200.01586318016052


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [58/100]  loss: 0.8607371151447296, time: 200.50033020973206


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [59/100]  loss: 0.8604432908946245, time: 200.73246145248413


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [60/100]  loss: 0.8598744796697425, time: 201.20284962654114


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [61/100]  loss: 0.8593640849388465, time: 200.2835955619812


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [62/100]  loss: 0.8589191942934006, time: 200.39198446273804


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [63/100]  loss: 0.8584460892059185, time: 201.1406705379486


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [64/100]  loss: 0.8582978448855183, time: 200.95063710212708


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [65/100]  loss: 0.8574250863973426, time: 200.8277280330658


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [66/100]  loss: 0.857158369330502, time: 201.01833319664001


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.87it/s]


Epoch: [67/100]  loss: 0.8569873353791615, time: 201.88142275810242


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.87it/s]


Epoch: [68/100]  loss: 0.8564745517319472, time: 201.72744035720825


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [69/100]  loss: 0.8557784342891955, time: 201.26887893676758


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [70/100]  loss: 0.855580722055738, time: 200.61423015594482


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [71/100]  loss: 0.855286129567989, time: 200.240172624588


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [72/100]  loss: 0.8549626994700659, time: 200.0749969482422


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [73/100]  loss: 0.8542175026482375, time: 200.54785680770874


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [74/100]  loss: 0.8539603317856158, time: 200.37876963615417


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [75/100]  loss: 0.853442347711987, time: 200.63526844978333


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [76/100]  loss: 0.8531838505671768, time: 200.52086758613586


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:21<00:00,  1.88it/s]


Epoch: [77/100]  loss: 0.8529345571364044, time: 201.26562404632568


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.89it/s]


Epoch: [78/100]  loss: 0.8523046287594649, time: 200.5060522556305


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:20<00:00,  1.88it/s]


Epoch: [79/100]  loss: 0.8520244104521615, time: 200.55772423744202


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [80/100]  loss: 0.8516901127560429, time: 199.3030264377594


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [81/100]  loss: 0.8512557910548316, time: 199.12538361549377


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:18<00:00,  1.90it/s]


Epoch: [82/100]  loss: 0.8513614018127401, time: 198.67746591567993


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:18<00:00,  1.91it/s]


Epoch: [83/100]  loss: 0.8508493838802217, time: 198.24900770187378


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:17<00:00,  1.91it/s]


Epoch: [84/100]  loss: 0.8507139660693981, time: 197.4820384979248


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:17<00:00,  1.91it/s]


Epoch: [85/100]  loss: 0.850045617453005, time: 197.4932792186737


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:17<00:00,  1.92it/s]


Epoch: [86/100]  loss: 0.8501250923626007, time: 197.11886715888977


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:17<00:00,  1.92it/s]


Epoch: [87/100]  loss: 0.8494645013380303, time: 197.05017757415771


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:17<00:00,  1.91it/s]


Epoch: [88/100]  loss: 0.8490141381031622, time: 197.53760170936584


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:17<00:00,  1.92it/s]


Epoch: [89/100]  loss: 0.8486938610594109, time: 197.12932896614075


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:17<00:00,  1.91it/s]


Epoch: [90/100]  loss: 0.8488604926243031, time: 197.44296765327454


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:17<00:00,  1.92it/s]


Epoch: [91/100]  loss: 0.8479638485996811, time: 197.1139805316925


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:17<00:00,  1.91it/s]


Epoch: [92/100]  loss: 0.8479974427551189, time: 197.5750970840454


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:18<00:00,  1.91it/s]


Epoch: [93/100]  loss: 0.8473331035760344, time: 198.1093454360962


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:18<00:00,  1.90it/s]


Epoch: [94/100]  loss: 0.8471233076834804, time: 198.45048785209656


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:18<00:00,  1.90it/s]


Epoch: [95/100]  loss: 0.8472175902475125, time: 198.49177360534668


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [96/100]  loss: 0.8468498557648331, time: 199.0325276851654


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.90it/s]


Epoch: [97/100]  loss: 0.8463553101928146, time: 199.3070101737976


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [98/100]  loss: 0.846441183140669, time: 199.840589761734


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]


Epoch: [99/100]  loss: 0.8461569287474193, time: 199.6111056804657


100%|██████████████████████████████████████████████████████████████████████████████| 378/378 [03:19<00:00,  1.89it/s]

Epoch: [100/100]  loss: 0.845874012778045, time: 199.61502313613892
Training time 19961.72344970703





In [14]:
# Save results
model_backbone_weights = model.backend
print(model_backbone_weights)

EfficientNet(
  (features): Sequential(
    (0): ConvNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): ConvNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): ConvNormActivation(
 

In [16]:
! mkdir ./output/VICReg

In [17]:
save_model_path = './output/VICReg'
torch.save({ 'model_state_dict': model_backbone_weights.state_dict() }, f'{save_model_path}/efficientnet_b0_backbone_weights.ckpt')

## Fine-tuning

In [18]:
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split
import torch

# Path to the dataset
dataset_path = './datasets/COVIDGR_1.0'
positive_path = os.path.join(dataset_path, 'P')
negative_path = os.path.join(dataset_path, 'N')

# Data transformations
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.3, 0.9), ratio=(3/4, 4/3)),
    transforms.RandomApply(
            [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
            p=0.8
        ),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=25, sigma=(0.1, 2.0))], p=0.5),
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# Load dataset with ImageFolder
dataset = datasets.ImageFolder(root=dataset_path, transform=train_transform)

In [19]:
# Step 4: Map labels
dataset.class_to_idx = {'N': 0, 'P': 1}
dataset

Dataset ImageFolder
    Number of datapoints: 852
    Root location: ./datasets/COVIDGR_1.0
    StandardTransform
Transform: Compose(
               RandomResizedCrop(size=(224, 224), scale=(0.3, 0.9), ratio=(0.75, 1.3333), interpolation=bilinear)
               RandomApply(
               p=0.8
               ColorJitter(brightness=[0.6, 1.4], contrast=[0.6, 1.4], saturation=[0.8, 1.2], hue=[-0.1, 0.1])
           )
               RandomGrayscale(p=0.2)
               RandomApply(
               p=0.5
               GaussianBlur(kernel_size=(25, 25), sigma=(0.1, 2.0))
           )
               Resize(size=(256, 256), interpolation=bilinear, max_size=None, antialias=None)
               CenterCrop(size=(224, 224))
               ToTensor()
               Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
           )

In [20]:
# Splitting the dataset
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size
batch_size = 256
print(train_size, batch_size)

# Split off the test set
train_val_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# DataLoader for the test set (held out)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

766 256


In [21]:
checkpoint_path = "./output/VICReg/efficientnet_b0_backbone_weights.ckpt"
best_params = {"learning_rate": 0.01, "weight_decay": 0.001}

In [22]:
from torchvision.models import efficientnet_b0
import torch.nn as nn
import torch.optim as optim

def get_model():
    # Load the EfficientNet model
    model = efficientnet_b0()
    
    # Modify the final classification head for your dataset
    embed_dim = model.classifier[1].in_features
    model.classifier[1] = nn.Identity()
    
    # Load the pre-trained weights
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    state_dict = checkpoint["model_state_dict"]  # Adjust the key if needed
    
    # Load the weights into the model
    msg = model.load_state_dict(state_dict, strict=False)
    print('Pretrained weights found at {} and loaded with msg: {}'.format(checkpoint_path, msg))

    # Freeze model params
    for param in model.features.parameters():
        param.requires_grad = False
    return model, embed_dim

In [23]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np
import time
from sklearn.model_selection import KFold

def calculate_metrics(true_labels, predictions):
    accuracy = accuracy_score(true_labels, predictions)
    precision = precision_score(true_labels, predictions)
    recall = recall_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions)
    return accuracy, precision, recall, f1


def train_model(model, classifier, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    classifier = classifier.to(device)
    train_losses, val_losses = [], []
    best_val_f1 = 0.0

    for epoch in range(num_epochs):
        classifier.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            outputs = classifier(outputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        scheduler.step()  # Adjust learning rate

        # Validation phase
        classifier.eval()
        val_loss = 0.0
        all_preds, all_labels = [], []
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                outputs = classifier(outputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        # Metrics
        accuracy, precision, recall, f1 = calculate_metrics(all_labels, all_preds)
        train_losses.append(running_loss / len(train_loader))
        val_losses.append(val_loss / len(val_loader))

        if f1 > best_val_f1:
            best_val_f1 = f1

        print(f"[{device}] Epoch {epoch+1}/{num_epochs}, Train Loss: {running_loss / len(train_loader):.4f}, "
              f"Val Loss: {val_loss / len(val_loader):.4f}, F1 Score: {f1:.4f}")

    return best_val_f1

# Evaluate model on the test set
def evaluate_model(model, classifier, test_loader):
    classifier.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    classifier = classifier.to(device)

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            outputs = classifier(outputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='binary')
    recall = recall_score(all_labels, all_preds, average='binary')
    f1 = f1_score(all_labels, all_preds, average='binary')
    return accuracy, precision, recall, f1

def grid_search(train_loader, val_loader, learning_rates, weight_decays, num_epochs):
    best_model = None
    best_f1 = 0
    best_params = {}
    for lr in learning_rates:
        for wd in weight_decays:
            model = get_model()
            optimizer = optim.SGD(model.classifier[1].parameters(), lr=lr, weight_decay=wd, momentum=0.9)
            scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: min(1.0, (epoch + 1) / 10))
            criterion = nn.CrossEntropyLoss()
            print(f"\nTraining with lr={lr}, weight_decay={wd}")
            f1_score = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs)
            if f1_score > best_f1:
                best_f1 = f1_score
                best_model = model
                best_params = {"learning_rate": lr, "weight_decay": wd}
    print(f"\nBest Model F1: {best_f1} with params {best_params}")
    return best_model, best_params

# Step 9: 5-Fold Cross-Validation
def cross_validation(best_params, dataset, test_loader, num_epochs=50, folds=5):
    fold_metrics = []
    kfold = KFold(n_splits=folds, shuffle=True, random_state=100)

    for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset)):
        print(f"\nStarting fold {fold + 1}/{folds}")

        # Split dataset indices for training and validation
        train_subset = Subset(dataset, train_idx)
        val_subset = Subset(dataset, val_idx)

        # Create DataLoaders for this fold
        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

        model, embed_dim = get_model()
        linear_classifier = nn.Linear(embed_dim, 2) # 2 is the number of features
        optimizer = optim.SGD(linear_classifier.parameters(), lr=best_params['learning_rate'], weight_decay=best_params['weight_decay'], momentum=0.9)
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: min(1.0, (epoch + 1) / 10))
        criterion = nn.CrossEntropyLoss()

        start_time = time.time()
        _ = train_model(model, linear_classifier, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs)
        end_time = time.time()

        # Evaluate on test set
        accuracy, precision, recall, f1 = evaluate_model(model, linear_classifier, test_loader)
        fold_metrics.append((accuracy, precision, recall, f1, end_time - start_time))

    return np.array(fold_metrics)

In [24]:
# Run 5-fold cross-validation
metrics = cross_validation(best_params, train_val_dataset, test_loader, num_epochs=50, folds=5)

# Calculate average and standard deviation of metrics across folds
avg_metrics = metrics.mean(axis=0)
std_metrics = metrics.std(axis=0)

print(f"\nAverage metrics over 5 folds in test set:\n"
      f"Accuracy: {avg_metrics[0]:.4f} ± {std_metrics[0]:.4f}\n"
      f"Precision: {avg_metrics[1]:.4f} ± {std_metrics[1]:.4f}\n"
      f"Recall: {avg_metrics[2]:.4f} ± {std_metrics[2]:.4f}\n"
      f"F1 Score: {avg_metrics[3]:.4f} ± {std_metrics[3]:.4f}\n"
      f"Training Time per Fold: {avg_metrics[4]:.2f} ± {std_metrics[4]:.2f} seconds")


Starting fold 1/5
Pretrained weights found at ./output/VICReg/efficientnet_b0_backbone_weights.ckpt and loaded with msg: <All keys matched successfully>
[cuda] Epoch 1/50, Train Loss: 0.6803, Val Loss: 0.7020, F1 Score: 0.4427
[cuda] Epoch 2/50, Train Loss: 0.6768, Val Loss: 0.6813, F1 Score: 0.4806
[cuda] Epoch 3/50, Train Loss: 0.6522, Val Loss: 0.6593, F1 Score: 0.6197
[cuda] Epoch 4/50, Train Loss: 0.6258, Val Loss: 0.6395, F1 Score: 0.6711
[cuda] Epoch 5/50, Train Loss: 0.6094, Val Loss: 0.6044, F1 Score: 0.7006
[cuda] Epoch 6/50, Train Loss: 0.5779, Val Loss: 0.5893, F1 Score: 0.7425
[cuda] Epoch 7/50, Train Loss: 0.5604, Val Loss: 0.6093, F1 Score: 0.6456
[cuda] Epoch 8/50, Train Loss: 0.5432, Val Loss: 0.5691, F1 Score: 0.7284
[cuda] Epoch 9/50, Train Loss: 0.5616, Val Loss: 0.6195, F1 Score: 0.7329
[cuda] Epoch 10/50, Train Loss: 0.5659, Val Loss: 0.6165, F1 Score: 0.7237
[cuda] Epoch 11/50, Train Loss: 0.5312, Val Loss: 0.6187, F1 Score: 0.6708
[cuda] Epoch 12/50, Train Loss

In [25]:
metrics

array([[7.55813953e-01, 6.97674419e-01, 7.89473684e-01, 7.40740741e-01,
        1.88389540e+03],
       [6.97674419e-01, 6.57894737e-01, 6.57894737e-01, 6.57894737e-01,
        1.86215698e+03],
       [6.86046512e-01, 6.27906977e-01, 7.10526316e-01, 6.66666667e-01,
        1.85964647e+03],
       [6.86046512e-01, 6.27906977e-01, 7.10526316e-01, 6.66666667e-01,
        1.83039524e+03],
       [6.86046512e-01, 6.48648649e-01, 6.31578947e-01, 6.40000000e-01,
        1.86942247e+03]])

In [26]:
import pandas as pd

results = pd.DataFrame(metrics, columns=['accuracy', 'precision', 'recall', 'f1', 'training time'])
results.head()

Unnamed: 0,accuracy,precision,recall,f1,training time
0,0.755814,0.697674,0.789474,0.740741,1883.895396
1,0.697674,0.657895,0.657895,0.657895,1862.156977
2,0.686047,0.627907,0.710526,0.666667,1859.646466
3,0.686047,0.627907,0.710526,0.666667,1830.395243
4,0.686047,0.648649,0.631579,0.64,1869.422467


In [28]:
results.to_csv('./results/vicreg_fine_tuned.csv')