# LSDL CUB, Homework 3. Constrastive Learning [10 pts]

This task is dedicated to contrastive self-supervised methods. We will focus on the SimCLR and BYOL algorithms that were discussed in class. We will conduct experiments on the [STL10](https://cs.stanford.edu/~acoates/stl10/) dataset, which is ideal for pretraining without labels, as it contains 100k unlabeled, 5k training labeled and 8k test labeled images.

To submit the task, you must conduct the experiments described in this notebook and write a report on them in PDF format. Along with the report, the code that allows you to run the experiments must be submitted. Before implementing anything, read all the experiment statements and think about how to better organize the code, do not forget to checkpoint the necessary trained models. We reserve the right to lower the grade for poorly structured code. Be sure to use the **training optimizations** from the [list](https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html) in your pipelines, for example, Automatic Mixed Precision, to speed up the experiments. Also note that the report is a **mandatory part** of the assessment, without it we will not check your assignment. The report must include training curves for all the models you run. Also, make sure that your figures and graphs are readable.

## 0. Supervised baseline [0 pts]

**If not completed, the maximum for the entire task is 0 points**

We will start our study by training a supervised model from a random initial approximation. Use the labeled train for training, and the labeled test for testing. We will use ResNet-18\* as the neural network architecture. We recommend searching for hyperparameters and augmentations for training on STL-10 in articles. The author of the task got an accuracy of about 71-72%.

\**For datasets with a smaller image size than ImageNet (such as CIFAR-10/100, STL-10), it is common to use ResNet-18 with modified first layers: usually, 7x7 convolution is replaced with 3x3, and 2x2 MaxPooling is removed. We suggest you train the regular ResNet-18 in the torchvision implementation to save time.*

In [None]:
# YOUR SOLUTION HERE (⊃｡•́‿•̀｡)⊃━✿✿✿✿✿✿
import torch
import torchvision
import torchvision.transforms as transforms

from tqdm import trange
import wandb

PROJECT_NAME = "HW 3, LSDL 2024. CUB."
wandb.init(project=PROJECT_NAME, name="Supervised")

# Hyperparameters
BATCH_SIZE = 64
LEARNING_RATE = 0.1
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
LR_GAMMA = 0.955
EPOCHS = 100

# Configuration
NUM_WORKERS = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

# Main code
model = torchvision.models.resnet18(pretrained=False).to(device)

mean_stats = [0.485, 0.456, 0.406]
std_stats = [0.229, 0.224, 0.225]
# from the previous homework
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0), ratio=(0.75, 1.33)),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),  # Converts image to PyTorch tensor with values in [0, 1]
    transforms.Normalize(mean=mean_stats, std=std_stats),  # Normalize the tensor,
    transforms.RandomErasing(p=0.3, scale=(0.02, 0.33), ratio=(0.3, 3.3)),
])
test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),  # Converts image to PyTorch tensor with values in [0, 1]
    transforms.Normalize(mean=mean_stats, std=std_stats)  # Normalize the tensor
])

train_set = torchvision.datasets.STL10(root="./", split='train', download=True, transform=train_transforms)
test_set = torchvision.datasets.STL10(root="./", split='test', download=True, transform=test_transforms)    

train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=LR_GAMMA)

scaler = torch.amp.GradScaler()

log_it = 0

for i in trange(EPOCHS):
    model.train()
    for images, labels in enumerate(train_loader):
        optimizer.zero_grad(set_to_none=True)
        
        images = images.to(device)
        labels = labels.to(device)
        
        with torch.autocast(device_type=device, dtype=torch.float16):
            outputs = model(images)
            loss = torch.nn.functional.cross_entropy(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        wandb.log({"train_loss": loss.item()}, step=log_it)
        log_it += 1
        
    model.eval()
    if (i + 1) % 5 == 0:
        with torch.no_grad():
            correct, total = 0, 0
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)
                
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
            accuracy = 100 * correct / total
            
            # print(f'Epoch {i+1}/{EPOCHS}, Test Accuracy: {accuracy:.2f}')
        wandb.log({"test_accuracy": accuracy}, step=log_it)
        
    lr_scheduler.step()
    
torch.save(model.state_dict(), "ft_model.pth")
ft_accuracy = accuracy

device: cuda




Files already downloaded and verified
Files already downloaded and verified


  5%|▌         | 5/100 [00:45<15:34,  9.83s/it]

Epoch 5/100, Test Accuracy: 33.38


 10%|█         | 10/100 [01:30<14:53,  9.93s/it]

Epoch 10/100, Test Accuracy: 49.41


 15%|█▌        | 15/100 [02:16<14:06,  9.96s/it]

Epoch 15/100, Test Accuracy: 50.88


 20%|██        | 20/100 [03:02<13:18,  9.98s/it]

Epoch 20/100, Test Accuracy: 54.14


 25%|██▌       | 25/100 [03:49<12:35, 10.07s/it]

Epoch 25/100, Test Accuracy: 64.00


 30%|███       | 30/100 [04:35<11:43, 10.04s/it]

Epoch 30/100, Test Accuracy: 64.67


 35%|███▌      | 35/100 [05:21<10:51, 10.02s/it]

