<a href="https://colab.research.google.com/github/sthalles/SimCLR/blob/simclr-refactor/feature_eval/mini_batch_logistic_regression_evaluator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import torch
import sys
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets

In [2]:
!pip install gdown

Collecting gdown
  Downloading gdown-3.12.2.tar.gz (8.2 kB)
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h    Preparing wheel metadata ... [?25ldone
Building wheels for collected packages: gdown
  Building wheel for gdown (PEP 517) ... [?25ldone
[?25h  Created wheel for gdown: filename=gdown-3.12.2-py3-none-any.whl size=9681 sha256=e7aeb48796863f84a249d2893133f26d1c48510ad467ccf5c082f4cac3144282
  Stored in directory: /home/arora.roh/.cache/pip/wheels/ba/e0/7e/726e872a53f7358b4b96a9975b04e98113b005cd8609a63abc
Successfully built gdown
Installing collected packages: gdown
Successfully installed gdown-3.12.2


In [3]:
def get_file_id_by_model(folder_name):
  file_id = {'resnet18_100-epochs_stl10': '14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF',
             'resnet18_100-epochs_cifar10': '1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C',
             'resnet50_50-epochs_stl10': '1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu'}
  return file_id.get(folder_name, "Model not found.")

In [4]:
folder_name = 'resnet50_50-epochs_stl10'
file_id = get_file_id_by_model(folder_name)
print(folder_name, file_id)

resnet50_50-epochs_stl10 1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu


In [5]:
# download and extract model files
os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))
os.system('unzip {}'.format(folder_name))
!ls

checkpoint_0040.pth.tar
config.yml
events.out.tfevents.1610927742.4cb2c837708d.2694093.0
mini_batch_logistic_regression_evaluator.ipynb
resnet50_50-epochs_stl10.zip
training.log


In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

Using device: cuda


In [8]:
def get_stl10_data_loaders(download, shuffle=False, batch_size=256):
  train_dataset = datasets.STL10('./data', split='train', download=download,
                                  transform=transforms.ToTensor())

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=0, drop_last=False, shuffle=shuffle)
  
  test_dataset = datasets.STL10('./data', split='test', download=download,
                                  transform=transforms.ToTensor())

  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=10, drop_last=False, shuffle=shuffle)
  return train_loader, test_loader

def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):
  train_dataset = datasets.CIFAR10('./data', train=True, download=download,
                                  transform=transforms.ToTensor())

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=0, drop_last=False, shuffle=shuffle)
  
  test_dataset = datasets.CIFAR10('./data', train=False, download=download,
                                  transform=transforms.ToTensor())

  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=10, drop_last=False, shuffle=shuffle)
  return train_loader, test_loader

In [10]:
with open(os.path.join('./config.yml')) as file:
  config = yaml.load(file)

  


In [11]:
if config.arch == 'resnet18':
  model = torchvision.models.resnet18(pretrained=False, num_classes=10).to(device)
elif config.arch == 'resnet50':
  model = torchvision.models.resnet50(pretrained=False, num_classes=10).to(device)

In [12]:
checkpoint = torch.load('checkpoint_0040.pth.tar', map_location=device)
state_dict = checkpoint['state_dict']

for k in list(state_dict.keys()):

  if k.startswith('backbone.'):
    if k.startswith('backbone') and not k.startswith('backbone.fc'):
      # remove prefix
      state_dict[k[len("backbone."):]] = state_dict[k]
  del state_dict[k]

In [13]:
log = model.load_state_dict(state_dict, strict=False)
assert log.missing_keys == ['fc.weight', 'fc.bias']

In [14]:
if config.dataset_name == 'cifar10':
  train_loader, test_loader = get_cifar10_data_loaders(download=True)
elif config.dataset_name == 'stl10':
  train_loader, test_loader = get_stl10_data_loaders(download=True)
print("Dataset:", config.dataset_name)

Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to ./data/stl10_binary.tar.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ./data/stl10_binary.tar.gz to ./data
Files already downloaded and verified
Dataset: stl10


