In [None]:
# %run MyOtherNotebook.ipynb To import from other ipynb files.
import torchvision
from torchvision import transforms
import torch
from typing import OrderedDict
from tqdm import tqdm
from pathlib import Path
from utils.full_train_cifar10 import full_train_cifar10

# Got help from https://github.com/kuangliu/pytorch-cifar/blob/master/main.py 
TRAIN_ALL=True # Train all models.
all_settings = ["ImageNet_Pretrained", "BYOL", "DINO", "BarlowTwins"]
model_setting="BarlowTwins" # ImageNet_Pretrained, BYOL, DINO or BarlowTwins

In [None]:
device = "cuda:0"
torch.backends.cudnn.benchmark = True
torch.cuda.set_per_process_memory_fraction(0.8, device=device)
NUMBER_OF_CLASSES = 10

In [None]:
def create_resnet(model_setting):
    if model_setting =="ImageNet_Pretrained":
        ResNet = torchvision.models.resnet50(weights="IMAGENET1K_V2")
        number_of_input_features = ResNet.fc.in_features
        ResNet.fc = torch.nn.Linear(number_of_input_features, NUMBER_OF_CLASSES)
    else:
        ResNet = torchvision.models.resnet50()
        number_of_input_features = ResNet.fc.in_features
        ResNet.fc = torch.nn.Identity()
        if model_setting=="BYOL":
            weight_path = "/home/utku/Documents/repos/SSL_OOD/resnet50_byol_imagenet2012.pth.tar"
            state_dict = torch.load(weight_path)["online_backbone"]
            correct_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:]  # remove `module.`
                correct_state_dict[name] = v
            ResNet.load_state_dict(correct_state_dict)
            ResNet.fc = torch.nn.Linear(number_of_input_features, NUMBER_OF_CLASSES)
        elif model_setting =="BarlowTwins":
            weight_path = "/home/utku/Documents/repos/SSL_OOD/barlowT_resnet50.pth"
            state_dict = torch.load(weight_path)
            ResNet.load_state_dict(state_dict)
            ResNet.fc = torch.nn.Linear(number_of_input_features, NUMBER_OF_CLASSES)
            pass
        elif model_setting =="DINO":
            weight_path = "/home/utku/Documents/repos/SSL_OOD/dino_resnet50_pretrain.pth"
            state_dict = torch.load(weight_path)
            ResNet.load_state_dict(state_dict)
            ResNet.fc = torch.nn.Linear(number_of_input_features, NUMBER_OF_CLASSES)
        else:
            raise Exception
    ResNet = ResNet.to(device)
    return ResNet

In [None]:
if TRAIN_ALL:
    for model_setting in all_settings:
        ResNet = create_resnet(model_setting=model_setting)
        full_train_cifar10(ResNet, device=device, model_setting=model_setting)
else:
    ResNet = create_resnet(model_setting=model_setting)
    full_train_cifar10(ResNet, device=device, model_setting=model_setting)