<a href="https://colab.research.google.com/github/sthalles/SimCLR/blob/simclr-refactor/feature_eval/mini_batch_logistic_regression_evaluator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [56]:
import torch
import sys
import numpy as np
import os
import yaml
import matplotlib.pyplot as plt
import torchvision
import os
from PIL import Image

import numpy as np
import pandas as pd
from tqdm import tqdm

In [41]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root, split, transform):
        r"""
        Args:
            root: Location of the dataset folder, usually it is /dataset
            split: The split you want to used, it should be one of train, val or unlabeled.
            transform: the transform you want to applied to the images.
        """
        self.split = split
        self.transform = transform

        self.image_dir = os.path.join(root, split)
        label_path = os.path.join(root, f"{split}_label_tensor.pt")

        self.num_images = len(os.listdir(self.image_dir))

        if os.path.exists(label_path):
            self.labels = torch.load(label_path)
        else:
            self.labels = -1 * torch.ones(self.num_images, dtype=torch.long)

    def __len__(self):
        return self.num_images

    def __getitem__(self, idx):
        with open(os.path.join(self.image_dir, f"{idx}.png"), 'rb') as f:
            img = Image.open(f).convert('RGB')

        return self.transform(img), self.labels[idx]

In [37]:
!pwd

/home/rahulahuja/nyu/dl/NYU_DL_comp/SimCLR/feature_eval


In [2]:
!pip install gdown

Collecting gdown
  Downloading gdown-3.12.2.tar.gz (8.2 kB)
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h    Preparing wheel metadata ... [?25ldone
Building wheels for collected packages: gdown
  Building wheel for gdown (PEP 517) ... [?25ldone
[?25h  Created wheel for gdown: filename=gdown-3.12.2-py3-none-any.whl size=9681 sha256=0c3cb355ee627461a9115749a300d2306e12259c91eab4db10d861eb32f1c7f7
  Stored in directory: /home/rahulahuja/.cache/pip/wheels/ba/e0/7e/726e872a53f7358b4b96a9975b04e98113b005cd8609a63abc
Successfully built gdown
Installing collected packages: gdown
Successfully installed gdown-3.12.2
You should consider upgrading via the '/home/rahulahuja/anaconda3/bin/python -m pip install --upgrade pip' command.[0m


In [12]:
def get_file_id_by_model(folder_name):
  file_id = {'resnet18_100-epochs_stl10': '14_nH2FkyKbt61cieQDiSbBVNP8-gtwgF',
             'resnet18_100-epochs_cifar10': '1lc2aoVtrAetGn0PnTkOyFzPCIucOJq7C',
             'resnet50_50-epochs_stl10': '1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu'}
  return file_id.get(folder_name, "Model not found.")

In [13]:
folder_name = 'resnet50_50-epochs_stl10'
file_id = get_file_id_by_model(folder_name)
print(folder_name, file_id)

resnet50_50-epochs_stl10 1ByTKAUsdm_X7tLcii6oAEl5qFRqRMZSu


In [14]:
# download and extract model files
os.system('gdown https://drive.google.com/uc?id={}'.format(file_id))
os.system('unzip {}'.format(folder_name))
!ls

checkpoint_0040.pth.tar
config.yml
events.out.tfevents.1610927742.4cb2c837708d.2694093.0
resnet50_50-epochs_stl10.zip
sample_data
training.log


In [3]:
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets

In [5]:
device = 'cuda:4' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

Using device: cuda:4


In [58]:
def get_stl10_data_loaders(download, shuffle=False, batch_size=256):
  train_dataset = datasets.STL10('./data', split='train', download=download,
                                  transform=transforms.ToTensor())

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=0, drop_last=False, shuffle=shuffle)
  
  test_dataset = datasets.STL10('./data', split='test', download=download,
                                  transform=transforms.ToTensor())

  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=10, drop_last=False, shuffle=shuffle)
  return train_loader, test_loader

def get_cifar10_data_loaders(download, shuffle=False, batch_size=256):
  train_dataset = datasets.CIFAR10('./data', train=True, download=download,
                                  transform=transforms.ToTensor())

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=0, drop_last=False, shuffle=shuffle)
  
  test_dataset = datasets.CIFAR10('./data', train=False, download=download,
                                  transform=transforms.ToTensor())

  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=10, drop_last=False, shuffle=shuffle)
  return train_loader, test_loader

