In [1]:
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import json
from itertools import chain
import numpy as np

from essl.datasets import Cifar10, SVHN
from essl.backbones import largerCNN_backbone
from essl.evaluate_downstream import finetune_model
from torch.utils.data import DataLoader
import torch
import glob
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_algorithm(params_path):
    with open(params_path, "r") as f:
        params = []
        for i, l in enumerate(f.readlines()):
            params.append(l)
    # get params
    ssl = [j.split(" ")[1].strip("\n") for j in params if "ssl_task" in j][0]
    bs = [j.split(" ")[1].strip("\n") for j in params if "ssl_batch_size" in j][0]
    return ssl, bs

In [18]:
# cifar10
b256_fs = glob.glob("/home/noah/ESSL/final_exps/optimization/exp8*/*")
b256 = [os.path.join(f, "outcomes.json") for f in b256_fs]
b256_names = [get_algorithm(os.path.join(f, "params.txt")) for f in b256_fs]

b32_fs = glob.glob("/home/noah/ESSL/final_exps/optimization/exp6*/*")
b32 = [os.path.join(f, "outcomes.json") for f in b32_fs]
b32_names = [get_algorithm(os.path.join(f, "params.txt")) for f in b32_fs]

# SVHN
b256_fs_svhn = glob.glob("/home/noah/ESSL/final_exps/optimization/exp10*/*")
b256_svhn = [os.path.join(f, "outcomes.json") for f in b256_fs_svhn]
b256_names_svhn = [get_algorithm(os.path.join(f, "params.txt")) for f in b256_fs_svhn]

b32_fs_svhn = glob.glob("/home/noah/ESSL/final_exps/optimization/exp11*/*")
b32_svhn = [os.path.join(f, "outcomes.json") for f in b32_fs_svhn]
b32_names_svhn = [get_algorithm(os.path.join(f, "params.txt")) for f in b32_fs_svhn]

print(b32)
print(b32_names)

['/home/noah/ESSL/final_exps/optimization/exp6_3/3/outcomes.json', '/home/noah/ESSL/final_exps/optimization/exp6_2/8/outcomes.json', '/home/noah/ESSL/final_exps/optimization/exp6_0/2/outcomes.json', '/home/noah/ESSL/final_exps/optimization/exp6_1/2/outcomes.json']
[('SimSiam', '32'), ('BYOL', '32'), ('NNCLR', '32'), ('SwaV', '32')]


In [4]:
def get_data(exp_path):
    chromos = []
    with open(exp_path, "r") as f:
        results = json.load(f)
        for fitness, chromo in zip(results["pop_vals"], results["chromos"]):
            c = list(chain.from_iterable(chromo[1]))
            c.append(fitness[1])
            chromos.append(c)
    columns = list(chain.from_iterable([[f"aug{i}", f"op{i}"] for i in range(1, 4)]))
    columns.append("test acc")
    columns
    df = pd.DataFrame(chromos, columns=columns)
    
    # create data in long format
    ops = set(list(df["aug1"]) + list(df["aug2"]) +  list(df["aug3"]))
    indexes = {op:i for i, op in enumerate(ops)}
    chromos_long = np.zeros([len(chromos), len(ops)+1])
    for i, c in enumerate(chromos):
        for aug, intensity in zip(c[:-2][::2], c[:-2][1::2]):
            chromos_long[i][indexes[aug]] = intensity
            chromos_long[i][-1] = c[-1]
    columns_long = list(ops) + ["fitness"]
    df_long = pd.DataFrame(chromos_long, columns = columns_long)
    return df, df_long, ops

def get_init_pop_avg(exp_path):
    with open(exp_path, "r") as f:
        results = json.load(f)
        c = 0
        p = results['pop_vals'][c]
        init_pop = []
        while p[0] == 0:
            init_pop.append(p[1])
            c+=1
            p = results['pop_vals'][c]
    return sum(init_pop)/len(init_pop)
        
i = get_init_pop_avg(exp_path=b256[0])
print(i)

82.85666666666665


In [19]:
for exp_path, exp_name in zip(b256, b256_names):
    df, df_long, ops = get_data(exp_path)
    print(exp_name, " ", max(df_long['fitness']))

('NNCLR', '256')   84.15
('SimSiam', '256')   84.43
('BYOL', '256')   84.1
('SwaV', '256')   84.23


In [6]:
for exp_path, exp_name in zip(b32, b32_names):
    df, df_long, ops = get_data(exp_path)
    print(exp_name, " ", max(df_long['fitness']))

('SimSiam', '32')   84.08
('BYOL', '32')   84.12
('NNCLR', '32')   84.18
('SwaV', '32')   84.07


In [20]:
for exp_path, exp_name in zip(b256_svhn, b256_names_svhn):
    df, df_long, ops = get_data(exp_path)
    print(exp_name, " ", round(max(df_long['fitness']), 2))

('NNCLR', '256')   93.1
('SwaV', '256')   93.2
('BYOL', '256')   92.67
('SimSiam', '256')   92.74


In [22]:
for exp_path, exp_name in zip(b32_svhn, b32_names_svhn):
    df, df_long, ops = get_data(exp_path)
    print(exp_name, " ", round(max(df_long['fitness']), 2))

('SwaV', '32')   93.29
('NNCLR', '32')   92.77
('SimSiam', '32')   93.41
('BYOL', '32')   91.3


top 2 accuracies

In [9]:

