# Tutorial 6-4: The Self-Supervised Evolution â€“ "From SimCLR to SimSiam"

**Course:** CSEN 342: Deep Learning  
**Topic:** Self-Supervised Learning (SSL), Contrastive Learning, and Siamese Networks

## Objective
Deep learning typically requires massive labeled datasets. **Self-Supervised Learning (SSL)** removes this bottleneck by generating its own labels from the data itself.

The core idea is simple: **Invariance**. If we take an image of a dog and crop it, rotate it, or change its color, it is *still* an image of a dog. The network should output the same feature vector for both versions.

In this tutorial, we will trace the evolution of modern SSL:
1.  **SimCLR (2020):** Uses **Contrastive Learning**. It pulls augmentations of the *same* image together (Positives) and pushes augmentations of *different* images apart (Negatives).
2.  **SimSiam (2021):** Removes the need for negative pairs entirely. It uses a **Siamese Network** with a **Stop-Gradient** operation to prevent the model from cheating (collapsing to a constant solution).

We will evaluate our self-supervised model using a **Linear Probe**: training a simple classifier on the frozen features to see if the network learned meaningful concepts (like "dog" or "plane") without supervision.

---

## Part 1: The Engine (Data Augmentation)

The most critical component of SSL is the augmentation pipeline. We need to generate two "views" ($x_i, x_j$) of every image. 

We use **CIFAR-10**. We will define a `SimCLRTransform` that applies:
1.  Random Resized Crop (forces model to learn parts-to-whole).
2.  Random Horizontal Flip.
3.  Color Jitter (forces model to ignore color histograms).
4.  Grayscale (optional, but helps).

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os

# Config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128 # Larger batch sizes are better for SimCLR (more negatives)
epochs = 5       # We keep it short for the tutorial; real SSL takes 100+ epochs
proj_dim = 128   # Dimension of the projection head

# 1. Define The Augmentation Wrapper
class SimCLRTransform:
    """
    Generates two different random augmentations of the same image.
    """
    def __init__(self, size=32):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(size=size, scale=(0.2, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.ToTensor(),
            # Standard CIFAR-10 Normalization
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])

    def __call__(self, x):
        # Return two views
        return self.transform(x), self.transform(x)

# 2. Load Data
data_root = '../data'
os.makedirs(data_root, exist_ok=True)

# Train set (Unlabeled - we ignore the labels!)
train_dataset = torchvision.datasets.CIFAR10(
    root=data_root, train=True, download=True, transform=SimCLRTransform()
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

# Test set (Labeled - for linear probe evaluation later)
# Standard transform for testing
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
test_dataset = torchvision.datasets.CIFAR10(
    root=data_root, train=False, download=True, transform=test_transform
)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 3. Visualization
def imshow(img):
    # Un-normalize for display
    img = img * torch.tensor([0.2023, 0.1994, 0.2010]).view(3, 1, 1) + torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')

# Get a batch
(x_i, x_j), _ = next(iter(train_loader))

plt.figure(figsize=(10, 4))
for i in range(4):
    plt.subplot(2, 4, i+1)
    imshow(x_i[i])
    plt.title("View 1")
    plt.subplot(2, 4, i+5)
    imshow(x_j[i])
    plt.title("View 2")
plt.show()

---

## Part 2: SimCLR (Contrastive Learning)

**Architecture:**
1.  **Backbone ($f$):** A ResNet-18 (without the final classification layer). This extracts representation $h$.
2.  **Projection Head ($g$):** A small MLP (Linear $\to$ ReLU $\to$ Linear) that maps $h$ to $z$.

**Why the Projection Head?** 
Research shows that the contrastive loss destroys some information (like color/rotation) to achieve invariance. We want the *Backbone* ($h$) to keep that information for downstream tasks, so we perform the destructive contrastive step in the *Projection Head* ($z$) instead.

**The Loss (NT-Xent):**
For a pair of positive images $(i, j)$, we maximize similarity while minimizing similarity to all other $2(N-1)$ images in the batch.

In [None]:
class SimCLR(nn.Module):
    def __init__(self, base_model=None):
        super(SimCLR, self).__init__()
        
        # 1. Backbone (ResNet-18)
        # We drop the final FC layer to get features directly
        resnet = torchvision.models.resnet18(pretrained=False)
        # CIFAR-10 is small (32x32), so we replace the first 7x7 conv with 3x3
        # to avoid downsampling too aggressively at the start.
        resnet.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        resnet.maxpool = nn.Identity() # Remove maxpool for small images
        
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])
        feature_dim = resnet.fc.in_features

        # 2. Projection Head (MLP)
        # Maps 512 -> 128
        self.projection_head = nn.Sequential(
            nn.Linear(feature_dim, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, proj_dim)
        )

    def forward(self, x):
        h = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(h)
        return h, z

