In [1]:
import os
import time
import math
import random
import itertools
import numpy as np
import pandas as pd
# PyTorch
import torch
import torchvision
import torchmetrics

In [2]:
import sys
sys.path.append('../src/')

%load_ext autoreload
%autoreload 2
# Importing our custom module(s)
import losses
import utils

In [3]:
model = torchvision.models.resnet50()
model.fc = torch.nn.Linear(in_features=2048, out_features=37, bias=True)
params = torch.nn.utils.parameters_to_vector(model.parameters())
print(len(params))

23583845


In [8]:
experiments_directory = '/cluster/tufts/hugheslab/eharve06/informed-priors/experiments/retrained_Oxford-IIIT_Pets_VI'
lr_0s = [0.1, 0.01, 0.001, 0.0001]
methods = ['l2-zero', 'l2-sp', 'ptyl']
ns = [370, 3441]
random_states = [1001, 2001, 3001]

columns = ['avg_sec/epoch', 'method', 'model_name', 'n', 'random_state', 'runtime', 'train_loss', 'test_acc', 'test_nll']
results = pd.DataFrame(columns=columns)

for lr_0, method, n, random_state in itertools.product(lr_0s, methods, ns, random_states):
    model_name = f'{method}_kappa={23583845/n}_lr_0={lr_0}_n={n}_random_state={random_state}'
    temp_df = pd.read_csv(f'{experiments_directory}/{model_name}.csv')
    row = [temp_df['train_sec/epoch'].mean(), method, model_name, n, random_state, temp_df['train_sec/epoch'].sum(), temp_df.train_loss.values[-1], temp_df.val_or_test_acc.values[-1], temp_df.val_or_test_nll.values[-1]]
    results.loc[len(results)] = row

In [9]:
min_indices = results.groupby(['method', 'n', 'random_state'])['train_loss'].idxmin()
best_results = results.loc[min_indices]
best_results['total_time'] = results.groupby(['method', 'n', 'random_state'])['runtime'].sum().values
best_results

Unnamed: 0,avg_sec/epoch,method,model_name,n,random_state,runtime,train_loss,test_acc,test_nll,total_time
42,0.6425,l2-sp,l2-sp_kappa=63740.12162162162_lr_0=0.001_n=370...,370,1001,1927.499683,1274052.0,0.868171,0.553892,7551.031071
43,0.737395,l2-sp,l2-sp_kappa=63740.12162162162_lr_0=0.001_n=370...,370,2001,2212.183542,1192010.0,0.871974,0.561397,8773.028834
44,0.628366,l2-sp,l2-sp_kappa=63740.12162162162_lr_0=0.001_n=370...,370,3001,1885.097442,1182029.0,0.855082,0.608515,8196.339905
45,8.811878,l2-sp,l2-sp_kappa=6853.776518453938_lr_0=0.001_n=344...,3441,1001,2026.731989,962782.6,0.930966,0.23836,8084.142288
46,9.330259,l2-sp,l2-sp_kappa=6853.776518453938_lr_0=0.001_n=344...,3441,2001,2145.959629,951793.2,0.93308,0.235463,8844.304394
47,10.535445,l2-sp,l2-sp_kappa=6853.776518453938_lr_0=0.001_n=344...,3441,3001,2423.152366,968905.3,0.929603,0.240015,8844.180085
36,0.689356,l2-zero,l2-zero_kappa=63740.12162162162_lr_0=0.001_n=3...,370,1001,2068.067386,73829430.0,0.816996,0.887582,8100.566932
37,0.675249,l2-zero,l2-zero_kappa=63740.12162162162_lr_0=0.001_n=3...,370,2001,2025.747493,68533190.0,0.83566,0.83814,8010.904821
38,0.710169,l2-zero,l2-zero_kappa=63740.12162162162_lr_0=0.001_n=3...,370,3001,2130.506056,69337320.0,0.800639,1.003989,8529.742489
39,9.212243,l2-zero,l2-zero_kappa=6853.776518453938_lr_0=0.001_n=3...,3441,1001,2118.815861,40542260.0,0.922984,0.258817,8475.124801


In [10]:
grouped_results = best_results.groupby(['method', 'n']).agg(lambda x: tuple(x))
columns = ['test_acc', 'test_nll']
for column in columns:
    grouped_results[f'{column}_mean'] = grouped_results[column].apply(lambda item: np.mean(item))
    grouped_results[f'{column}_std'] = grouped_results[column].apply(lambda item: np.std(item))
    grouped_results[f'{column}_min'] = grouped_results[column].apply(lambda item: np.min(item))
    grouped_results[f'{column}_max'] = grouped_results[column].apply(lambda item: np.max(item))
grouped_results = grouped_results.reset_index()
grouped_results