Epoch 35/100, Test Accuracy: 67.25


 40%|████      | 40/100 [06:29<15:00, 15.01s/it]

Epoch 40/100, Test Accuracy: 69.04


 45%|████▌     | 45/100 [07:50<15:42, 17.13s/it]

Epoch 45/100, Test Accuracy: 70.72


 50%|█████     | 50/100 [09:12<14:39, 17.58s/it]

Epoch 50/100, Test Accuracy: 72.76


 55%|█████▌    | 55/100 [10:34<13:12, 17.61s/it]

Epoch 55/100, Test Accuracy: 70.01


 60%|██████    | 60/100 [11:56<11:43, 17.60s/it]

Epoch 60/100, Test Accuracy: 72.86


 65%|██████▌   | 65/100 [13:17<10:15, 17.59s/it]

Epoch 65/100, Test Accuracy: 72.39


 70%|███████   | 70/100 [14:39<08:49, 17.65s/it]

Epoch 70/100, Test Accuracy: 73.46


 75%|███████▌  | 75/100 [16:01<07:22, 17.70s/it]

Epoch 75/100, Test Accuracy: 74.15


 80%|████████  | 80/100 [17:23<05:53, 17.67s/it]

Epoch 80/100, Test Accuracy: 74.24


 85%|████████▌ | 85/100 [18:34<03:38, 14.57s/it]

Epoch 85/100, Test Accuracy: 74.58


 90%|█████████ | 90/100 [19:21<01:49, 10.96s/it]

Epoch 90/100, Test Accuracy: 74.22


 95%|█████████▌| 95/100 [20:08<00:51, 10.21s/it]

Epoch 95/100, Test Accuracy: 74.76


100%|██████████| 100/100 [20:54<00:00, 12.55s/it]

Epoch 100/100, Test Accuracy: 74.72





In [None]:
import pandas as pd
# results_df = pd.DataFrame({"ft": [accuracy]})

KeyboardInterrupt: 

## 1. SimCLR [1.5 pts]

Implement and train the SimCLR method from [Chen et al, 2020](https://arxiv.org/pdf/2002.05709.pdf). We want you to implement it yourself, so this task does not allow you to borrow code from open sources. Use the unlabeled part of STL-10 as a training set, and use the labeled train to validate the algorithm.

In [None]:
# YOUR SOLUTION HERE (⊃｡•́‿•̀｡)⊃━✿✿✿✿✿✿

# From the paper:
# Default setting. Unless otherwise specified, for data augmentation we use random crop and resize (with random
# flip), color distortions, and Gaussian blur (for details, see
# Appendix A). We use ResNet-50 as the base encoder network, and a 2-layer MLP projection head to project the
# representation to a 128-dimensional latent space. As the
# loss, we use NT-Xent, optimized using LARS with learning
# rate of 4.8 (= 0.3 × BatchSize/256) and weight decay of
# 10−6
# . We train at batch size 4096 for 100 epochs.3 Furthermore, we use linear warmup for the first 10 epochs,
# and decay the learning rate with the cosine decay schedule
# without restarts (Loshchilov & Hutter, 2016).


# Hyperparameters
LEARNING_RATE = 0.3 * BATCH_SIZE / 256
print(f"LEARNING_RATE: {LEARNING_RATE}")
NUM_WARMUP_STEPS = 10


# Appendix A:
# In our default pretraining setting (which is used to train our best models), we utilize random crop (with resize and random
# flip), random color distortion, and random Gaussian blur as the data augmentations. The details of these three augmentations
# are provided below.

# Random crop and resize to 224x224 We use standard Inception-style random cropping (Szegedy et al., 2015). The
# crop of random size (uniform from 0.08 to 1.0 in area) of the original size and a random aspect ratio (default: of
# 3/4 to 4/3) of the original aspect ratio is made. This crop is finally resized to the original size. This has been implemented in Tensorflow as “slim.preprocessing.inception_preprocessing.distorted_bounding_box_crop”, or in Pytorch
# as “torchvision.transforms.RandomResizedCrop”. Additionally, the random crop (with resize) is always followed by a
# random horizontal/left-to-right flip with 50% probability. This is helpful but not essential. By removing this from our default
# augmentation policy, the top-1 linear evaluation drops from 64.5% to 63.4% for our ResNet-50 model trained in 100 epochs

# Color distortion Color distortion is composed by color jittering and color dropping. We find stronger color jittering
# usually helps, so we set a strength parameter.

# Gaussian blur This augmentation is in our default policy. We find it helpful, as it improves our ResNet-50 trained for
# 100 epochs from 63.2% to 64.5%. We blur the image 50% of the time using a Gaussian kernel. We randomly sample
# σ ∈ [0.1, 2.0], and the kernel size is set to be 10% of the image height/width.

# from the orginal paper
def get_color_distortion(s=1.0):
    # s is the strength of color distortion.
    color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)
    color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
    return color_distort