def get_nyu_data_loaders(shuffle=True, batch_size=256):
  train_dataset = CustomDataset('../../dataset', split='train',
                                  transform=transforms.ToTensor())

  train_loader = DataLoader(train_dataset, batch_size=batch_size,
                            num_workers=0, drop_last=False, shuffle=shuffle)
  
  test_dataset = CustomDataset('../../dataset', split='val',transform=transforms.ToTensor())

  test_loader = DataLoader(test_dataset, batch_size=2*batch_size,
                            num_workers=4, drop_last=False, shuffle=False)
  return train_loader, test_loader



In [8]:
model = torchvision.models.resnet50(pretrained=False, num_classes=800).to(device)

In [72]:
checkpoint = torch.load('../runs/Apr16_19-00-37_rahulahuja-U2099/checkpoint_latest.pth.tar', map_location=device)
# /home/rahulahuja/nyu/dl/NYU_DL_comp/SimCLR/runs/Apr13_09-18-19_rahulahuja-U2099/checkpoint_0200.pth.tar
state_dict = checkpoint['state_dict']

for k in list(state_dict.keys()):
  if k.startswith('backbone.'):
    if k.startswith('backbone') and not k.startswith('backbone.fc'):
      # remove prefix
      state_dict[k[len("backbone."):]] = state_dict[k]
  del state_dict[k]

In [73]:
checkpoint['state_dict'].keys()

odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 'bn1.running_mean', 'bn1.running_var', 'bn1.num_batches_tracked', 'layer1.0.conv1.weight', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.num_batches_tracked', 'layer1.0.conv2.weight', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.num_batches_tracked', 'layer1.0.conv3.weight', 'layer1.0.bn3.weight', 'layer1.0.bn3.bias', 'layer1.0.bn3.running_mean', 'layer1.0.bn3.running_var', 'layer1.0.bn3.num_batches_tracked', 'layer1.0.downsample.0.weight', 'layer1.0.downsample.1.weight', 'layer1.0.downsample.1.bias', 'layer1.0.downsample.1.running_mean', 'layer1.0.downsample.1.running_var', 'layer1.0.downsample.1.num_batches_tracked', 'layer1.1.conv1.weight', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.bn1.running_mean', 'layer1.1.bn1.running_var', 'layer1.1.bn1.num_batches_tracked', 'layer1.1.conv2.we

In [74]:
log = model.load_state_dict(state_dict, strict=False)
assert log.missing_keys == ['fc.weight', 'fc.bias']
# log.missing_keys

In [59]:
train_loader, test_loader = get_nyu_data_loaders()

In [47]:
# freeze all layers but the last fc
for name, param in model.named_parameters():
    if name not in ['fc.weight', 'fc.bias']:
        param.requires_grad = False

parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2  # fc.weight, fc.bias

In [48]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.0008)
criterion = torch.nn.CrossEntropyLoss().to(device)

In [49]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [76]:
epochs = 100
for epoch in range(epochs):
  top1_train_accuracy = 0
  for counter, (x_batch, y_batch) in enumerate(tqdm(train_loader)):
#     print(counter)
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    logits = model(x_batch)
    loss = criterion(logits, y_batch)
    top1 = accuracy(logits, y_batch, topk=(1,))
    top1_train_accuracy += top1[0]

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  top1_train_accuracy /= (counter + 1)
  top1_accuracy = 0
  top5_accuracy = 0
  for counter, (x_batch, y_batch) in enumerate(tqdm(test_loader)):
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    logits = model(x_batch)
  
    top1, top5 = accuracy(logits, y_batch, topk=(1,5))
    top1_accuracy += top1[0]
    top5_accuracy += top5[0]
  
  top1_accuracy /= (counter + 1)
  top5_accuracy /= (counter + 1)
  print(f"Epoch {epoch}\tTop1 Train accuracy {top1_train_accuracy.item()}\tTop1 Test accuracy: {top1_accuracy.item()}\tTop5 test acc: {top5_accuracy.item()}")