# NT-Xent Loss Function
def nt_xent_loss(z_i, z_j, temperature=0.5):
    batch_size = z_i.shape[0]
    
    # Concatenate all features: [z_i, z_j] -> size (2N, D)
    features = torch.cat([z_i, z_j], dim=0)
    
    # Calculate Cosine Similarity Matrix
    features = F.normalize(features, dim=1)
    similarity_matrix = torch.matmul(features, features.T) # (2N, 2N)
    
    # Remove self-similarity (diagonal)
    # We construct a mask to ignore the diagonal
    mask = torch.eye(2 * batch_size, dtype=torch.bool).to(device)
    
    # For each item i (0...N-1), its positive pair is at i + N
    # For each item i+N, its positive pair is at i
    labels = torch.cat([
        torch.arange(batch_size) + batch_size,
        torch.arange(batch_size)
    ]).to(device)
    
    # Discard diagonal elements from similarity matrix for loss calculation
    # This part can be tricky in PyTorch. A simpler way is to use CrossEntropy directly.
    # The logits are the similarity scores / temperature
    logits = similarity_matrix / temperature
    
    # We mask out the self-similarity by setting it to a very large negative number
    logits.masked_fill_(mask, -9e15)
    
    # Cross Entropy calculates log_softmax automatically
    loss = F.cross_entropy(logits, labels)
    return loss

model_simclr = SimCLR().to(device)
optimizer_simclr = optim.Adam(model_simclr.parameters(), lr=1e-3)

print("SimCLR Model Initialized.")

---

## Part 3: The Evolution to SimSiam

SimCLR requires **Negative Pairs** (the other images in the batch) to prevent collapse. If we only pulled positive pairs together, the network would just output `[0, 0, 0]` for every image, achieving 0 loss.

**SimSiam** solves this *without* negatives. It uses two tricks:
1.  **Predictor Head:** One branch has an extra MLP ($p$) that tries to predict the output of the other branch.
2.  **Stop-Gradient:** We calculate loss against the *target* branch, but we **do not** backpropagate errors through the target branch. This turns the target branch into a stable "teacher" for the predictor.

**Loss:** Negative Cosine Similarity.
$$ D(p_1, z_2) = - \frac{p_1}{\|p_1\|_2} \cdot \frac{z_2}{\|z_2\|_2} $$

In [None]:
class SimSiam(nn.Module):
    def __init__(self):
        super(SimSiam, self).__init__()
        
        # Reuse the SimCLR backbone/projection structure
        self.simclr_base = SimCLR()
        feature_dim = self.simclr_base.projection_head[2].out_features

        # 3. Predictor Head (MLP)
        # Used ONLY in SimSiam on the "student" branch
        self.predictor = nn.Sequential(
            nn.Linear(feature_dim, 64),
            nn.ReLU(),
            nn.Linear(64, feature_dim)
        )

    def forward(self, x1, x2):
        # Get projections (z)
        _, z1 = self.simclr_base(x1)
        _, z2 = self.simclr_base(x2)
        
        # Predict the other view
        p1 = self.predictor(z1)
        p2 = self.predictor(z2)
        
        return p1, p2, z1.detach(), z2.detach() # DETACH IS CRITICAL
    # The torch.Tensor.detach() method is used to separate a tensor from its current computation graph,
    # effectively preventing any further gradient calculations from flowing back through that point. 
    # It returns a new tensor that shares the same underlying data storage as the original but has 
    # requires_grad=False. 

def negative_cosine_similarity(p, z):
    p = F.normalize(p, dim=1)
    z = F.normalize(z, dim=1)
    return -(p * z).sum(dim=1).mean()

