# Self-supervised learning project
Comparing self-supervised ResNet with SimCLR to normal ResNet

## Universal settings and helpers

In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchvision
from PIL import Image
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import normalize

from lightly.data import LightlyDataset
from lightly.transforms import SimCLRTransform, utils

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if device == "cuda":
    # Should increase performance on GPUs with tensor cores 
    # Not sure if it is supported on all GPUs; consider disabling
    torch.set_float32_matmul_precision('medium') 

In [None]:
num_workers = 32 # I have 32 CPU threads
input_size = 32 # CIFAR10 images are 32x32

In [None]:
seed = 1
pl.seed_everything(seed)

This function prints some metrics we can use to compare the different classifiers.

In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np

def eval_classifier(classifier):
    classifier.to(device)  

    classifier.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in test_classifier_dl:
            images, labels = images.to(device), labels.to(device)
            outputs = classifier(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    conf_matrix = confusion_matrix(all_labels, all_preds)
    class_report = classification_report(all_labels, all_preds)

    print("Confusion Matrix:\n", conf_matrix)
    print("\nClassification Report:\n", class_report)

## Define the classifier

The same classifier is used for the SimCLR model and the pretrained ResNet model.

In [None]:
from pytorch_lightning.callbacks import EarlyStopping
from lightly.models.utils import deactivate_requires_grad

class Classifier(pl.LightningModule):
    def __init__(self, backbone, num_ftrs):
        super().__init__()
        self.backbone = backbone

        # freeze the backbone
        deactivate_requires_grad(backbone)

        self.fc = nn.Linear(num_ftrs, 10) # 10 is the number of output classes

        self.criterion = nn.CrossEntropyLoss()
        self.validation_step_outputs = []

    def forward(self, x):
        y_hat = self.backbone(x).flatten(start_dim=1)
        y_hat = self.fc(y_hat)
        return y_hat

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.criterion(y_hat, y)
        self.log("train_loss_fc", loss)
        return loss

    def on_train_epoch_end(self):
        self.custom_histogram_weights()

    def custom_histogram_weights(self):
        for name, params in self.named_parameters():
            self.logger.experiment.add_histogram(name, params, self.current_epoch)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        y_hat = torch.nn.functional.softmax(y_hat, dim=1)

        loss = self.criterion(y_hat, y)
        self.log("val_loss", loss, on_epoch=True, prog_bar=True)

        _, predicted = torch.max(y_hat, 1)
        num = predicted.shape[0]
        correct = (predicted == y).float().sum()
        self.validation_step_outputs.append((num, correct))
        return num, correct

    def on_validation_epoch_end(self):
        if self.validation_step_outputs:
            total_num = 0
            total_correct = 0
            for num, correct in self.validation_step_outputs:
                total_num += num
                total_correct += correct
            acc = total_correct / total_num
            self.log("val_acc", acc, on_epoch=True, prog_bar=True)
            self.validation_step_outputs.clear()

    def configure_optimizers(self):
        optim = torch.optim.SGD(self.fc.parameters(), lr=2e-2)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs)
        return [optim], [scheduler]

## Load classifier dataset

The same dataset is used to fine-tune the classifier using the SimCLR model and the pretrained ResNet model.

We use just 1% of the labeled CIFAR10 dataset (500 samples).

In [None]:
batch_size = 64

In [None]:
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_classifier_ds_full = torchvision.datasets.CIFAR10(
    root="datasets/cifar10",
    train=True,
    download=True,
    transform=transform
)

test_classifier_ds = torchvision.datasets.CIFAR10(
    root="datasets/cifar10",
    train=False,
    download=True,
    transform=transform
)

# Selecting 1% of the training data
num_train_samples = len(train_classifier_ds_full)
subset_size = int(0.01 * num_train_samples)
indices = np.random.choice(num_train_samples, subset_size, replace=False)
train_classifier_ds = Subset(train_classifier_ds_full, indices)

# DataLoaders
train_classifier_dl = DataLoader(
    train_classifier_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers
)

train_classifier_dl_full = DataLoader(
    train_classifier_ds_full,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers
)

test_classifier_dl = DataLoader(
    test_classifier_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers
)

## Pretraining ResNet with SimCLR

Some hyperparameters

In [None]:
simclr_batch_size = 4096 # Used by the original SimCLR (https://github.com/google-research/simclr?tab=readme-ov-file#pretraining)

### Load the training and testing data

In [None]:
transform = SimCLRTransform(input_size=input_size, vf_prob=0.5, rr_prob=0.5)

test_transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((input_size, input_size)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=utils.IMAGENET_NORMALIZE["mean"],
            std=utils.IMAGENET_NORMALIZE["std"],
        ),
    ]
)

In [None]:
train_simclr_ds = torchvision.datasets.CIFAR10(
    "datasets/cifar10", download=True, transform=transform
)

test_simclr_ds = torchvision.datasets.CIFAR10(
    "datasets/cifar10", download=True, transform=test_transform
)

train_simclr_dl = torch.utils.data.DataLoader(
    train_simclr_ds,
    batch_size=simclr_batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
)

test_simclr_dl = torch.utils.data.DataLoader(
    test_simclr_ds,
    batch_size=simclr_batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers,
)

### Implement SimCLR model

In [None]:
from lightly.loss import NTXentLoss
from lightly.models.modules.heads import SimCLRProjectionHead

