<a href="https://colab.research.google.com/github/tasn19/scan-repro/blob/main/SCAN_repro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

multiple sections or import .py into colab?

Possible section structure
1 Setup 
 -> datasets
 -> model
 -> criterion
 -> utils
 2 Pretext
 3 SCAN
 4 SelfLabel

In [None]:
# Import libraries
import torch
import torchvision
import torchvision.transforms as transforms

## Dataset

Torchvision CIFAR10

In [None]:
# Import dataset from torchvision
path = "/content/drive/MyDrive/Colab Notebooks/SCANmaterials/Unsupervised-Classification/datasets/cifar10"
dataset = torchvision.datasets.CIFAR10(root=path, transform=None, target_transform=None, download=True)


In [None]:
# retrieve image
img, label = dataset.__getitem__(20)
img.resize((64,64))
print(img, label)

In [None]:
img # Image.show() does not work on Colab!!

Custom Dataset: contains a set of images and a set of the same images in augmented form

In [None]:
# this is how they did it in the paper
from torch.utils.data import Dataset

class CustomDataset(Dataset):
  def __init__(self, dataset, step):
    transform = dataset.transform
    dataset.transform = None
    self.dataset = dataset
    self.step = step

    if step == "simclr":
      self.img_transform = base_transform
      self.augment_transform = transform
    else:
      self.img_transform = base_transform
      self.augment_transform = base_transform

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, index):
    oriimg, label = self.dataset.__getitem__(index)
    img = self.img_transform(oriimg)
    augmented_img = self.augment_transform(oriimg)
    return img, augmented_img, label



Transformations

In [None]:
# Create dictionary of transformation parameters (tp)
tp = {"base": {
    "RandomResizedCrop":{"size": 32},
    "Normalize": {"mean": (0.4914, 0.4822, 0.4465), "std": (0.2023, 0.1994, 0.2010)}},
    "simclr": {
    "RandomResizedCrop":{"size": 32, "scale": (0.2, 1.0)},
    "RandomColorJitter":{"brightness": 0.4, "contrast": 0.4, "saturation": 0.4, "hue": 0.1, "p": 0.8},
    "RandomGrayscale":{"p": 0.2},
    "Normalize": {"mean": (0.4914, 0.4822, 0.4465), "std": (0.2023, 0.1994, 0.2010)}}
    }

In [None]:
# Transformations to pre-process image  
def get_transform(step):
  # step options: 'base', 'simclr'
  if step == "base":
    transform = transforms.Compose([transforms.RandomResizedCrop(tp[step]["RandomResizedCrop"]["size"]),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize(tp[step]["Normalize"]["mean"], tp[step]["Normalize"]["std"])]) 

  if step == "simclr": # what abt Gaussian blur ??
    transform = transforms.Compose(
        [transforms.RandomResizedCrop(tp[step]["RandomResizedCrop"]["size"], tp[step]["RandomResizedCrop"]["scale"]),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([
            transforms.ColorJitter(tp[step]["RandomColorJitter"]["brightness"], tp[step]["RandomColorJitter"]["contrast"],
                                  tp[step]["RandomColorJitter"]["saturation"], tp[step]["RandomColorJitter"]["hue"])], 
                                tp[step]["RandomColorJitter"]["p"]),
        transforms.RandomGrayscale(tp[step]["RandomGrayscale"]["p"]),
        transforms.ToTensor(),
        transforms.Normalize(tp[step]["Normalize"]["mean"], tp[step]["Normalize"]["std"])]
    )
  return transform

  base_transform = get_transform("base") # CIFAR10 dataset without CustomDataset should always use this?

## Models

ResNet-18 Backbone

