# Contrastive learning: SimCLR

In [None]:
%%capture

!pip install pytorch-lightning
!pip install lightning-bolts

import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
from torch.utils.data import DataLoader
from torch.multiprocessing import cpu_count

import torchvision.transforms as T
import torchvision.models as models
from torchvision.models import resnet18
from torchvision.datasets import STL10


from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR

import matplotlib.pyplot as plt

SEED = 42

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import GradientAccumulationScheduler, ModelCheckpoint

In [None]:
def default(val, def_val):
    return def_val if val is None else val

def reproducibility(config):
    SEED = int(config.seed)
    torch.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(SEED)
    if (config.cuda):
        torch.cuda.manual_seed(SEED)

def device_as(t1, t2):
    """
    Moves t1 to the device of t2
    """
    return t1.to(t2.device)

def weights_update(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model_dict = model.state_dict()
    pretrained_dict = {k: v for k, v in checkpoint['state_dict'].items() if k in model_dict}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    print(f'Checkpoint {checkpoint_path} was loaded')
    return model

## 1. Augmentation

In [None]:
class Augment:

   def __init__(self, img_size, s=1):
        # Аугментация должна состоять из случайного изменения яркости, контрастности, насыщенности и оттенка изображения (color_jitter с вероятностью 80%)
        color_jitter = T.ColorJitter(
            brightness=s,
            contrast=s,
            saturation=s,
            hue=s * 0.5
            )
        # случайной обрезки от 7% до 100% изображения
        crop = T.RandomResizedCrop(size=img_size, scale=(0.07, 1.0))
        # слуйчаного горизонтального поворота с вероятностью 50%
        hflip = T.RandomHorizontalFlip(p=0.5)
        # случайного перевода изображения в ч/б формат с вероятностью 20%.
        gray = T.RandomGrayscale(p=0.2)
        # размытия по Гауссу с вероятностью 50%
        # Размер ядра обычно составляет около 10% от изображения или меньше.
        blur = T.GaussianBlur((3, 3))
        # преобразования в тензор
        data_transforms = T.Compose(
            [
                T.RandomApply([color_jitter], p=0.8),
                T.RandomApply([blur], p=0.5),
                T.RandomApply([crop, hflip, gray], p=1.0),
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        self.train_transform = data_transforms

        self.test_transform = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

   def __call__(self, x):
       return self.train_transform(x), self.train_transform(x)

In [None]:
def get_stl_dataloader(batch_size, transform=None, split="unlabeled"):
    dataset = STL10("./", split='train', transform=transform, download=True)
    return DataLoader(dataset=dataset, batch_size=batch_size)

In [None]:
def imshow(img):
   mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32)
   std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32)
   unnormalize = T.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
   npimg = unnormalize(img).numpy()
   plt.imshow(np.transpose(npimg, (1, 2, 0)))
   plt.show()

## 2. Projection

In [None]:
class AddProjection(nn.Module):
    def __init__(self, config, model=None, mlp_dim=512):
        super(AddProjection, self).__init__()
        embedding_size = config.embedding_size
        self.backbone = default(model, models.resnet18(pretrained=False, num_classes=config.embedding_size))
        mlp_dim = default(mlp_dim, self.backbone.fc.in_features)
        print('Dim MLP input:',mlp_dim)
        self.backbone.fc = nn.Identity()

        # add mlp projection head
        self.projection = nn.Sequential(
            nn.Linear(in_features=mlp_dim, out_features=mlp_dim),
            nn.BatchNorm1d(mlp_dim),
            nn.ReLU(),
            nn.Linear(in_features=mlp_dim, out_features=embedding_size),
            nn.BatchNorm1d(embedding_size),
        )
       

    def forward(self, x, return_embedding=False):
        embedding = self.backbone(x)
        if return_embedding:
            return embedding
        return self.projection(embedding)

## 3. Loss

In [None]:
class ContrastiveLoss(nn.Module):
    """
    Vanilla Contrastive loss, also called InfoNceLoss as in SimCLR paper
    """
    def __init__(self, batch_size, temperature=0.5):
        super().__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.mask = (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool)).float()

    def calc_similarity_batch(self, a, b):
        # Cконкатенируйте 2 вида
        representations = torch.cat([a, b], dim=0)
        # Вычислите сходство всех пар
        similarity = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
        return similarity

    def forward(self, proj_1, proj_2):
        # 1. примените нормализацию
        z_i = F.normalize(proj_1, p=2, dim=1)
        z_j = F.normalize(proj_2, p=2, dim=1)
        similarity_matrix = self.calc_similarity_batch(z_i, z_j)

        # 2. проиндексировать полученную матрицу
        batch_size = proj_1.shape[0]
        sim_ij = torch.diag(similarity_matrix, batch_size)
        sim_ji = torch.diag(similarity_matrix, -batch_size)

        # 3. loss
        positives = torch.cat([sim_ij, sim_ji], dim=0)
        nominator = torch.exp(positives / self.temperature)
        denominator = device_as(self.mask, similarity_matrix) * torch.exp(similarity_matrix / self.temperature)
        all_losses = -torch.log(nominator / torch.sum(denominator, dim=1))
        loss = torch.sum(all_losses) / (2 * self.batch_size)
        return loss