# Inspired by
# https://docs.lightly.ai/self-supervised-learning/tutorials/package/tutorial_simclr_clothing.html#create-the-simclr-model 
class SimCLRModel(pl.LightningModule):
    def __init__(self):
        super().__init__()

        # create a ResNet backbone and remove the classification head
        resnet = torchvision.models.resnet18()
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])

        hidden_dim = resnet.fc.in_features
        self.projection_head = SimCLRProjectionHead(hidden_dim, hidden_dim, 128)

        self.criterion = NTXentLoss()

    def forward(self, x):
        h = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(h)
        return z

    def training_step(self, batch, batch_idx):
        (x0, x1), _ = batch
        z0 = self.forward(x0)
        z1 = self.forward(x1)
        loss = self.criterion(z0, z1)
        self.log("train_loss_ssl", loss, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optim = torch.optim.SGD(
            self.parameters(), lr=6e-2, momentum=0.9, weight_decay=5e-4
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs)
        return [optim], [scheduler]

Train the model

In [None]:
max_epochs = 1000

In [None]:
simclr_model = SimCLRModel()
trainer = pl.Trainer(max_epochs=max_epochs, devices=1, accelerator="gpu", log_every_n_steps=1)
trainer.fit(simclr_model, train_simclr_dl)

In [None]:
torch.save(simclr_model.state_dict(), "models/pretrained-simclr-1000.pt")

### Evaluate performance (qualitative)

The Lightly docs included some instructions on how to visualize the nearest neighbors of a couple of images, so let's take a look at that.

In [None]:
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
import numpy as np

# Borrowed from
# https://docs.lightly.ai/self-supervised-learning/tutorials/package/tutorial_simclr_clothing.html#create-the-simclr-model 
def generate_embeddings(model, dataloader):
    """Generates representations for all images in the dataloader with the given model"""

    embeddings = []
    with torch.no_grad():
        for img, _ in dataloader:  # Only extract the images, ignore the labels
            img = img.to(model.device)
            emb = model.backbone(img).flatten(start_dim=1)
            embeddings.append(emb)

    embeddings = torch.cat(embeddings, 0)
    embeddings = normalize(embeddings)
    return embeddings

def plot_knn_examples(embeddings, images, n_neighbors=3, num_examples=6):
    """Plots multiple rows of random images with their nearest neighbors."""
    nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(embeddings)
    distances, indices = nbrs.kneighbors(embeddings)

    samples_idx = np.random.choice(len(indices), size=num_examples, replace=False)

    for idx in samples_idx:
        fig = plt.figure(figsize=(15, 3))
        for plot_x_offset, neighbor_idx in enumerate(indices[idx]):
            ax = fig.add_subplot(1, n_neighbors, plot_x_offset + 1)
            plt.imshow(images[neighbor_idx])
            ax.set_title(f"d={distances[idx][plot_x_offset]:.3f}")
            plt.axis("off")

The CIFAR10 dataset is stored as binary files using pickle, so we need to unpickle them.

In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

def load_cifar10_batch(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def unpack_image(data):
    """ Convert a CIFAR-10 flattened image to a 32x32x3 numpy array. """
    # Reshape the array to 3x32x32 and then transpose it to 32x32x3
    r = data[:1024].reshape(32, 32)
    g = data[1024:2048].reshape(32, 32)
    b = data[2048:].reshape(32, 32)
    return np.stack((r, g, b), axis=-1)

images = []
batches = [f"data_batch_{i}" for i in range(1, 6)]
for batch_file in batches:
    batch = load_cifar10_batch(Path("datasets/cifar10/cifar-10-batches-py") / batch_file)
    # Extract images and labels
    data = batch[b'data']
    labels = batch[b'labels']

    batch_images = [unpack_image(data[i]) for i in range(len(data))]
    images.extend(batch_images)

In [None]:
# load model
simclr_model = SimCLRModel()
simclr_model.load_state_dict(torch.load("models/pretrained-simclr.pt"))

In [None]:
simclr_model.eval() # Enable eval mode for better inference 
embeddings = generate_embeddings(simclr_model, test_simclr_dl)
plot_knn_examples(embeddings, images)

## Fine-tuning SimCLR model

Some hyperparameters

In [None]:
max_epochs = 500

In [None]:
# load model
simclr_model = SimCLRModel()
simclr_model.load_state_dict(torch.load("models/pretrained-simclr.pt"))

Early stopping prevents overfitting by stopping the training process once validation loss stops improving.

In [None]:
early_stop_callback = EarlyStopping(
    monitor='val_loss',   
    patience=10,           
    verbose=True,          
    mode='min'             
)

In [None]:
simclr_classifier = Classifier(simclr_model.backbone, num_ftrs=512) # 512 = resnet.fc.in_features
trainer = pl.Trainer(max_epochs=max_epochs, devices=1, accelerator="gpu", log_every_n_steps=1, callbacks=[early_stop_callback])
trainer.fit(simclr_classifier, train_classifier_dl, test_classifier_dl)

In [None]:
torch.save(simclr_classifier.state_dict(), "models/fine-tuned-simclr.cls")

In [None]:
simclr_classifier.eval()
eval_classifier(simclr_classifier)

## Fine-tuning non-SimCLR model

In [None]:
max_epochs = 500

In [None]:
pretrained_resnet = torchvision.models.resnet18(pretrained=True)

The early stopping callback needs to be reset to train the other model.

In [None]:
early_stop_callback = EarlyStopping(
    monitor='val_loss',   
    patience=10,           
    verbose=True,          
    mode='min'             
)

In [None]:
pt_classifier = Classifier(pretrained_resnet, num_ftrs=1000) # For some reason, this needs to be 1000 for this model
trainer = pl.Trainer(max_epochs=max_epochs, devices=1, accelerator="gpu", log_every_n_steps=1, callbacks=[early_stop_callback])
trainer.fit(pt_classifier, train_classifier_dl, test_classifier_dl)

In [None]:
torch.save(pt_classifier.state_dict(), "models/fine-tuned-imagenet.cls")

In [None]:
eval_classifier(pt_classifier)