# 6. Self-supervised learning

Welcome to this module's notebook where we dive into the fascinating realm of self-supervised learning, one of our groups favorite topics! While transfer learning stands as a prominent topic in machine learning, implementation-wise it is not as interesting as self-supervised learning.

In this notebook, we've chosen to delve into the 'Bootstrap Your Own Latent space' (BYOL) method. Traditionally, self-supervised learning methods have been associated with vast datasets and meticulous batch size considerations, often limiting experimentation to tech giants like Google, Meta, and others. However, in this exploration, we aim to break down these barriers, demonstrating that even with limited resources, one can harness the power of unlabeled data to train models effectively.

Since we know everyone is working hard on the Final assignments this notebook is limited in terms of implementation exercises, however, there are some question that maybe help you understand some general idea's.

In [None]:
import copy

import torch
import torchvision
from torchvision import transforms
from torch import nn
import torch.nn.functional as F

from lightly.loss import NegativeCosineSimilarity
from lightly.models.modules import BYOLPredictionHead, BYOLProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.byol_transform import (
    BYOLTransform,
    BYOLView1Transform,
    BYOLView2Transform,
)
from lightly.utils.scheduler import cosine_schedule
import os
from torch.utils.data import Subset

import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA



%matplotlib inline

DATA_DIR = os.path.join(os.getenv('TEACHER_DIR'), 'JHS_data')

If you check the lecture slides, you'll notice that BYOL necessitates the use of two different augmented versions of a single image. Take the opportunity to implement this requirement, and carefully consider the transformations you'll use.

In [None]:
# two crop transform
contrastive_transform = transforms.Compose([
                                   #implement the correct augmentations
                                   transforms.ToTensor(),
                                   ])

test_transform = transforms.Compose([
                                   #leave this as it is
                                   transforms.ToTensor(),
                                   ])

class TwoCropTransform:
    """Create two crops of the same image"""
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return [self.transform(x), self.transform(x)]

We will use the MNIST training dataset for our self-supervised learning training. We will use the test set for evaluation what our model has learned without any labels. 

**Q1.** What will happen if you keep your augmentations to easy?

**Q2.** What happens to your model if you include a specific augmentation like e.g. rotation or [color jitter](https://pytorch.org/vision/main/generated/torchvision.transforms.ColorJitter.html)?

In [None]:
# We disable resizing and gaussian blur for cifar10.
dataset = torchvision.datasets.MNIST(
    DATA_DIR, download=False, train=True, transform=TwoCropTransform(contrastive_transform)
)

testdata = torchvision.datasets.MNIST(
    DATA_DIR, download=False, train=False, transform=test_transform
)

# Create a subset of the training dataset containing only the first 1000 images
subset_indices = range(1000)
testdata = Subset(testdata, subset_indices)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=24,
    shuffle=True,
    drop_last=True,
    num_workers=1,
)

testloader = torch.utils.data.DataLoader(
    testdata,
    batch_size=24,
    shuffle=True,
    drop_last=True,
    num_workers=1,
)

# visualize some examples images
def visualize_augmented_images(dataset, indices):
    fig, axes = plt.subplots(len(indices), 2, figsize=(10, 20))

    for i, index in enumerate(indices):
        image, _ = dataset[index]

        # Plot the original image
        axes[i, 0].imshow(image[0].squeeze(), cmap='gray')
        axes[i, 0].set_title(f'Augmented Image 1')

        # Plot the augmented image
        axes[i, 1].imshow(image[1].squeeze(), cmap='gray')
        axes[i, 1].set_title(f'Augmented Image 2')

    plt.tight_layout()
    plt.show()

    # Example usage:
indices_to_visualize = range(4)
visualize_augmented_images(dataset, indices_to_visualize)

We have already implemented this part for you, however it is important to understand the goal of the projection head.

**Q3.** Why do we add an extra prediction head on top of the student network?

**Q4.** What is the function of the projection head?

