<a href="https://colab.research.google.com/github/sthalles/SimCLR/blob/simclr-refactor/feature_eval/mini_batch_logistic_regression_evaluator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [31]:
import torch
import sys
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt
import torchvision

In [32]:
%pip install gdown

Note: you may need to restart the kernel to use updated packages.


In [33]:
def get_file_id_by_model(folder_name):
  file_id = {'resnet18_100-epochs_stl10': '14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF',
             'resnet18_100-epochs_cifar10': '1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C',
             'resnet50_50-epochs_stl10': '1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu'}
  return file_id.get(folder_name, "Model not found.")

In [34]:
folder_name = 'resnet50_50-epochs_stl10'
file_id = get_file_id_by_model(folder_name)
print(folder_name, file_id)

resnet50_50-epochs_stl10 1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu


In [35]:
# download and extract model files
# os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))
# os.system('unzip {}'.format(folder_name))
# !ls

In [36]:
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets

In [37]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

Using device: cuda


In [38]:
def get_stl10_data_loaders(download, shuffle=False, batch_size=256):
  train_dataset = datasets.STL10('./data', split='train', download=download,
                                  transform=transforms.ToTensor())

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=3, drop_last=False, shuffle=shuffle)
  
  test_dataset = datasets.STL10('./data', split='test', download=download,
                                  transform=transforms.ToTensor())

  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=3, drop_last=False, shuffle=shuffle)#changed num of workers
  return train_loader, test_loader

def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):
  train_dataset = datasets.CIFAR10('./data', train=True, download=download,
                                  transform=transforms.ToTensor())

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=3, drop_last=False, shuffle=shuffle)
  
  test_dataset = datasets.CIFAR10('./data', train=False, download=download,
                                  transform=transforms.ToTensor())

  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=3, drop_last=False, shuffle=shuffle)
  return train_loader, test_loader

In [39]:
# with open(os.path.join('./config.yml')) as file:
#   config = yaml.load(file)

class config_class:
  arch = "resnet18"
  dataset_name = "cifar10"
  def __init__(self):
    self.arch
    self.dataset_name

config = config_class()
config.arch = "resnet18"
config.dataset_name = "cifar10"

In [40]:
from resnet_simclr import ResNetSimCLR
# from models.resnet_simclr import  resnet_simclr
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if config.arch == 'resnet18':
  model = ResNetSimCLR(base_model='resnet18', out_dim=10).to(device)
elif config.arch == 'resnet50':
  model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)

In [41]:
# checkpoint = torch.load('checkpoint_0040.pth.tar', map_location=device)
# state_dict = checkpoint['state_dict']

# for k in list(state_dict.keys()):

#   if k.startswith('backbone.'):
#     if k.startswith('backbone') and not k.startswith('backbone.fc'):
#       # remove prefix
#       state_dict[k[len("backbone."):]] = state_dict[k]
#   del state_dict[k]

In [42]:
# log = model.load_state_dict(state_dict, strict=False)
# assert log.missing_keys == ['fc.weight', 'fc.bias']

In [43]:
if config.dataset_name == 'cifar10':
  train_loader, test_loader = get_cifar10_data_loaders(download=True)
elif config.dataset_name == 'stl10':
  train_loader, test_loader = get_stl10_data_loaders(download=True)
print("Dataset:", config.dataset_name)

Files already downloaded and verified
Files already downloaded and verified
Dataset: cifar10


In [44]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)
criterion = torch.nn.CrossEntropyLoss().to(device)

In [45]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [46]:
#visualize the dataset

import torch
from torchvision import models
from torchsummary import summary
import torchvision.models as models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summary(model, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 16, 16]           9,408
       BatchNorm2d-2           [-1, 64, 16, 16]             128
              ReLU-3           [-1, 64, 16, 16]               0
         MaxPool2d-4             [-1, 64, 8, 8]               0
            Conv2d-5             [-1, 64, 8, 8]          36,864
       BatchNorm2d-6             [-1, 64, 8, 8]             128
              ReLU-7             [-1, 64, 8, 8]               0
            Conv2d-8             [-1, 64, 8, 8]          36,864
       BatchNorm2d-9             [-1, 64, 8, 8]             128
             ReLU-10             [-1, 64, 8, 8]               0
       BasicBlock-11             [-1, 64, 8, 8]               0
           Conv2d-12             [-1, 64, 8, 8]          36,864
      BatchNorm2d-13             [-1, 64, 8, 8]             128
             ReLU-14             [-1, 6

In [47]:
epochs = 10
for epoch in range(epochs):
  print("Started epoch {}".format(epoch))
  top1_train_accuracy = 0
  for counter, (x_batch, y_batch) in enumerate(train_loader):
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)
    print("Batch {}".format(counter))
    print("x_batch:", x_batch.shape)
    print("y_batch:", y_batch.shape)
    

    logits = model(x_batch)
    loss = criterion(logits, y_batch)
    top1 = accuracy(logits, y_batch, topk=(1,))
    top1_train_accuracy += top1[0]

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  top1_train_accuracy /= (counter + 1)
  top1_accuracy = 0
  top5_accuracy = 0
  for counter, (x_batch, y_batch) in enumerate(test_loader):
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    logits = model(x_batch)
  
    top1, top5 = accuracy(logits, y_batch, topk=(1,5))
    top1_accuracy += top1[0]
    top5_accuracy += top5[0]
  
  top1_accuracy /= (counter + 1)
  top5_accuracy /= (counter + 1)
  print(f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}")