model_simsiam = SimSiam().to(device)
optimizer_simsiam = optim.Adam(model_simsiam.parameters(), lr=1e-3)

print("SimSiam Model Initialized.")

---

## Part 4: Training Loop

We will train SimSiam for a few epochs. (Training SimCLR follows the same loop structure, just different loss, but we will focus on SimSiam here as it is the more advanced method).

**Note:** Real SSL training takes 100-1000 epochs. We will run 3 epochs just to verify the loss decreases.

In [None]:
print("Starting SimSiam Training (Pre-training)...")

model_simsiam.train()
losses = []

for epoch in range(epochs):
    total_loss = 0
    for i, ((x1, x2), _) in enumerate(train_loader):
        x1, x2 = x1.to(device), x2.to(device)
        
        optimizer_simsiam.zero_grad()
        
        # Forward Pass
        # p1 predicts z2, p2 predicts z1
        p1, p2, z1, z2 = model_simsiam(x1, x2)
        
        # Symmetric Loss
        # L = 0.5 * D(p1, z2) + 0.5 * D(p2, z1)
        loss = negative_cosine_similarity(p1, z2) / 2 + negative_cosine_similarity(p2, z1) / 2
        
        loss.backward()
        optimizer_simsiam.step()
        
        total_loss += loss.item()
        
    avg_loss = total_loss / len(train_loader)
    losses.append(avg_loss)
    print(f"Epoch [{epoch+1}/{epochs}] Loss: {avg_loss:.4f}")

plt.plot(losses)
plt.title("SimSiam Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Neg Cosine Similarity")
plt.show()

---

## Part 5: The Linear Probe Evaluation

How do we know if the model learned anything? 

We freeze the backbone (disable gradient updates). We attach a fresh Linear Layer to the output of the backbone (features $h$). We train *only* this linear layer on the labeled training data for 1 epoch.

If the accuracy is high (e.g., >30-40% for just 1 epoch of linear training), it means the backbone has learned to separate the classes in the feature space purely from unlabeled data.

In [None]:
print("Starting Linear Probe Evaluation...")

# 1. Create a Linear Classifier on top of the FROZEN backbone
class LinearProbe(nn.Module):
    def __init__(self, backbone, feature_dim=512, num_classes=10):
        super(LinearProbe, self).__init__()
        self.backbone = backbone
        self.fc = nn.Linear(feature_dim, num_classes)

    def forward(self, x):
        with torch.no_grad(): # FREEZE BACKBONE
            h = self.backbone(x).flatten(start_dim=1)
        return self.fc(h)

probe_model = LinearProbe(model_simsiam.simclr_base.backbone).to(device)
probe_opt = optim.Adam(probe_model.fc.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# 2. Reload data (Standard transform, not SimCLR transform, because we need labels now)
train_ds_labeled = torchvision.datasets.CIFAR10(
    root=data_root, train=True, download=True, 
    transform=test_transform # Use simple transform for probing
)
train_loader_labeled = DataLoader(train_ds_labeled, batch_size=128, shuffle=True)

# 3. Train the Probe just for one epoch
probe_model.train()
for i, (imgs, labels) in enumerate(train_loader_labeled):
    imgs, labels = imgs.to(device), labels.to(device)
    
    probe_opt.zero_grad()
    preds = probe_model(imgs)
    loss = criterion(preds, labels)
    loss.backward()
    probe_opt.step()
    
    if i % 100 == 0: print(f"Probe Step {i}, Loss: {loss.item():.4f}")

# 4. Test Accuracy
probe_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = probe_model(imgs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

acc = 100 * correct / total
print(f"\nLinear Probe Accuracy: {acc:.2f}%")
print("(Random guessing is 10%. Anything above 30% proves SSL worked.)")

### Discussion

**What just happened?**
You trained a ResNet backbone without ever showing it a label (SimSiam phase). It learned to recognize patterns by figuring out that a cropped dog and a rotated dog are the "same thing".

Then, you froze that backbone and trained a tiny linear layer. The fact that the linear layer could classify images with decent accuracy proves that the **backbone learned semantic features** (it clustered dogs near dogs and planes near planes in the vector space) purely through self-supervision.