simclr_transforms = transforms.Compose([
    transforms.RandomResizedCrop(size=(224, 224), scale=(0.08, 1.0), ratio=(3/4, 4/3)),
    transforms.RandomHorizontalFlip(),
    get_color_distortion(),
    transforms.GaussianBlur(kernel_size=int(0.1 * 224)),
    transforms.ToTensor(),  # Converts image to PyTorch tensor with values in [0, 1]
    transforms.Normalize(mean=mean_stats, std=std_stats),  # Normalize the tensor,
])

class DoubleImageDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        image1 = self.transform(image)
        image2 = self.transform(image)
        return torch.stack([image1, image2]), label

simclr_train_set = DoubleImageDataset(torchvision.datasets.STL10(root="./", split='train+unlabeled', download=True), transform=simclr_transforms)
simclr_train_loader = torch.utils.data.DataLoader(simclr_train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

# Model for SimCLR
model = torchvision.models.resnet18(pretrained=False).to(device)
output_after_avgpool = None
# Define the hook function
def hook_fn(module, input, output):
    global output_after_avgpool
    output_after_avgpool = output
# Register the hook to the avgpool layer
hook = model.avgpool.register_forward_hook(hook_fn)
mlp = torch.nn.Sequential(
    torch.nn.Linear(512, 512),
    torch.nn.ReLU(),
    torch.nn.Linear(512, 128),
)


# Create the LambdaLR scheduler with the warmup lambda function
def linear_warmup_lambda(current_step):
    if current_step < NUM_WARMUP_STEPS:
        # Linear increase: current_step / num_warmup_steps scales from 0 to 1
        return current_step / NUM_WARMUP_STEPS
    else:
        # After warmup, maintain the base learning rate
        return LR_GAMMA ** (current_step - NUM_WARMUP_STEPS)

optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-6)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=linear_warmup_lambda)

for i in trange(EPOCHS):
    model.train()
    for images, labels in enumerate(train_loader):
        optimizer.zero_grad(set_to_none=True)
        
        images = images.to(device)
        labels = labels.to(device)
        
        with torch.autocast(device_type=device, dtype=torch.float16):
            outputs = model(images)
            loss = torch.nn.functional.cross_entropy(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        wandb.log({"train_loss": loss.item()}, step=log_it)
        log_it += 1
        
    model.eval()
    if (i + 1) % 5 == 0:
        with torch.no_grad():
            correct, total = 0, 0
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)
                
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
            accuracy = 100 * correct / total
            
            # print(f'Epoch {i+1}/{EPOCHS}, Test Accuracy: {accuracy:.2f}')
        wandb.log({"test_accuracy": accuracy}, step=log_it)
        
    lr_scheduler.step()
    
torch.save(model.state_dict(), "ft_model.pth")
ft_accuracy = accuracy

## 2. BYOL [2.5 pts]

Similar to the previous task, implement and train the BYOL method from [Grill et al, 2020](https://arxiv.org/pdf/2006.07733.pdf). To check that the projections do not collapse (the variable $z$ from the original paper), plot the standard deviation of $z$ throughout training.

In [None]:
# YOUR SOLUTION HERE (⊃｡•́‿•̀｡)⊃━✿✿✿✿✿✿

## 3. t-SNE [1.5 pts]

Using the t-SNE method, visualize the embeddings of images from the training and test samples that are obtained from supervised, SimCLR and BYOL models. The output of the average pooling of the model is taken as embeddings. The points corresponding to each of the 10 classes should be plotted with the same color.

In [None]:
# YOUR SOLUTION HERE (⊃｡•́‿•̀｡)⊃━✿✿✿✿✿✿

## 4. Linear probing [1 pts]

Train a linear probe for self-supervised models, compare the quality with supervised training.

In [None]:
# YOUR SOLUTION HERE (⊃｡•́‿•̀｡)⊃━✿✿✿✿✿✿

## 5. Fine-tuning [1.5 pts]

Finally, fine-tune the self-supervised models to STL-10 classification. If you did everything correctly, the quality should be several percent higher than that of the baseline. Similar to task 3, draw how the embeddings changed after the fine-tuning process.

In [None]:
# YOUR SOLUTION HERE (⊃｡•́‿•̀｡)⊃━✿✿✿✿✿✿

## 6. OOD robustness [2 pts]

Now, we have 5 different models:

- Supervised
- SimCLR + linear probing
- SimCLR + fine-tuning
- BYOL + linear probing
- BYOL + fine-tuning

We will compare the models by robustness on out-of-distribution objects. As an OOD dataset, we will take the CIFAR-10 test sample, which has almost the same classes as STL-10 (9/10 classes).
The only mismatch in CIFAR-10 is the "frog" class, so drop the images of this class. Compare the trained models by OOD accuracy.

In [None]:
# YOUR SOLUTION HERE (⊃｡•́‿•̀｡)⊃━✿✿✿✿✿✿

## Bonus. MoCo [2 pts]

As a bonus, let's look at another contrastive self-supervised model, MoCov2, from [He et al, 2019](https://arxiv.org/pdf/1911.05722), [He et al, 2020](https://arxiv.org/pdf/2003.04297.pdf). Conduct all the experiments described above with this model.

In [None]:
# YOUR SOLUTION HERE (⊃｡•́‿•̀｡)⊃━✿✿✿✿✿✿