Started epoch 0
Batch 0
x_batch: torch.Size([256, 3, 32, 32])
y_batch: torch.Size([256])
Batch 1
x_batch: torch.Size([256, 3, 32, 32])
y_batch: torch.Size([256])
Batch 2
x_batch: torch.Size([256, 3, 32, 32])
y_batch: torch.Size([256])
Batch 3
x_batch: torch.Size([256, 3, 32, 32])
y_batch: torch.Size([256])
Batch 4
x_batch: torch.Size([256, 3, 32, 32])
y_batch: torch.Size([256])
Batch 5
x_batch: torch.Size([256, 3, 32, 32])
y_batch: torch.Size([256])
Batch 6
x_batch: torch.Size([256, 3, 32, 32])
y_batch: torch.Size([256])
Batch 7
x_batch: torch.Size([256, 3, 32, 32])
y_batch: torch.Size([256])
Batch 8
x_batch: torch.Size([256, 3, 32, 32])
y_batch: torch.Size([256])
Batch 9
x_batch: torch.Size([256, 3, 32, 32])
y_batch: torch.Size([256])
Batch 10
x_batch: torch.Size([256, 3, 32, 32])
y_batch: torch.Size([256])
Batch 11
x_batch: torch.Size([256, 3, 32, 32])
y_batch: torch.Size([256])
Batch 12
x_batch: torch.Size([256, 3, 32, 32])
y_batch: torch.Size([256])
Batch 13
x_batch: torch.Size([25

KeyboardInterrupt: 

In [None]:
#visualize the dataset

import torch
from torchvision import models
from torchsummary import summary
import torchvision.models as models
#show what the model is made of we can compare it to the original resnet 18
# resnet18 = models.resnet18()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summary(model, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 16, 16]           9,408
       BatchNorm2d-2           [-1, 64, 16, 16]             128
              ReLU-3           [-1, 64, 16, 16]               0
         MaxPool2d-4             [-1, 64, 8, 8]               0
            Conv2d-5             [-1, 64, 8, 8]          36,864
       BatchNorm2d-6             [-1, 64, 8, 8]             128
              ReLU-7             [-1, 64, 8, 8]               0
            Conv2d-8             [-1, 64, 8, 8]          36,864
       BatchNorm2d-9             [-1, 64, 8, 8]             128
             ReLU-10             [-1, 64, 8, 8]               0
       BasicBlock-11             [-1, 64, 8, 8]               0
           Conv2d-12             [-1, 64, 8, 8]          36,864
      BatchNorm2d-13             [-1, 64, 8, 8]             128
             ReLU-14             [-1, 6

In [None]:
# from torchviz import make_dot
# y = model(img)
# make_dot(y, params=dict(list(auto_model.named_parameters()))).render("torchviz", format="png")

In [None]:
#compute model loss



In [None]:
#copy and pasted because didn't manage to import it
from torch import nn

class LinearClassifier(nn.Module):
    def __init__(self):
        super(LinearClassifier, self).__init__()
        self.fc1 = nn.Linear(128, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [None]:
#joined model 
class JoinedModel(nn.Module):
    def __init__(self,num_classes=10):
        
        super(JoinedModel, self).__init__()
        #uses the resnets weights already trained
        self.simCLR = model
        #classifier parts
        self.classifier = LinearClassifier()
    def forward(self, x):
        x = self.simCLR.forward(x,no_projection_head=True)
        #classifier part
        x = self.classifier(x)
        return x
joined_model = JoinedModel().to(device)

In [None]:
#verify the model shape
summary(joined_model, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 16, 16]           9,408
       BatchNorm2d-2           [-1, 64, 16, 16]             128
              ReLU-3           [-1, 64, 16, 16]               0
         MaxPool2d-4             [-1, 64, 8, 8]               0
            Conv2d-5             [-1, 64, 8, 8]          36,864
       BatchNorm2d-6             [-1, 64, 8, 8]             128
              ReLU-7             [-1, 64, 8, 8]               0
            Conv2d-8             [-1, 64, 8, 8]          36,864
       BatchNorm2d-9             [-1, 64, 8, 8]             128
             ReLU-10             [-1, 64, 8, 8]               0
       BasicBlock-11             [-1, 64, 8, 8]               0
           Conv2d-12             [-1, 64, 8, 8]          36,864
      BatchNorm2d-13             [-1, 64, 8, 8]             128
             ReLU-14             [-1, 6

In [None]:
#freeze all weights 
for param in joined_model.parameters():
    param.requires_grad = False
#unfreeze the classifier
joined_model.classifier.fc1.weight.requires_grad = True
joined_model.classifier.fc1.bias.requires_grad = True
joined_model.classifier.fc2.weight.requires_grad = True
joined_model.classifier.fc2.bias.requires_grad = True

In [None]:
#verify weight are frozen

for name, param in joined_model.named_parameters():
    print(name, param.requires_grad)


simCLR.backbone.conv1.weight False
simCLR.backbone.bn1.weight False
simCLR.backbone.bn1.bias False
simCLR.backbone.layer1.0.conv1.weight False
simCLR.backbone.layer1.0.bn1.weight False
simCLR.backbone.layer1.0.bn1.bias False
simCLR.backbone.layer1.0.conv2.weight False
simCLR.backbone.layer1.0.bn2.weight False
simCLR.backbone.layer1.0.bn2.bias False
simCLR.backbone.layer1.1.conv1.weight False
simCLR.backbone.layer1.1.bn1.weight False
simCLR.backbone.layer1.1.bn1.bias False
simCLR.backbone.layer1.1.conv2.weight False
simCLR.backbone.layer1.1.bn2.weight False
simCLR.backbone.layer1.1.bn2.bias False
simCLR.backbone.layer2.0.conv1.weight False
simCLR.backbone.layer2.0.bn1.weight False
simCLR.backbone.layer2.0.bn1.bias False
simCLR.backbone.layer2.0.conv2.weight False
simCLR.backbone.layer2.0.bn2.weight False
simCLR.backbone.layer2.0.bn2.bias False
simCLR.backbone.layer2.0.downsample.0.weight False
simCLR.backbone.layer2.0.downsample.1.weight False
simCLR.backbone.layer2.0.downsample.1.bias 

In [49]:
#train the joined model
learning_rate = 1e-3
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, joined_model.parameters()), lr=learning_rate)


In [51]:
retrain_epochs = 10
#train the classification layer
for epoch in range(retrain_epochs):
    print("Epoch: {}".format(epoch))
    for i, (images, labels) in enumerate(train_loader):
        print(i)
        images = images.to(device)
        labels = labels.to(device)

        outputs = joined_model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #loss
        print("Loss: {}".format(loss.item()))
        print("Epoch: {}/{}".format(epoch, retrain_epochs - 1))

Epoch: 0
0
Loss: 0.5399525761604309
Epoch: 0/9
1
Loss: 0.519212007522583
Epoch: 0/9
2
Loss: 0.45982852578163147
Epoch: 0/9
3
Loss: 0.545820415019989
Epoch: 0/9
4
Loss: 0.6507019400596619
Epoch: 0/9
5
Loss: 0.5678039789199829
Epoch: 0/9
6
Loss: 0.5231722593307495
Epoch: 0/9
7
Loss: 0.65208500623703
Epoch: 0/9
8
Loss: 0.6153600811958313
Epoch: 0/9
9
Loss: 0.5568593144416809
Epoch: 0/9
10
Loss: 0.5739806294441223
Epoch: 0/9
11
Loss: 0.712602972984314
Epoch: 0/9
12
Loss: 0.6553792357444763
Epoch: 0/9
13
Loss: 0.5993602871894836
Epoch: 0/9
14
Loss: 0.6396824717521667
Epoch: 0/9
15
Loss: 0.6791321039199829
Epoch: 0/9
16
Loss: 0.704905092716217
Epoch: 0/9
17
Loss: 0.7251819372177124
Epoch: 0/9
18
Loss: 0.6250635981559753
Epoch: 0/9
19
Loss: 0.6624881029129028
Epoch: 0/9
20
Loss: 0.6796400547027588
Epoch: 0/9
21
Loss: 0.6937943696975708
Epoch: 0/9
22
Loss: 0.6520704030990601
Epoch: 0/9
23
Loss: 0.6232017278671265
Epoch: 0/9
24
Loss: 0.6794021129608154
Epoch: 0/9
25
Loss: 0.7212855219841003
Epo

In [52]:
# evaluate the joined model
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = joined_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the model on the 10000 test images: {} %'.format(100 * (correct / total)))
    # save the model
    torch.save(joined_model.state_dict(), './joined_model_simCLR_Classif.pth')

Accuracy of the model on the 10000 test images: 62.760000000000005 %


In [None]:
#save the joined_model for resnet supervised encoder
torch.save(joined_model.state_dict(), './saved_models/joined_model_SimCLR_LinearClass.pth')