Unnamed: 0,method,n,avg_sec/epoch,model_name,random_state,runtime,train_loss,test_acc,test_nll,total_time,test_acc_mean,test_acc_std,test_acc_min,test_acc_max,test_nll_mean,test_nll_std,test_nll_min,test_nll_max
0,l2-sp,370,"(0.6424998943010966, 0.7373945138454437, 0.628...",(l2-sp_kappa=63740.12162162162_lr_0=0.001_n=37...,"(1001, 2001, 3001)","(1927.4996829032898, 2212.183541536331, 1885.0...","(1274051.866858108, 1192010.299662162, 1182029...","(0.8681714534759521, 0.8719738125801086, 0.855...","(0.5538923214937254, 0.5613965085723089, 0.608...","(7551.0310707092285, 8773.02883386612, 8196.33...",0.865076,0.007235,0.855082,0.871974,0.574601,0.024176,0.553892,0.608515
1,l2-sp,3441,"(8.81187821160192, 9.33025925781416, 10.535445...",(l2-sp_kappa=6853.776518453938_lr_0=0.001_n=34...,"(1001, 2001, 3001)","(2026.7319886684418, 2145.9596292972565, 2423....","(962782.6223427056, 951793.202742662, 968905.2...","(0.9309657216072084, 0.9330804347991944, 0.929...","(0.2383600340942396, 0.2354628913374615, 0.240...","(8084.142288208008, 8844.304393529892, 8844.18...",0.931216,0.001431,0.929603,0.93308,0.237946,0.001881,0.235463,0.240015
2,l2-zero,370,"(0.6893557953834534, 0.6752491643428803, 0.710...",(l2-zero_kappa=63740.12162162162_lr_0=0.001_n=...,"(1001, 2001, 3001)","(2068.06738615036, 2025.7474930286407, 2130.50...","(73829432.64864865, 68533193.64324324, 6933731...","(0.8169963359832764, 0.8356602191925049, 0.800...","(0.88758203831168, 0.8381401887866645, 1.00398...","(8100.566932439804, 8010.904821395874, 8529.74...",0.817765,0.014308,0.800639,0.83566,0.909904,0.069523,0.83814,1.003989
3,l2-zero,3441,"(9.212242872818656, 8.36751988970715, 8.869133...",(l2-zero_kappa=6853.776518453938_lr_0=0.001_n=...,"(1001, 2001, 3001)","(2118.815860748291, 1924.5295746326447, 2039.9...","(40542259.95268817, 39702876.96297588, 4142613...","(0.9229835271835328, 0.9217308759689332, 0.922...","(0.258816998513884, 0.2770982319232747, 0.2730...","(8475.124800920486, 8273.532461643219, 7895.40...",0.922565,0.000589,0.921731,0.922984,0.269666,0.007845,0.258817,0.277098
4,ptyl,370,"(0.7357667366663615, 0.675787350098292, 0.7894...",(ptyl_kappa=63740.12162162162_lr_0=0.001_n=370...,"(1001, 2001, 3001)","(2207.3002099990845, 2027.362050294876, 2368.4...","(1274936.3504391892, 1192581.1858868245, 11819...","(0.8681714534759521, 0.8725143671035767, 0.854...","(0.5538901293600899, 0.5614445472842311, 0.608...","(8735.077678442001, 8144.731352329254, 9191.08...",0.865166,0.007533,0.854812,0.872514,0.574617,0.024168,0.55389,0.608517
5,ptyl,3441,"(9.065430354035419, 9.044891802124354, 9.55875...",(ptyl_kappa=6853.776518453938_lr_0=0.001_n=344...,"(1001, 2001, 3001)","(2085.0489814281464, 2080.3251144886017, 2198....","(962646.8420916884, 951610.1313607964, 968633....","(0.9309657216072084, 0.9333507418632508, 0.929...","(0.2383294473642077, 0.2354536867370785, 0.239...","(8786.081049442291, 8383.4397585392, 8747.9712...",0.931307,0.001549,0.929603,0.933351,0.237915,0.001864,0.235454,0.239962


In [11]:
grouped_results[['method', 'n', 'test_acc_mean', 'test_acc_min', 'test_acc_max']]

Unnamed: 0,method,n,test_acc_mean,test_acc_min,test_acc_max
0,l2-sp,370,0.865076,0.855082,0.871974
1,l2-sp,3441,0.931216,0.929603,0.93308
2,l2-zero,370,0.817765,0.800639,0.83566
3,l2-zero,3441,0.922565,0.921731,0.922984
4,ptyl,370,0.865166,0.854812,0.872514
5,ptyl,3441,0.931307,0.929603,0.933351


In [12]:
grouped_results[['method', 'n', 'test_nll_mean', 'test_nll_min', 'test_nll_max']]

Unnamed: 0,method,n,test_nll_mean,test_nll_min,test_nll_max
0,l2-sp,370,0.574601,0.553892,0.608515
1,l2-sp,3441,0.237946,0.235463,0.240015
2,l2-zero,370,0.909904,0.83814,1.003989
3,l2-zero,3441,0.269666,0.258817,0.277098
4,ptyl,370,0.574617,0.55389,0.608517
5,ptyl,3441,0.237915,0.235454,0.239962