In [None]:
class BYOL(nn.Module):
    def __init__(self, backbone):
        super().__init__()

        self.backbone = backbone
        self.projection_head = BYOLProjectionHead(512, 1024, 256)
        self.prediction_head = BYOLPredictionHead(256, 1024, 256)

        self.backbone_momentum = copy.deepcopy(self.backbone)
        self.projection_head_momentum = copy.deepcopy(self.projection_head)

        deactivate_requires_grad(self.backbone_momentum)
        deactivate_requires_grad(self.projection_head_momentum)

    def forward(self, x):
        y = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(y)
        p = self.prediction_head(z)
        return p

    def forward_momentum(self, x):
        y = self.backbone_momentum(x).flatten(start_dim=1)
        z = self.projection_head_momentum(y)
        z = z.detach()
        return z

Alright, let's start training! If you observe a decrease in your loss value, feel free to take a moment to relax. Training may require some time to complete....

In [None]:
resnet = torchvision.models.resnet18(weights=False)
resnet.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) # change first layer of resnet to one channel
backbone = nn.Sequential(*list(resnet.children())[:-1])

model = BYOL(backbone)
criterion = NegativeCosineSimilarity()
optimizer = torch.optim.SGD(model.parameters(), lr=0.06)

epochs = 5
print_interval = 50
total_steps = len(dataloader)

print("Starting Training")
for epoch in range(epochs):
    running_loss = 0.0
    momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)
    for i, batch in enumerate(dataloader, 0):
        x0, x1 = batch[0]
        update_momentum(model.backbone, model.backbone_momentum, m=momentum_val)
        update_momentum(model.projection_head, model.projection_head_momentum, m=momentum_val)
        
        x0 = x0
        x1 = x1
        
        """Implement BYOL"""
        p0 = model(x0)
        z0 = model.forward_momentum(x0)
        p1 = model(x1)
        z1 = model.forward_momentum(x1)
        loss = 1 + (0.5 * (criterion(p0, z1) + criterion(p1, z0)))
        
        running_loss += loss.item()
        
        if (i + 1) % print_interval == 0:
            average_loss = running_loss / print_interval
            print(f'Epoch [{epoch + 1}/{epochs}], Step [{i + 1}/{total_steps}], Loss: {average_loss:.4f}')
            running_loss = 0.0
               
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
    torch.save(model.state_dict(), 'byol_model_epoch_{}.pth'.format(epoch+1))


Okay after the training is done, lets see if the model learned anything!

In [None]:
model_path_prefix = "byol_model_epoch_"

# Iterate through the first 5 epochs
for epoch_nb in range(0, 6):
    # Load the model for the current epoch
    resnet = torchvision.models.resnet18(weights=False)
    resnet.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) # change first layer of resnet to one channel
    backbone = nn.Sequential(*list(resnet.children())[:-1])

    model = BYOL(backbone)
    model_path = model_path_prefix + str(epoch_nb) + ".pth"
    if epoch_nb != 0:
        model.load_state_dict(torch.load(model_path))
        
    model.eval()

    # Extract features from the dataset
    features_list = []
    labels_list = []
    with torch.no_grad():
        for images, labels in testloader:
            features = model.backbone(images).squeeze()
            features_list.append(features)
            labels_list.append(labels)

    features = torch.cat(features_list, dim=0).numpy()
    labels = torch.cat(labels_list, dim=0).numpy()
    

    # Perform t-SNE dimensionality reduction
    tsne = TSNE(n_components=2, random_state=42)
    tsne_features = tsne.fit_transform(features)

    # Plot t-SNE visualization for the current epoch
    plt.figure(figsize=(10, 8))
    for i in range(10):
        indices = labels == i
        plt.scatter(tsne_features[indices, 0], tsne_features[indices, 1], label=str(i), s=10)
    plt.title(f't-SNE Visualization of MNIST Features (Epoch {epoch_nb})')
    plt.xlabel('t-SNE Component 1')
    plt.ylabel('t-SNE Component 2')
    plt.legend()
    plt.show()


From this visualization it is clear that the model can already group certain digits in the MNIST dataset, without being trained with the labels!

**Q5.** What are the primary advantages of utilizing the weights acquired through self-supervised learning when applied to a downstream task? 
