### SimCLR

A Simple Framework for Contrastive Learning of Visual Representations

In [4]:
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
import random
import torch.nn as nn
import torch.nn.functional as F
import torchvision as tv
import torchvision.datasets as tvd
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from typing import Union


# Setting device and seed for experiments reproductibility

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
seed = 42
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [58]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])    
])

train_dataset = tvd.CIFAR10(root="./data/cifar/train", download=True, transform=transform, train=True)
test_dataset = tvd.CIFAR10(root="./data/cifar/test", download=True, transform=transform, train=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar/train/cifar-10-python.tar.gz


100%|██████████████████████████████████████████████████████████████████████████████████| 170498071/170498071 [00:08<00:00, 19644378.69it/s]


Extracting ./data/cifar/train/cifar-10-python.tar.gz to ./data/cifar/train
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar/test/cifar-10-python.tar.gz


100%|██████████████████████████████████████████████████████████████████████████████████| 170498071/170498071 [00:08<00:00, 20205291.82it/s]


Extracting ./data/cifar/test/cifar-10-python.tar.gz to ./data/cifar/test


In [7]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [59]:
batch_size = 16
train_dl = DataLoader(dataset=train_dataset, shuffle=True, batch_size=batch_size)
test_dl = DataLoader(dataset=test_dataset, shuffle=False, batch_size=batch_size)

### Basic model

In [60]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        
        self.fc1 = nn.Linear(16 * 5 * 5, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [61]:
epochs = 10
lr = 1e-3
momentum = 0.9

net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum)

In [62]:
for epoch in range(epochs): 
    running_loss = 0.0
    for idx, data in enumerate(train_dl):
        inputs, labels = data
        
        # Zero the parameter gradients
        optimizer.zero_grad()
    
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item()
        if idx % 2000 == 1999:    # print every 2000 mini-batches
            print(f"Epoch:{epoch + 1} Minibatch end: {idx + 1:5d} Loss: {running_loss / 2000:.3f}")
            running_loss = 0.0

[1,  2000] loss: 2.223
[2,  2000] loss: 1.638
[3,  2000] loss: 1.432
[4,  2000] loss: 1.320
[5,  2000] loss: 1.233
[6,  2000] loss: 1.155
[7,  2000] loss: 1.084
[8,  2000] loss: 1.034
[9,  2000] loss: 0.980
[10,  2000] loss: 0.927
Finished Training


In [63]:
# Save the model
PATH = f"./cifar_net_{epochs}_{lr}_m{momentum}.pth"
torch.save(net.state_dict(), PATH)

### Computing accuracy for the basic model

In [66]:
correct = 0
total = 0

with torch.no_grad():
    for data in test_dl:
        images, labels = data
        # calculate outputs by running images through the network
        outputs = net(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the {len(test_dl)} test images: {100 * correct // total:.2f} %')

Accuracy of the network on the 625 test images: 62.00 %


### SimCLR learnig procedure

### Compare the two training procedures