In [None]:
# This is paper's version. Also used here: https://github.com/microsoft/snca.pytorch/blob/master/models/resnet_cifar.py
# dfrnt from torchvision model, CHECK
"""
This code is based on the Torchvision repository, which was licensed under the BSD 3-Clause.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(BasicBlock, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(Bottleneck, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves
        # like an identity. This improves the model by 0.2~0.3% according to:
        # https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for i in range(num_blocks):
            stride = strides[i]
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        return out


def resnet18a(**kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) # changed
    #return {'backbone': ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)} #, 'dim': 512}

With SimLR contrastive model

In [None]:
import torchvision
import torch.nn as nn
import torch.nn.functional as F 

class SimclrContrastiveModel(nn.Module):
  def __init__(self, backbone, head = 'MLP', featuresDim = 128, backboneDim = 512):
    super(SimclrContrastiveModel, self).__init__() 
    self.backbone = backbone
    self.backboneDim = backboneDim
    self.head = head  # need? if linear not used, remove
    # simCLR uses 2 layer MLP head 
    #nn.Linear(input sample size, output sample size)
    #self.contrastiveHead = nn.Linear(self.backboneDim, featuresDim) # just for testing
    self.contrastiveHead = nn.Sequential(nn.Linear(self.backboneDim, self.backboneDim),
                                         nn.ReLU(), nn.Linear(self.backboneDim, featuresDim))
  
  def forward(self, x):
      features = self.contrastiveHead(self.backbone(x))
      features = F.normalize(features, dim = 1)
      return features


def get_model(step):
  # Get backbone
  #resnet18 = torchvision.models.resnet18(pretrained=False)  # what abt forward pass fnc in author code?
  #resnet18_ft = nn.Sequential(*(list(resnet18.children())[0:9])) # remove last layer and retain feature extractor
  #backbone = resnet18_ft
  backbone = resnet18a()  

  if step == "simclr":
    # If pretext task, get simclr contrastive model
    model = SimclrContrastiveModel(backbone)
    # If scan or selflabel task, get clustering model
    # will need to load pretrained weights for 2)scan & 3)selflabel
  return model


## Criterion

SimCLR Loss
{Add loss fnc}

In [None]:
# loss_i,j = -log( exp(sim(z_i, z_j)/tau) / sum [1] exp(sim(z_i, z_k)/tau) )
# Based on https://www.egnyte.com/blog/2020/07/understanding-simclr-a-framework-for-contrastive-learning/
class SimCLR_loss(nn.Module):
  def __init__(self, batch_size, temp=0.1):
    super().__init__()
    self.batch_size = batch_size # need?
    self.register_buffer("temperature", torch.tensor(temp))
    self.register_buffer("negatives_mask", (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool)).float())

  def forward(self, emb_i, emb_j):
    """
    emb_i and emb_j are batches of embeddings, where corresponding indices are pairs
    z_i, z_j as per SimCLR paper
    """
    z_i = F.normalize(emb_i, dim=1)
    z_j = F.normalize(emb_j, dim=1)

    representations = torch.cat([z_i, z_j], dim=0)
    similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
  
    sim_ij = torch.diag(similarity_matrix, self.batch_size)
    sim_ji = torch.diag(similarity_matrix, -self.batch_size)
    positives = torch.cat([sim_ij, sim_ji], dim=0)
  
    nominator = torch.exp(positives / self.temperature)
    denominator = self.negatives_mask * torch.exp(similarity_matrix / self.temperature)
    
    loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1))
    loss = torch.sum(loss_partial) / (2 * self.batch_size)
    print('NT-Xent loss', loss)
    return loss

## Train

In [None]:
# Determine 20 nearest neighbors with SimClR instance discrimination task
def SimCLR_train(dataloader, model, epoch, criterion, optimizer):
  # Record progress ADD
  model.train()
  for i, (ims, aug_ims, lbls) in enumerate(dataloader):
    #print(ims.size())
    batch, channel, h, w = ims.size()
    x_i = ims.unsqueeze(1)
    x_j = aug_ims.unsqueeze(1)    
    x_i = x_i.view(-1, channel, h, w) # in model images processed independently so batch size doesn't matter 
    x_i = x_i.cuda(non_blocking=True)

    x_j = x_j.view(-1, channel, h, w) 
    x_j = x_j.cuda(non_blocking=True)
    targets = lbls.cuda(non_blocking=True) # need?
    z_i = model(x_i) # try concatenation x_i and x_j?
    z_j = model(x_j) 
    loss = criterion(z_i, z_j)
    # update losses

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


# Pretext Task

In [None]:
# Pretext Task 
cifar_path = "/content/drive/MyDrive/Colab Notebooks/SCANmaterials/Unsupervised-Classification/datasets/cifar10"
step = "simclr"
transform = get_transform("simclr")
base_transform = get_transform("base")

# Dictionary containing hyperparameters
# author code epochs = 500, batchsize = 512, num_workers = 8
hyperparams = {"epochs": 10, "batchsize": 4, "weight decay": 0.0001, "momentum": 0.9, "lr": 0.4, 
                   "lr decay rate": 0.1, "num_workers": 2}

# Load training set
train1_set = CIFAR10(root = cifar_path, base_transform = base_transform, transform = transform, download = False) # change to True
train_set = CustomDataset(train1_set, step)
# enable pin_memory to speed up host to device transfer
# Probably highest possible batch_size=128 & num_workers=2 with memory limits CHECK
train_loader = torch.utils.data.DataLoader(train_set, batch_size = 128, shuffle = True, num_workers = 2, pin_memory = True)

# Load testing set
#test_set = CIFAR10(root = cifar_path, transform = transform, download = True)
#test_loader = torch.utils.data.DataLoader(test_set, batch_size = 128, shuffle = True, num_workers = 2, pin_memory = True)


In [None]:
# For initial testing, take a small subset of dataset
indices = torch.randperm(len(train_set)).tolist()
expset = torch.utils.data.Subset(train_set, indices[:100])
expload = torch.utils.data.DataLoader(expset, batch_size = 4, shuffle = True, num_workers = 2)

In [None]:
# Instantiate model
model = get_model(step)
model.cuda()

# Get criterion
batchsize = hyperparams["batchsize"]
criterion = SimCLR_loss(batchsize)
criterion.cuda()

lr = 0.4
decay_rate = 0.1
# Instantiate SGD (??) optimizer # original simclr paper used LARS...
params = [p for p in model.parameters() if p.requires_grad] # CHECK
optimizer = torch.optim.SGD(params, lr, momentum=hyperparams["momentum"], 
                            weight_decay=hyperparams["weight decay"], nesterov=False)

# Train model
# add warm-up? (to reduce effect of early training)
epochs = hyperparams["epochs"]
for epoch in range(epochs):
  print('Epoch ', epoch)
  # Update scheduler (it resets every epoch)
  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=(lr*(decay_rate**3)))
  lr = scheduler.get_lr()[0]
  print('Learning Rate ', lr, len(scheduler.get_lr()))

  # train
  SimCLR_train(expload, model, epoch, criterion, optimizer)

  # memory bank

  # validate

  # checkpoint

  # update learning rate CHECK
  scheduler.step()

# Save model
