In [1]:
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import json
from itertools import chain
import numpy as np

from essl.datasets import Cifar10, SVHN
from essl.backbones import largerCNN_backbone
from essl.evaluate_downstream import finetune_model
from torch.utils.data import DataLoader
import torch
import glob
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def top2(model_path, dataset):
    backbone = largerCNN_backbone()
    model = finetune_model(backbone=backbone.backbone, in_features=backbone.in_features, num_outputs=dataset.num_classes)
    model.load_state_dict(torch.load(model_path))
    model = model.cuda()
    dataloader = DataLoader(dataset.test_data, batch_size=64,shuffle=False)
    total_correct = 0
    total = 0
    for X, y in dataloader:
        X, y = X.cuda(), y.cuda()
        outputs = model(X)
        vals, inds = torch.topk(outputs, k=2)
        output_ohe = torch.nn.functional.one_hot(inds, dataset.num_classes)
        output_ohe = torch.sum(output_ohe, dim=1)
        y_ohe = torch.nn.functional.one_hot(y, dataset.num_classes) 
        correct = torch.sum(output_ohe * y_ohe)
        total_correct += correct.item()
        total+= X.shape[0]
    return total_correct/total
    

In [3]:
b256 = [
    "/home/noah/ESSL/final_exps/optimization/exp8_4/4/outcomes.json",
    "/home/noah/ESSL/final_exps/optimization/exp8_5/2/outcomes.json",
    "/home/noah/ESSL/final_exps/optimization/exp8_6/1/outcomes.json",
    "/home/noah/ESSL/final_exps/optimization/exp8_7/7/outcomes.json"
]

b256_names = ["SwaV", "BYOL", "SimSiam", "NNCLR"]

b32 = [
    "/home/noah/ESSL/final_exps/optimization/exp6_1/6/outcomes.json",
    "/home/noah/ESSL/final_exps/optimization/exp6_2/8/outcomes.json",
    "/home/noah/ESSL/final_exps/optimization/exp6_3/3/outcomes.json",
    "/home/noah/ESSL/final_exps/optimization/exp6_0/3/outcomes.json"
]

b32_names = ["SwaV", "BYOL", "SimSiam", "NNCLR"]

b256SVHN = [
    "/home/noah/ESSL/final_exps/optimization/exp10_0/1/outcomes.json",
    "/home/noah/ESSL/final_exps/optimization/exp10_1/4/outcomes.json",
    "/home/noah/ESSL/final_exps/optimization/exp10_2/1/outcomes.json",
    "/home/noah/ESSL/final_exps/optimization/exp10_3/3/outcomes.json"
]

b256_names = ["SwaV", "BYOL", "SimSiam", "NNCLR"]

b32SVHN = [
    #"/home/noah/ESSL/final_exps/optimization/exp6_1/6/outcomes.json",
    "/home/noah/ESSL/final_exps/optimization/exp11_1/0/outcomes.json",
    #"/home/noah/ESSL/final_exps/optimization/exp6_3/3/outcomes.json",
    "/home/noah/ESSL/final_exps/optimization/exp11_3/1/outcomes.json"
]

b32_names = [#"SwaV", 
             "BYOL", 
             #"SimSiam", 
             "NNCLR"]

In [4]:
for exp_path, exp_name in zip(b32SVHN, b32_names):
    models = glob.glob(os.path.join(os.path.dirname(exp_path), "models/*.pt"))
    if len(models) == 1:
        model_path = models[0]
    elif len(models) == 0:
        continue
    else:
        model_path = [m for m in models if "downstream" in m][0]
    top2_acc = top2(model_path, dataset=SVHN())
    print(exp_name, " ", top2_acc)

Using downloaded and verified file: datasets/SVHN/train_32x32.mat
Using downloaded and verified file: datasets/SVHN/test_32x32.mat
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
BYOL   0.9612015980331899
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
Using downloaded and verified file: datasets/SVHN/test_32x32.mat
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
NNCLR   0.9575522433927474


In [5]:
for exp_path, exp_name in zip(b256SVHN, b256_names):
    models = glob.glob(os.path.join(os.path.dirname(exp_path), "models/*.pt"))
    if len(models) == 1:
        model_path = models[0]
    elif len(models) == 0:
        continue
    else:
        model_path = [m for m in models if "downstream" in m][0]
    top2_acc = top2(model_path, dataset=SVHN())
    print(exp_name, " ", top2_acc)

Using downloaded and verified file: datasets/SVHN/train_32x32.mat
Using downloaded and verified file: datasets/SVHN/test_32x32.mat
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
SwaV   0.9615857406269207
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
Using downloaded and verified file: datasets/SVHN/test_32x32.mat
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
BYOL   0.9571681007990166
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
Using downloaded and verified file: datasets/SVHN/test_32x32.mat
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
SimSiam   0.965619237861094
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
Using downloaded and verified file: datasets/SVHN/test_32x32.mat
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
NNCLR   0.9598570989551322
