##SimCLR — A Simple Framework for Contrastive Learning of Visual Representations

SimCLR uses contrastive learning to maximize agreement between 2 augmented versions of the same image.

Steps

1. Take an input image

2. Prepare 2 random augmentations on the image

3. Run a deep neural network like ResNet50 to obtain image embeddings of those augmented images.

4. Run a small, fully connected linear neural network to project embeddings into another vector space.

5. Calculate the contrastive loss and run backpropagation through both networks. Contrastive loss decreases when projections coming from the same image are similar. 

Reference: [https://arxiv.org/abs/2002.05709](https://arxiv.org/abs/2002.05709)

## Setup

In [None]:
# Use wandb for logging
!pip install --upgrade wandb --quiet

[K     |████████████████████████████████| 1.7 MB 5.3 MB/s 
[K     |████████████████████████████████| 97 kB 4.5 MB/s 
[K     |████████████████████████████████| 180 kB 46.3 MB/s 
[K     |████████████████████████████████| 133 kB 46.0 MB/s 
[K     |████████████████████████████████| 63 kB 1.7 MB/s 
[?25h  Building wheel for subprocess32 (setup.py) ... [?25l[?25hdone
  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


## Imports

In [None]:
from dataclasses import asdict, dataclass
from itertools import chain
from pathlib import Path

import hashlib
import numpy as np
from PIL.Image import Image
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn.functional as F
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.models import resnet18, resnet34, resnet50
from torchvision.transforms import (
  ColorJitter, Compose, Lambda, Normalize, RandomApply, RandomGrayscale,
  RandomHorizontalFlip, RandomResizedCrop, ToTensor)
from tqdm import tqdm
import wandb

## Hyperparameters

In [None]:
@dataclass(eq=True, frozen=True)
class HParams:
  cifar: int = 10
  crop_size: int = 32
  colour_distortion: int = 0.5
  batch_size: int = 1024 # 256, 512, 1024, 2048, 4096 evaluated in paper
  xent_temp: float = 0.5 # 0.1, 0.5, 1.0 evaluated in paper
  proj_dim: int = 128
  weight_decay: float = 1e-6
  max_lr: float = 1.5 # 0.5, 1.0, 1.5 evaluated in paper
  warmup_epochs: int = 10
  cooldown_epochs: int = 90 # 90, 190, 290, 390, 490, 590, 690, 790, 890, 990 evaluated in paper
  use_cosine_scheduler: bool = True
  resnet_depth: int = 18 # 50 evaluated in paper
  DEPTH_TO_REPR_DIM = {18: 512, 34: 512, 50: 2048}

  def __post_init__(self):
    assert self.cifar in (10, 100)
    assert self.resnet_depth in self.DEPTH_TO_REPR_DIM
  
  @property
  def repr_dim(self) -> int:
    return self.DEPTH_TO_REPR_DIM[self.resnet_depth]

  @property
  def md5(self):
    return hashlib.md5(str(hash(self)).encode('utf-8')).hexdigest()

hp = HParams()

## Image Augmentations

In [None]:
class SimCLRAugment(object):
  def __init__(self, hp: HParams):
    s = hp.colour_distortion
    self.simclr_augment = Compose([
      RandomResizedCrop(hp.crop_size),
      RandomHorizontalFlip(),
      RandomApply([
        ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
      ], p=0.8),
      RandomGrayscale(p=0.2),
    ])

  def __call__(self, img: Image):
    aug = self.simclr_augment(img)
    return (img, aug)

## Data Loaders

In [None]:
def get_loaders(hp: HParams):
  if hp.cifar == 10:
    dataset = CIFAR10
  elif hp.cifar == 100:
    dataset = CIFAR100

  train_transform = Compose([
    SimCLRAugment(hp),
    Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])),
  ])
  # Crops x Channels x Height x Width
  train_dataset = dataset("./data", train=True, transform=train_transform, download=True)

  test_transform = ToTensor()
  # Channels x Height x Width
  test_dataset = dataset("./data", train=False, transform=test_transform, download=True)

  kwargs = {"num_workers": 1, "pin_memory": True}
  train_loader = DataLoader(train_dataset, batch_size=hp.batch_size, shuffle=True, **kwargs)
  test_loader = DataLoader(test_dataset, batch_size=hp.batch_size, shuffle=False, **kwargs)
  return train_loader, test_loader

## Encoder and Projector Models

In [None]:
class Encoder(torch.nn.Module):
  def __init__(self, hp: HParams):
    super().__init__()
    if hp.resnet_depth == 18:
      self.resnet = resnet18(pretrained=False, num_classes=hp.cifar)
    elif hp.resnet_depth == 34:
      self.resnet = resnet34(pretrained=False, num_classes=hp.cifar)
    elif hp.resnet_depth == 50:
      self.resnet = resnet50(pretrained=False, num_classes=hp.cifar)
    self.resnet.conv1 = torch.nn.Conv2d(
      3, 64, kernel_size=3, stride=1, padding=1, bias=False
    )
    self.resnet.maxpool = torch.nn.Identity()
    self.resnet.fc = torch.nn.Identity()

  def forward(self, x):
    return self.resnet(x)

class Projector(torch.nn.Module):
  def __init__(self, hp: HParams):
    super().__init__()
    self.l1 = torch.nn.Linear(hp.repr_dim, hp.repr_dim)
    self.l2 = torch.nn.Linear(hp.repr_dim, hp.proj_dim)
  
  def forward(self, x):
    x = F.relu(self.l1(x))
    return self.l2(x)

## Contrastive Loss

In [None]:
def nt_xent_loss(z: torch.Tensor, xent_temp: float):
  # z: (N x 2) x projection_dim
  N = z.shape[0] // 2 # Can be less than batch_size for last batch in epoch
  znorm = z / torch.norm(z, 2, dim=1, keepdim=True)
  cos_sim = torch.einsum('id,jd->ij', znorm, znorm) / xent_temp
  cos_sim.fill_diagonal_(-1e5)
  l = -F.log_softmax(cos_sim, 1)
  idxs = np.arange(N)
  return (l[2*idxs,2*idxs+1] + l[2*idxs+1,2*idxs]).sum() / (2*N)

## Train Function

In [None]:
def train(encoder, projector, train_loader, optimizer, epoch, xent_temp: float) -> None:
  encoder.train()
  projector.train()

  batch_losses = []

  for data, target in tqdm(train_loader, leave=False, desc=f'epoch {epoch}'):
    bs, ncrops, c, h, w = data.size()
    data = data.cuda()
    target = target.cuda()

    optimizer.zero_grad()

    z = projector(encoder(data.view((-1,c,h,w))))
    loss = nt_xent_loss(z, xent_temp)

    loss.backward()
    optimizer.step()
    batch_losses.append(loss.item())

  wandb.log({"Train Loss": np.mean(batch_losses)}, step=epoch)

## Evaluation Functions

In [None]:
@torch.no_grad()
def prepare_xy(encoder, loader):
  encoder.eval()
  projector.eval()

  embeddings = []
  for data, target in loader:
    data = data.cuda()
    if len(data.shape) == 5:
      h = encoder(data[:,0,:,:,:])
    else:
      h = encoder(data)
    embeddings.append((h.cpu().numpy(), target.numpy()))

  X = np.concatenate([x[0] for x in embeddings])
  y = np.concatenate([x[1] for x in embeddings])
  return X, y

def evaluate_logistic(X, y, Xt, yt, epoch: int) -> None:
  scaler = StandardScaler()
  X = scaler.fit_transform(X)
  Xt = scaler.transform(Xt)

  clf = LogisticRegression(
    random_state=0, solver='lbfgs', multi_class='multinomial', max_iter=1000, n_jobs=1,
  ).fit(X, y)
  
  results = {
    'Train Evaluation': np.mean(clf.predict(X) == y),
    'Test Evaluation': np.mean(clf.predict(Xt) == yt),
  }
  wandb.log(results, step=epoch)

def evaluate_features(encoder, projector, train_loader, test_loader, epoch: int) -> None:
  X, y = prepare_xy(encoder, train_loader)
  Xt, yt = prepare_xy(encoder, test_loader)
  evaluate_logistic(X, y, Xt, yt, epoch)

## Restore State

In [None]:
def save_state(hp, epoch: int, encoder, projector, optimizer, scheduler):
  torch.save({
    'hparams': hp,
    'epoch': epoch,
    'encoder_state_dict': encoder.state_dict(),
    'projector_state_dict': projector.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
  }, hp.md5 + '.pkl')

def load_state(hp):
  try:
    checkpoint = torch.load(hp.md5 + '.pkl')
    return checkpoint
  except FileNotFoundError:
    return None

checkpoint = load_state(hp)
if checkpoint is not None:
  print(f"Restoring training state from epoch {checkpoint['epoch']}")
  hp = checkpoint['hparams']

## Instantiate State

In [None]:
torch.manual_seed(hash(hp))

# Dataset
train_loader, test_loader = get_loaders(hp)

# Models
encoder = Encoder(hp)
projector = Projector(hp)
if checkpoint is not None:
  encoder.load_state_dict(checkpoint['encoder_state_dict'])
  projector.load_state_dict(checkpoint['projector_state_dict'])
encoder = encoder.cuda()
projector = projector.cuda()

# Optimizers and Schedulers
init_lr = hp.max_lr / hp.warmup_epochs
optimizer = SGD(chain(encoder.parameters(), projector.parameters()), lr=init_lr, weight_decay=1e-6)
cosine_scheduler = CosineAnnealingLR(optimizer, hp.cooldown_epochs)
if checkpoint is not None:
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  cosine_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

# Starting Epoch
epoch = 1 if checkpoint is None else checkpoint['epoch']

# Wandb
wandb.init(anonymous='must')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


## Train and Evaluation Loop

In [None]:
%%wandb
for epoch in range(epoch, hp.warmup_epochs + hp.cooldown_epochs + 1):
  wandb.log({"Learning Rate": optimizer.param_groups[0]['lr']}, step=epoch)
  train(encoder, projector, train_loader, optimizer, epoch, hp.xent_temp)
  
  if epoch <= hp.warmup_epochs:
    optimizer.param_groups[0]['lr'] = min(hp.max_lr, hp.max_lr * (epoch+1)/10) # Pytorch LambdaLR scheduler is buggy...
  elif hp.use_cosine_scheduler:
    cosine_scheduler.step()
  
  if (epoch == 1) or (epoch % 10 == 0) or (epoch == hp.warmup_epochs + hp.cooldown_epochs):
    evaluate_features(encoder, projector, train_loader, test_loader, epoch)
    save_state(hp, epoch + 1, encoder, projector, optimizer, cosine_scheduler)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
epoch 25:  76%|███████▌  | 37/49 [03:58<01:16,  6.40s/it]