## Training SimSiam and linear head

Clone the following github: https://github.com/Reza-Safdari/SimSiam-91.9-top1-acc-on-CIFAR10

### Training SimSiam:

python main.py --exp_dir [Directory to store experiment models and results] --data_root /projects/leelab/data/image/cifar10/ --arch resnet18 --learning_rate 0.06 --epochs 800 --weight_decay 5e-4 --momentum 0.9 --batch_size 512 --gpu 7

### Training linear head:

python main_lincls.py --arch resnet18 --num_cls 10 --batch_size 256 --lr 30.0 --weight_decay 0.0 --pretrained [Best model (should be in exp_dir)] --gpu 7 /projects/leelab/data/image/cifar10/

## Load pretrained model

In [1]:
import torch
import torch.nn as nn
from cl_explain.encoders.simsiam.resnet import ResNet18

device = 7

data_path = '/projects/leelab/data/image/cifar10/'
model_path = '/projects/leelab/cl-explainability/encoders/sim_siam/'

def load_model(model_path, device):
    model = ResNet18(low_dim=10)
    checkpoint = torch.load(model_path + 'resnet18.pth.tar')
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)
    return(model)

model = load_model(model_path, device)

## Evaluate validation performance

In [2]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

BATCHSIZE = 256
NUMWORKERS = 20

transform_test = transforms.Compose([
    transforms.Resize(int(32 * (8 / 7)), 
                      interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(32),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

test_set = datasets.CIFAR10(data_path, train=False, transform=transform_test)

test_loader = DataLoader(
    test_set, batch_size=BATCHSIZE, shuffle=False, num_workers=NUMWORKERS)

In [7]:
# Verify apply_eval_head is working as expected
for img, label in test_loader:
    break

print(model(img.to(device), apply_eval_head=True).shape)
print(model(img.to(device), apply_eval_head=False).shape)

torch.Size([256, 10])
torch.Size([256, 512])


In [10]:
import time

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda(device)    

def validate(test_loader, model, criterion, gpu=None, print_freq=500):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(test_loader):
            if gpu is not None:
                images = images.cuda(gpu, non_blocking=True)
            target = target.cuda(gpu, non_blocking=True)

            # compute output
            output = model(images, apply_eval_head=True)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return top1.avg

In [11]:
validate(test_loader, model, criterion, device)

 * Acc@1 92.060 Acc@5 99.800


tensor(92.0600, device='cuda:7')