## 4. SimCLR

In [None]:
class SimCLR_pl(pl.LightningModule):
    def __init__(self, config, model=None, feat_dim=512):
        super().__init__()
        self.config = config
        
        self.model = AddProjection(config, model=model, mlp_dim=feat_dim)

        self.loss = ContrastiveLoss(config.batch_size, temperature=self.config.temperature)

    def forward(self, X):
        return self.model(X)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        z1 = self.model(images[0])
        z2 = self.model(images[1])
        loss = self.loss(z1, z2)
        self.log('Contrastive loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        max_epochs = int(self.config.epochs)
        param_groups = define_param_groups(self.model, self.config.weight_decay, 'adam')
        lr = self.config.lr
        optimizer = Adam(param_groups, lr=lr, weight_decay=self.config.weight_decay)

        print(f'Learning Rate {lr}, '
              f'Effective batch size {self.config.batch_size * self.config.gradient_accumulation_steps}')

        scheduler_warmup = LinearWarmupCosineAnnealingLR(
            optimizer,
            max_epochs=max_epochs,
            warmup_epochs=10
            )

        return [optimizer], [scheduler_warmup]

## 5. Train

In [None]:
def define_param_groups(model, weight_decay, optimizer_name):
   def exclude_from_wd_and_adaptation(name):
       if 'bn' in name:
           return True
       if optimizer_name == 'lars' and 'bias' in name:
           return True

   param_groups = [
       {
           'params': [p for name, p in model.named_parameters() if not exclude_from_wd_and_adaptation(name)],
           'weight_decay': weight_decay,
           'layer_adaptation': True,
       },
       {
           'params': [p for name, p in model.named_parameters() if exclude_from_wd_and_adaptation(name)],
           'weight_decay': 0.,
           'layer_adaptation': False,
       },
   ]
   return param_groups

In [None]:
class Hparams:
    def __init__(self):
        self.epochs = 10 # number of training epochs
        self.seed = 77777 # randomness seed
        self.cuda = True # use nvidia gpu
        self.img_size = 96 #image shape
        self.save = "./saved_models/" # save checkpoint
        self.load = False # load pretrained checkpoint
        self.gradient_accumulation_steps = 5 # gradient accumulation steps
        self.batch_size = 200
        self.lr = 3e-4 # Karpathy constant
        self.weight_decay = 1e-6
        self.embedding_size= 128 # papers value is 128
        self.temperature = 0.5 # 0.1 or 0.5
        self.checkpoint_path = './SimCLR_ResNet18.ckpt' # replace checkpoint path here

In [None]:
available_gpus = len([torch.cuda.device(i) for i in range(torch.cuda.device_count())])
save_model_path = os.path.join(os.getcwd(), "saved_models/")
print('available_gpus:',available_gpus)
filename='SimCLR_ResNet18_adam_'
resume_from_checkpoint = False
train_config = Hparams()

reproducibility(train_config)
save_name = filename + '.ckpt'

model = SimCLR_pl(train_config, model=resnet18(pretrained=False), feat_dim=512)

transform = Augment(train_config.img_size)
data_loader = get_stl_dataloader(train_config.batch_size, transform)

#накопление градиента
accumulator = GradientAccumulationScheduler(scheduling={0: train_config.gradient_accumulation_steps})

In [None]:
checkpoint_callback = ModelCheckpoint(filename=filename, dirpath=save_model_path,
                                        save_last=True, save_top_k=2,monitor='Contrastive loss_epoch',mode='min')

In [None]:
if resume_from_checkpoint:
    trainer = Trainer(callbacks=[accumulator, checkpoint_callback],
                  gpus=available_gpus,
                  max_epochs=train_config.epochs,
                  resume_from_checkpoint=train_config.checkpoint_path)
else:
    trainer = Trainer(callbacks=[accumulator, checkpoint_callback],
                  gpus=available_gpus,
                  max_epochs=train_config.epochs)


trainer.fit(model, data_loader)

In [None]:
trainer.save_checkpoint(save_name)
from google.colab import files
files.download(save_name)

In [None]:
model_pl = SimCLR_pl(train_config, model=resnet18(pretrained=False))
model_pl = weights_update(model_pl, "SimCLR_ResNet18_adam_.ckpt")

resnet18_backbone_weights = model_pl.model.backbone
print(resnet18_backbone_weights)
torch.save({
            'model_state_dict': resnet18_backbone_weights.state_dict(),
            }, 'resnet18_backbone_weights.ckpt')

In [None]:
class SimCLR_eval(pl.LightningModule):
    def __init__(self, lr, model=None, linear_eval=False):
        super().__init__()
        self.lr = lr
        self.linear_eval = linear_eval
        if self.linear_eval:
          model.eval()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(512,10),
            # torch.nn.ReLU(),
            # torch.nn.Dropout(0.1),
            # torch.nn.Linear(128, 10)
        )

        self.model = torch.nn.Sequential(
            model, self.mlp
        )
        self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, X):
        return self.model(X)

    def training_step(self, batch, batch_idx):
        x, y = batch
        z = self.forward(x)
        loss = self.loss(z, y)
        self.log('Cross Entropy loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        predicted = z.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        self.log('Train Acc', acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)

        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        z = self.forward(x)
        loss = self.loss(z, y)
        self.log('Val CE loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)

        predicted = z.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        self.log('Val Accuracy', acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return loss

    def configure_optimizers(self):
        if self.linear_eval:
          print(f"\n\n Attention! Linear evaluation \n")
          optimizer = SGD(self.mlp.parameters(), lr=self.lr, momentum=0.9)
        else:
          optimizer = SGD(self.model.parameters(), lr=self.lr, momentum=0.9)
        return [optimizer]


class Hparams:
    def __init__(self):
        self.epochs = 10 # number of training epochs
        self.seed = 77777 # randomness seed
        self.cuda = True # use nvidia gpu
        self.img_size = 96 #image shape
        self.save = "./saved_models/" # save checkpoint
        self.gradient_accumulation_steps = 1 # gradient accumulation steps
        self.batch_size = 128
        self.lr = 1e-3 
        self.embedding_size= 128 # papers value is 128
        self.temperature = 0.5 # 0.1 or 0.5


# general stuff
available_gpus = len([torch.cuda.device(i) for i in range(torch.cuda.device_count())])
train_config = Hparams()
save_model_path = os.path.join(os.getcwd(), "saved_models/")
print('available_gpus:', available_gpus)
filename = 'SimCLR_ResNet18_finetune_'
reproducibility(train_config)
save_name = filename + '_Final.ckpt'

# load resnet backbone
backbone = models.resnet18(pretrained=False)
backbone.fc = nn.Identity()
checkpoint = torch.load('resnet18_backbone_weights.ckpt')
backbone.load_state_dict(checkpoint['model_state_dict'])
model = SimCLR_eval(train_config.lr, model=backbone, linear_eval=False)

# preprocessing and data loaders
transform_preprocess = Augment(train_config.img_size).test_transform
data_loader = get_stl_dataloader(train_config.batch_size, transform=transform_preprocess,split='train')
data_loader_test = get_stl_dataloader(train_config.batch_size, transform=transform_preprocess,split='test')


# callbacks and trainer
accumulator = GradientAccumulationScheduler(scheduling={0: train_config.gradient_accumulation_steps})

checkpoint_callback = ModelCheckpoint(filename=filename, dirpath=save_model_path,save_last=True,save_top_k=2,
                                       monitor='Val Accuracy_epoch', mode='max')

trainer = Trainer(callbacks=[checkpoint_callback,accumulator],
                  gpus=available_gpus,
                  max_epochs=train_config.epochs)

trainer.fit(model, data_loader,data_loader_test)
trainer.save_checkpoint(save_name)

In [None]:
# load model
resnet = models.resnet18(pretrained=False)
resnet.fc = nn.Identity()
print('imagenet weights, no pretraining')
model = SimCLR_eval(train_config.lr, model=resnet, linear_eval=False)

# preprocessing and data loaders
transform_preprocess = Augment(train_config.img_size).test_transform
data_loader = get_stl_dataloader(128, transform=transform_preprocess,split='train')
data_loader_test = get_stl_dataloader(128, transform=transform_preprocess,split='test')

checkpoint_callback = ModelCheckpoint(filename=filename, dirpath=save_model_path)

trainer = Trainer(callbacks=[checkpoint_callback],
                  gpus=available_gpus,
                  max_epochs=train_config.epochs)

trainer.fit(model, data_loader, data_loader_test)
trainer.save_checkpoint(save_name)

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./lightning_logs/ --port 6010