def top2(model_path, dataset):
    backbone = largerCNN_backbone()
    model = finetune_model(backbone=backbone.backbone, in_features=backbone.in_features, num_outputs=dataset.num_classes)
    model.load_state_dict(torch.load(model_path))
    model = model.cuda()
    dataloader = DataLoader(dataset.test_data, batch_size=64,shuffle=False)
    total_correct = 0
    total = 0
    for X, y in dataloader:
        X, y = X.cuda(), y.cuda()
        outputs = model(X)
        vals, inds = torch.topk(outputs, k=2)
        output_ohe = torch.nn.functional.one_hot(inds, dataset.num_classes)
        output_ohe = torch.sum(output_ohe, dim=1)
        y_ohe = torch.nn.functional.one_hot(y, dataset.num_classes) 
        correct = torch.sum(output_ohe * y_ohe)
        total_correct += correct.item()
        total+= X.shape[0]
    return total_correct/total
    


In [31]:
for exp_path, exp_name in zip(b32, b32_names):
    models = glob.glob(os.path.join(os.path.dirname(exp_path), "models/*.pt"))
    if len(models) == 1:
        model_path = models[0]
    elif len(models) == 0:
        continue
    else:
        model_path = [m for m in models if "downstream" in m][0]
    top2_acc = top2(model_path, dataset=Cifar10())
    print(exp_name, " ", round(top2_acc*100, 2))

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
('SimSiam', '32')   92.58
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
('BYOL', '32')   93.2
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
('NNCLR', '32')   92.81
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
('SwaV', '32')   93.02


In [28]:
for exp_path, exp_name in zip(b256, b256_names):
    models = glob.glob(os.path.join(os.path.dirname(exp_path), "models/*.pt"))
    if len(models) == 1:
        model_path = models[0]
    elif len(models) == 0:
        continue
    else:
        model_path = [m for m in models if "downstream" in m][0]
    top2_acc = top2(model_path, dataset=Cifar10())
    print(exp_name, " ", round(top2_acc*100, 2))

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
('NNCLR', '256')   92.53
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
('SimSiam', '256')   93.23
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
('BYOL', '256')   92.91
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
('SwaV', '256')   92.19


In [29]:
for exp_path, exp_name in zip(b32_svhn, b32_names_svhn):
    models = glob.glob(os.path.join(os.path.dirname(exp_path), "models/*.pt"))
    if len(models) == 1:
        model_path = models[0]
    elif len(models) == 0:
        continue
    else:
        model_path = [m for m in models if "downstream" in m][0]
    top2_acc = top2(model_path, dataset=SVHN())
    print(exp_name, " ", round(top2_acc*100, 2))

Using downloaded and verified file: datasets/SVHN/train_32x32.mat
Using downloaded and verified file: datasets/SVHN/test_32x32.mat
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
('SwaV', '32')   96.58
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
Using downloaded and verified file: datasets/SVHN/test_32x32.mat
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
('NNCLR', '32')   95.76
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
Using downloaded and verified file: datasets/SVHN/test_32x32.mat
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
('SimSiam', '32')   95.82
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
Using downloaded and verified file: datasets/SVHN/test_32x32.mat
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
('BYOL', '32')   96.12


In [30]:
for exp_path, exp_name in zip(b256_svhn, b256_names_svhn):
    models = glob.glob(os.path.join(os.path.dirname(exp_path), "models/*.pt"))
    if len(models) == 1:
        model_path = models[0]
    elif len(models) == 0:
        continue
    else:
        model_path = [m for m in models if "downstream" in m][0]
    top2_acc = top2(model_path, dataset=SVHN())
    print(exp_name, " ", round(top2_acc*100, 2))

Using downloaded and verified file: datasets/SVHN/train_32x32.mat
Using downloaded and verified file: datasets/SVHN/test_32x32.mat
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
('NNCLR', '256')   95.99
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
Using downloaded and verified file: datasets/SVHN/test_32x32.mat
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
('SwaV', '256')   96.16
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
Using downloaded and verified file: datasets/SVHN/test_32x32.mat
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
('BYOL', '256')   95.72
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
Using downloaded and verified file: datasets/SVHN/test_32x32.mat
Using downloaded and verified file: datasets/SVHN/train_32x32.mat
('SimSiam', '256')   96.56


In [14]:
for exp_path, exp_name in zip(b256, b256_names):
    print(exp_name, " ", round(get_init_pop_avg( exp_path), 2))

('NNCLR', '256')   82.86
('SimSiam', '256')   83.33
('BYOL', '256')   83.35
('SwaV', '256')   83.47


In [15]:
for exp_path, exp_name in zip(b32, b32_names):
    print(exp_name, " ", round(get_init_pop_avg( exp_path), 2))

('SimSiam', '32')   82.45
('BYOL', '32')   83.46
('NNCLR', '32')   82.91
('SwaV', '32')   82.63


In [16]:
for exp_path, exp_name in zip(b256_svhn, b256_names_svhn):
    print(exp_name, " ", round(get_init_pop_avg( exp_path), 2))

('NNCLR', '256')   92.2
('SwaV', '256')   92.1
('BYOL', '256')   91.96
('SimSiam', '256')   91.94


In [17]:
for exp_path, exp_name in zip(b32_svhn, b32_names_svhn):
    print(exp_name, " ", round(get_init_pop_avg( exp_path), 2))

('SwaV', '32')   92.68
('NNCLR', '32')   92.08
('SimSiam', '32')   92.5
('BYOL', '32')   89.15
