## This notebook is a remake of the original for the particular task of taking only the first 100 examples as labeled. We use all the training for self supervision using simCLR. Then, we regularly train our classifier and get arout 37% accuracy on test, which is better than trainint from scratch (around 20%), but not as good as using pretrained networks on Imagenet.

In [1]:
import torch
import sys
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt
import torchvision

In [2]:
!pip install gdown



In [3]:
def get_file_id_by_model(folder_name):
  file_id = {'resnet18_100-epochs_stl10': '14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF',
             'resnet18_100-epochs_cifar10': '1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C',
             'resnet50_50-epochs_stl10': '1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu'}
  return file_id.get(folder_name, "Model not found.")

In [4]:
folder_name = 'resnet18_100-epochs_cifar10'
file_id = get_file_id_by_model(folder_name)
print(folder_name, file_id)

resnet18_100-epochs_cifar10 1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C


In [5]:
# download and extract model files
os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))
os.system('unzip {}'.format(folder_name))
!ls

Downloading...
From: https://drive.google.com/uc?id=1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C
To: /users/eleves-a/2018/tom.sander/MVA/SimcCLRTom/SimcCLRTom/feature_eval/resnet18_100-epochs_cifar10.zip
100%|██████████| 101M/101M [00:01<00:00, 56.9MB/s] 


Archive:  resnet18_100-epochs_cifar10.zip
  inflating: checkpoint_0100.pth.tar  
  inflating: config.yml              
  inflating: events.out.tfevents.1610901418.4cb2c837708d.2683796.0  
  inflating: run.log                 
checkpoint_0100.pth.tar
config.yml
events.out.tfevents.1610901418.4cb2c837708d.2683796.0
mini_batch_logistic_regression_evaluator.ipynb
resnet18_100-epochs_cifar10.zip
run.log


In [6]:
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

Using device: cuda


In [17]:
def get_stl10_data_loaders(download, shuffle=False, batch_size=256):
  train_dataset = datasets.STL10('./data', split='train', download=download,
                                  transform=transforms.ToTensor())

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=0, drop_last=False, shuffle=shuffle)
  
  test_dataset = datasets.STL10('./data', split='test', download=download,
                                  transform=transforms.ToTensor())

  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=10, drop_last=False, shuffle=shuffle)
  return train_loader, test_loader

def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):
  train_dataset_ = datasets.CIFAR10('./data', train=True, download=download,
                                  transform=transforms.ToTensor())
  indices_train = torch.arange(0,100)
  train_dataset = torch.utils.data.Subset(train_dataset_, indices_train)

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=0, drop_last=False, shuffle=shuffle)
  
  test_dataset = datasets.CIFAR10('./data', train=False, download=download,
                                  transform=transforms.ToTensor())

  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=10, drop_last=False, shuffle=shuffle)
  return train_loader, test_loader

In [9]:
with open(os.path.join('./config.yml')) as file:
  config = yaml.load(file, Loader=yaml.Loader)

In [10]:
if config.arch == 'resnet18':
  model = torchvision.models.resnet18(pretrained=False, num_classes=10).to(device)
elif config.arch == 'resnet50':
  model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)

In [12]:
checkpoint = torch.load('checkpoint_0100.pth.tar', map_location=device)
state_dict = checkpoint['state_dict']

for k in list(state_dict.keys()):

  if k.startswith('backbone.'):
    if k.startswith('backbone') and not k.startswith('backbone.fc'):
      # remove prefix
      state_dict[k[len("backbone."):]] = state_dict[k]
  del state_dict[k]

In [13]:
log = model.load_state_dict(state_dict, strict=False)
assert log.missing_keys == ['fc.weight', 'fc.bias']

In [18]:
if config.dataset_name == 'cifar10':
  train_loader, test_loader = get_cifar10_data_loaders(download=True)
elif config.dataset_name == 'stl10':
  train_loader, test_loader = get_stl10_data_loaders(download=True)
print("Dataset:", config.dataset_name)

Files already downloaded and verified
Files already downloaded and verified
Dataset: cifar10