In [15]:
# freeze all layers but the last fc
for name, param in model.named_parameters():
    if name not in ['fc.weight', 'fc.bias']:
        param.requires_grad = False

parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2  # fc.weight, fc.bias

In [16]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)
criterion = torch.nn.CrossEntropyLoss().to(device)

In [17]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [18]:
epochs = 100
for epoch in range(epochs):
  top1_train_accuracy = 0
  for counter, (x_batch, y_batch) in enumerate(train_loader):
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    logits = model(x_batch)
    loss = criterion(logits, y_batch)
    top1 = accuracy(logits, y_batch, topk=(1,))
    top1_train_accuracy += top1[0]

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  top1_train_accuracy /= (counter + 1)
  top1_accuracy = 0
  top5_accuracy = 0
  for counter, (x_batch, y_batch) in enumerate(test_loader):
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    logits = model(x_batch)
  
    top1, top5 = accuracy(logits, y_batch, topk=(1,5))
    top1_accuracy += top1[0]
    top5_accuracy += top5[0]
  
  top1_accuracy /= (counter + 1)
  top5_accuracy /= (counter + 1)
  print(f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}")











Epoch 0	Top1 Train accuracy 33.82008361816406	Top1 Test accuracy: 48.32763671875	Top5 test acc: 93.48388671875
Epoch 1	Top1 Train accuracy 52.78607940673828	Top1 Test accuracy: 55.16357421875	Top5 test acc: 95.23681640625
Epoch 2	Top1 Train accuracy 57.42761993408203	Top1 Test accuracy: 57.44873046875	Top5 test acc: 95.7958984375
Epoch 3	Top1 Train accuracy 58.763790130615234	Top1 Test accuracy: 58.07861328125	Top5 test acc: 96.23046875
Epoch 4	Top1 Train accuracy 59.525508880615234	Top1 Test accuracy: 58.56689453125	Top5 test acc: 96.34033203125
Epoch 5	Top1 Train accuracy 60.1872673034668	Top1 Test accuracy: 59.0087890625	Top5 test acc: 96.474609375
Epoch 6	Top1 Train accuracy 60.6560173034668	Top1 Test accuracy: 59.2919921875	Top5 test acc: 96.54052734375
Epoch 7	Top1 Train accuracy 61.27412796020508	Top1 Test accuracy: 59.49951171875	Top5 test acc: 96.66259765625
Epoch 8	Top1 Train accuracy 61.44990921020508	Top1 Test accuracy: 59.83642578125	Top5 test acc: 96.7822265625


Epoch 75	Top1 Train accuracy 69.95519256591797	Top1 Test accuracy: 66.3818359375	Top5 test acc: 97.890625
Epoch 76	Top1 Train accuracy 70.03331756591797	Top1 Test accuracy: 66.3818359375	Top5 test acc: 97.90283203125
Epoch 77	Top1 Train accuracy 70.05284881591797	Top1 Test accuracy: 66.40625	Top5 test acc: 97.90283203125
Epoch 78	Top1 Train accuracy 70.11144256591797	Top1 Test accuracy: 66.455078125	Top5 test acc: 97.890625
Epoch 79	Top1 Train accuracy 70.15050506591797	Top1 Test accuracy: 66.5283203125	Top5 test acc: 97.87841796875
Epoch 80	Top1 Train accuracy 70.18956756591797	Top1 Test accuracy: 66.56494140625	Top5 test acc: 97.87841796875
Epoch 81	Top1 Train accuracy 70.24816131591797	Top1 Test accuracy: 66.5283203125	Top5 test acc: 97.87841796875
Epoch 82	Top1 Train accuracy 70.32628631591797	Top1 Test accuracy: 66.6015625	Top5 test acc: 97.890625
Epoch 83	Top1 Train accuracy 70.36534881591797	Top1 Test accuracy: 66.66259765625	Top5 test acc: 97.890625
Epoch 84	Top1 Train accuracy