100%|██████████| 100/100 [04:41<00:00,  2.82s/it]
100%|██████████| 50/50 [00:51<00:00,  1.02s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 0	Top1 Train accuracy 1.40234375	Top1 Test accuracy: 1.21875	Top5 test acc: 4.64453125


100%|██████████| 100/100 [04:46<00:00,  2.86s/it]
100%|██████████| 50/50 [01:06<00:00,  1.33s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1	Top1 Train accuracy 2.11328125	Top1 Test accuracy: 1.80859375	Top5 test acc: 6.48046875


100%|██████████| 100/100 [05:14<00:00,  3.14s/it]
100%|██████████| 50/50 [00:59<00:00,  1.20s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 2	Top1 Train accuracy 2.93359375	Top1 Test accuracy: 2.44921875	Top5 test acc: 8.33984375


100%|██████████| 100/100 [05:04<00:00,  3.04s/it]
100%|██████████| 50/50 [00:54<00:00,  1.09s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 3	Top1 Train accuracy 4.109375	Top1 Test accuracy: 3.2109375	Top5 test acc: 10.91015625


100%|██████████| 100/100 [05:17<00:00,  3.17s/it]
100%|██████████| 50/50 [00:55<00:00,  1.10s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 4	Top1 Train accuracy 5.48828125	Top1 Test accuracy: 3.984375	Top5 test acc: 12.85546875


100%|██████████| 100/100 [05:14<00:00,  3.14s/it]
100%|██████████| 50/50 [01:04<00:00,  1.29s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 5	Top1 Train accuracy 6.6796875	Top1 Test accuracy: 4.99609375	Top5 test acc: 14.75


100%|██████████| 100/100 [05:37<00:00,  3.37s/it]
100%|██████████| 50/50 [00:56<00:00,  1.14s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 6	Top1 Train accuracy 8.01171875	Top1 Test accuracy: 5.73828125	Top5 test acc: 16.28515625


100%|██████████| 100/100 [05:11<00:00,  3.12s/it]
100%|██████████| 50/50 [00:51<00:00,  1.04s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 7	Top1 Train accuracy 8.94140625	Top1 Test accuracy: 6.3203125	Top5 test acc: 17.49609375


100%|██████████| 100/100 [05:03<00:00,  3.03s/it]
100%|██████████| 50/50 [00:57<00:00,  1.15s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 8	Top1 Train accuracy 9.89453125	Top1 Test accuracy: 6.91796875	Top5 test acc: 18.59375


100%|██████████| 100/100 [04:47<00:00,  2.88s/it]
100%|██████████| 50/50 [01:01<00:00,  1.22s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 9	Top1 Train accuracy 10.546875	Top1 Test accuracy: 7.4140625	Top5 test acc: 19.52734375


100%|██████████| 100/100 [04:52<00:00,  2.93s/it]
100%|██████████| 50/50 [01:00<00:00,  1.22s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 10	Top1 Train accuracy 11.48046875	Top1 Test accuracy: 7.78515625	Top5 test acc: 20.49609375


100%|██████████| 100/100 [05:15<00:00,  3.15s/it]
100%|██████████| 50/50 [00:56<00:00,  1.13s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 11	Top1 Train accuracy 11.9453125	Top1 Test accuracy: 8.296875	Top5 test acc: 21.08984375


100%|██████████| 100/100 [05:25<00:00,  3.26s/it]
100%|██████████| 50/50 [01:00<00:00,  1.21s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 12	Top1 Train accuracy 12.64453125	Top1 Test accuracy: 8.4453125	Top5 test acc: 21.6171875


100%|██████████| 100/100 [05:25<00:00,  3.25s/it]
100%|██████████| 50/50 [00:55<00:00,  1.11s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 13	Top1 Train accuracy 13.234375	Top1 Test accuracy: 8.703125	Top5 test acc: 22.23828125


100%|██████████| 100/100 [05:15<00:00,  3.15s/it]
100%|██████████| 50/50 [01:01<00:00,  1.23s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 14	Top1 Train accuracy 13.75	Top1 Test accuracy: 8.9609375	Top5 test acc: 22.80859375


100%|██████████| 100/100 [05:25<00:00,  3.26s/it]
100%|██████████| 50/50 [01:01<00:00,  1.22s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 15	Top1 Train accuracy 14.23828125	Top1 Test accuracy: 9.171875	Top5 test acc: 23.39453125


100%|██████████| 100/100 [05:32<00:00,  3.33s/it]
100%|██████████| 50/50 [01:06<00:00,  1.34s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 16	Top1 Train accuracy 14.4765625	Top1 Test accuracy: 9.50390625	Top5 test acc: 23.6875


100%|██████████| 100/100 [05:17<00:00,  3.18s/it]
100%|██████████| 50/50 [01:05<00:00,  1.30s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 17	Top1 Train accuracy 14.953125	Top1 Test accuracy: 9.71484375	Top5 test acc: 24.11328125


100%|██████████| 100/100 [05:30<00:00,  3.30s/it]
100%|██████████| 50/50 [01:04<00:00,  1.29s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 18	Top1 Train accuracy 15.37109375	Top1 Test accuracy: 9.90625	Top5 test acc: 24.3828125


100%|██████████| 100/100 [05:31<00:00,  3.31s/it]
100%|██████████| 50/50 [01:04<00:00,  1.30s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 19	Top1 Train accuracy 15.546875	Top1 Test accuracy: 10.01953125	Top5 test acc: 24.76171875


100%|██████████| 100/100 [05:44<00:00,  3.44s/it]
100%|██████████| 50/50 [01:00<00:00,  1.21s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 20	Top1 Train accuracy 15.65625	Top1 Test accuracy: 10.12890625	Top5 test acc: 25.12109375


100%|██████████| 100/100 [06:23<00:00,  3.84s/it]
100%|██████████| 50/50 [01:07<00:00,  1.36s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 21	Top1 Train accuracy 16.21484375	Top1 Test accuracy: 10.26171875	Top5 test acc: 25.40234375


100%|██████████| 100/100 [05:23<00:00,  3.23s/it]
100%|██████████| 50/50 [01:07<00:00,  1.36s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 22	Top1 Train accuracy 16.2890625	Top1 Test accuracy: 10.40234375	Top5 test acc: 25.63671875


100%|██████████| 100/100 [05:06<00:00,  3.07s/it]
100%|██████████| 50/50 [00:57<00:00,  1.15s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 23	Top1 Train accuracy 16.58203125	Top1 Test accuracy: 10.48828125	Top5 test acc: 26.01953125


100%|██████████| 100/100 [05:04<00:00,  3.05s/it]
100%|██████████| 50/50 [00:59<00:00,  1.19s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 24	Top1 Train accuracy 16.93359375	Top1 Test accuracy: 10.59765625	Top5 test acc: 26.1953125


100%|██████████| 100/100 [05:18<00:00,  3.18s/it]
100%|██████████| 50/50 [01:02<00:00,  1.25s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 25	Top1 Train accuracy 16.859375	Top1 Test accuracy: 10.7109375	Top5 test acc: 26.43359375


100%|██████████| 100/100 [05:15<00:00,  3.15s/it]
100%|██████████| 50/50 [01:04<00:00,  1.29s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 26	Top1 Train accuracy 17.31640625	Top1 Test accuracy: 10.9453125	Top5 test acc: 26.76953125


100%|██████████| 100/100 [05:30<00:00,  3.30s/it]
100%|██████████| 50/50 [01:02<00:00,  1.25s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 27	Top1 Train accuracy 17.625	Top1 Test accuracy: 11.08203125	Top5 test acc: 27.0


100%|██████████| 100/100 [05:35<00:00,  3.35s/it]
100%|██████████| 50/50 [01:16<00:00,  1.52s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 28	Top1 Train accuracy 17.61328125	Top1 Test accuracy: 11.1875	Top5 test acc: 27.20703125


100%|██████████| 100/100 [05:50<00:00,  3.51s/it]
100%|██████████| 50/50 [01:12<00:00,  1.45s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 29	Top1 Train accuracy 17.83203125	Top1 Test accuracy: 11.3125	Top5 test acc: 27.27734375


100%|██████████| 100/100 [05:56<00:00,  3.57s/it]
100%|██████████| 50/50 [01:00<00:00,  1.22s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 30	Top1 Train accuracy 17.8515625	Top1 Test accuracy: 11.3984375	Top5 test acc: 27.48046875


100%|██████████| 100/100 [05:51<00:00,  3.51s/it]
100%|██████████| 50/50 [00:59<00:00,  1.19s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 31	Top1 Train accuracy 18.24609375	Top1 Test accuracy: 11.51171875	Top5 test acc: 27.734375


100%|██████████| 100/100 [05:52<00:00,  3.53s/it]
100%|██████████| 50/50 [01:02<00:00,  1.24s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 32	Top1 Train accuracy 18.40234375	Top1 Test accuracy: 11.62109375	Top5 test acc: 27.93359375


100%|██████████| 100/100 [06:18<00:00,  3.79s/it]
100%|██████████| 50/50 [01:03<00:00,  1.28s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 33	Top1 Train accuracy 18.6953125	Top1 Test accuracy: 11.75390625	Top5 test acc: 28.0546875


100%|██████████| 100/100 [05:59<00:00,  3.60s/it]
100%|██████████| 50/50 [01:00<00:00,  1.22s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 34	Top1 Train accuracy 18.8828125	Top1 Test accuracy: 11.90234375	Top5 test acc: 28.21875


100%|██████████| 100/100 [05:50<00:00,  3.50s/it]
100%|██████████| 50/50 [01:10<00:00,  1.41s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 35	Top1 Train accuracy 18.8984375	Top1 Test accuracy: 11.91015625	Top5 test acc: 28.33203125


100%|██████████| 100/100 [04:56<00:00,  2.97s/it]
100%|██████████| 50/50 [01:02<00:00,  1.26s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 36	Top1 Train accuracy 19.2734375	Top1 Test accuracy: 12.0	Top5 test acc: 28.421875


100%|██████████| 100/100 [05:08<00:00,  3.08s/it]
100%|██████████| 50/50 [01:05<00:00,  1.30s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 37	Top1 Train accuracy 19.421875	Top1 Test accuracy: 12.12109375	Top5 test acc: 28.69921875


100%|██████████| 100/100 [05:30<00:00,  3.30s/it]
100%|██████████| 50/50 [01:09<00:00,  1.39s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 38	Top1 Train accuracy 19.48828125	Top1 Test accuracy: 12.203125	Top5 test acc: 28.78125


100%|██████████| 100/100 [05:57<00:00,  3.57s/it]
100%|██████████| 50/50 [01:00<00:00,  1.21s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 39	Top1 Train accuracy 19.81640625	Top1 Test accuracy: 12.3046875	Top5 test acc: 28.984375


100%|██████████| 100/100 [06:02<00:00,  3.62s/it]
100%|██████████| 50/50 [01:02<00:00,  1.25s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 40	Top1 Train accuracy 20.015625	Top1 Test accuracy: 12.453125	Top5 test acc: 29.1171875


100%|██████████| 100/100 [06:03<00:00,  3.64s/it]
100%|██████████| 50/50 [00:58<00:00,  1.17s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 41	Top1 Train accuracy 20.03125	Top1 Test accuracy: 12.45703125	Top5 test acc: 29.1953125


100%|██████████| 100/100 [05:33<00:00,  3.33s/it]
100%|██████████| 50/50 [01:01<00:00,  1.24s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 42	Top1 Train accuracy 20.0078125	Top1 Test accuracy: 12.48828125	Top5 test acc: 29.265625


100%|██████████| 100/100 [06:16<00:00,  3.76s/it]
100%|██████████| 50/50 [01:05<00:00,  1.30s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 43	Top1 Train accuracy 20.30859375	Top1 Test accuracy: 12.62109375	Top5 test acc: 29.453125


100%|██████████| 100/100 [05:50<00:00,  3.50s/it]
100%|██████████| 50/50 [01:06<00:00,  1.34s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 44	Top1 Train accuracy 20.3125	Top1 Test accuracy: 12.625	Top5 test acc: 29.515625


100%|██████████| 100/100 [05:58<00:00,  3.58s/it]
100%|██████████| 50/50 [00:57<00:00,  1.14s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 45	Top1 Train accuracy 20.29296875	Top1 Test accuracy: 12.73828125	Top5 test acc: 29.60546875


100%|██████████| 100/100 [05:39<00:00,  3.40s/it]
100%|██████████| 50/50 [01:00<00:00,  1.20s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 46	Top1 Train accuracy 20.2890625	Top1 Test accuracy: 12.8515625	Top5 test acc: 29.6953125


100%|██████████| 100/100 [06:18<00:00,  3.78s/it]
100%|██████████| 50/50 [01:08<00:00,  1.38s/it]
  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 47	Top1 Train accuracy 20.64453125	Top1 Test accuracy: 12.87109375	Top5 test acc: 29.73828125


 12%|█▏        | 12/100 [00:46<05:43,  3.90s/it]


KeyboardInterrupt: 

In [None]:
def test(model):
    for counter, (x_batch, y_batch) in enumerate(tqdm(test_loader)):
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        logits = model(x_batch)

        top1, top5 = accuracy(logits, y_batch, topk=(1,5))
        top1_accuracy += top1[0]
        top5_accuracy += top5[0]

      top1_accuracy /= (counter + 1)
      top5_accuracy /= (counter + 1)

In [78]:
model1= torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50', _use_new_zipfile_serialization=False)
model2= torch.hub.load('facebookresearch/swav', 'resnet50')

Using cache found in /home/rahulahuja/.cache/torch/hub/facebookresearch_barlowtwins_main


TypeError: __init__() got an unexpected keyword argument '_use_new_zipfile_serialization'