In [19]:
# freeze all layers but the last fc
for name, param in model.named_parameters():
    if name not in ['fc.weight', 'fc.bias']:
        param.requires_grad = False

parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2  # fc.weight, fc.bias

In [20]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)
criterion = torch.nn.CrossEntropyLoss().to(device)

In [21]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [22]:
epochs = 100
for epoch in range(epochs):
  top1_train_accuracy = 0
  for counter, (x_batch, y_batch) in enumerate(train_loader):
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    logits = model(x_batch)
    loss = criterion(logits, y_batch)
    top1 = accuracy(logits, y_batch, topk=(1,))
    top1_train_accuracy += top1[0]

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  top1_train_accuracy /= (counter + 1)
  top1_accuracy = 0
  top5_accuracy = 0
  for counter, (x_batch, y_batch) in enumerate(test_loader):
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    logits = model(x_batch)
  
    top1, top5 = accuracy(logits, y_batch, topk=(1,5))
    top1_accuracy += top1[0]
    top5_accuracy += top5[0]
  
  top1_accuracy /= (counter + 1)
  top5_accuracy /= (counter + 1)
  print(f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}")

Epoch 0	Top1 Train accuracy 13.0	Top1 Test accuracy: 13.70059871673584	Top5 test acc: 59.769073486328125
Epoch 1	Top1 Train accuracy 13.0	Top1 Test accuracy: 14.848345756530762	Top5 test acc: 61.12419509887695
Epoch 2	Top1 Train accuracy 14.0	Top1 Test accuracy: 15.774931907653809	Top5 test acc: 61.978973388671875
Epoch 3	Top1 Train accuracy 17.0	Top1 Test accuracy: 16.62109375	Top5 test acc: 63.187618255615234
Epoch 4	Top1 Train accuracy 22.0	Top1 Test accuracy: 17.431640625	Top5 test acc: 64.11305236816406
Epoch 5	Top1 Train accuracy 23.0	Top1 Test accuracy: 18.082490921020508	Top5 test acc: 64.97013092041016
Epoch 6	Top1 Train accuracy 26.0	Top1 Test accuracy: 18.774702072143555	Top5 test acc: 65.81342315673828
Epoch 7	Top1 Train accuracy 29.0	Top1 Test accuracy: 19.48414421081543	Top5 test acc: 66.16383361816406
Epoch 8	Top1 Train accuracy 34.0	Top1 Test accuracy: 20.097082138061523	Top5 test acc: 66.63028717041016









Epoch 9	Top1 Train accuracy 37.0	Top1 Test accuracy: 20.69

Epoch 78	Top1 Train accuracy 92.0	Top1 Test accuracy: 36.586055755615234	Top5 test acc: 85.75769805908203
Epoch 79	Top1 Train accuracy 94.0	Top1 Test accuracy: 36.604434967041016	Top5 test acc: 85.80652618408203
Epoch 80	Top1 Train accuracy 94.0	Top1 Test accuracy: 36.672794342041016	Top5 test acc: 85.86511993408203
Epoch 81	Top1 Train accuracy 94.0	Top1 Test accuracy: 36.722774505615234	Top5 test acc: 85.86511993408203
Epoch 82	Top1 Train accuracy 94.0	Top1 Test accuracy: 36.713008880615234	Top5 test acc: 85.85535430908203
Epoch 83	Top1 Train accuracy 95.0	Top1 Test accuracy: 36.761837005615234	Top5 test acc: 85.87488555908203
Epoch 84	Top1 Train accuracy 96.0	Top1 Test accuracy: 36.791133880615234	Top5 test acc: 85.88465118408203
Epoch 85	Top1 Train accuracy 96.0	Top1 Test accuracy: 36.810665130615234	Top5 test acc: 85.90532684326172
Epoch 86	Top1 Train accuracy 96.0	Top1 Test accuracy: 36.839962005615234	Top5 test acc: 85.95531463623047
Epoch 87	Top1 Train accuracy 96.0	Top1 Test ac