In [1]:
import os
import time
import math
import random
import itertools
import numpy as np
import pandas as pd
# PyTorch
import torch
import torchvision
import torchmetrics

In [2]:
import sys
sys.path.append('../src/')

%load_ext autoreload
%autoreload 2
# Importing our custom module(s)
import losses
import utils

In [3]:
model = torchvision.models.resnet50()
model.fc = torch.nn.Linear(in_features=2048, out_features=10, bias=True)
params = torch.nn.utils.parameters_to_vector(model.parameters())
print(len(params))

23528522


In [4]:
experiments_directory = '/cluster/tufts/hugheslab/eharve06/informed-priors/experiments/retrained_CIFAR-10_VI'
lr_0s = [0.1, 0.01, 0.001, 0.0001]
methods = ['l2-zero', 'l2-sp', 'ptyl']
ns = [100, 1000, 10000, 50000]
random_states = [1001, 2001, 3001]

columns = ['avg_sec/epoch', 'method', 'model_name', 'n', 'random_state', 'runtime', 'train_loss', 'test_acc', 'test_nll']
results_df = pd.DataFrame(columns=columns)

for lr_0, method, n, random_state in itertools.product(lr_0s, methods, ns, random_states):
    model_name = f'{method}_kappa={23528522/n}_lr_0={lr_0}_n={n}_random_state={random_state}'
    temp_df = pd.read_csv(f'{experiments_directory}/{model_name}.csv')
    row = [temp_df['train_sec/epoch'].mean(), method, model_name, n, random_state, temp_df['train_sec/epoch'].sum(), temp_df.train_loss.values[-1], temp_df.val_or_test_acc.values[-1], temp_df.val_or_test_nll.values[-1]]
    results_df.loc[len(results_df)] = row

In [5]:
min_indices = results_df.groupby(['method', 'n', 'random_state'])['train_loss'].idxmin()
best_results_df = results_df.loc[min_indices]
best_results_df['total_time'] = results_df.groupby(['method', 'n', 'random_state'])['runtime'].sum().values
best_results_df

Unnamed: 0,avg_sec/epoch,method,model_name,n,random_state,runtime,train_loss,test_acc,test_nll,total_time
84,0.287038,l2-sp,l2-sp_kappa=235285.22_lr_0=0.001_n=100_random_...,100,1001,1722.230124,252390.6,0.705,1.002224,6887.981032
85,0.287962,l2-sp,l2-sp_kappa=235285.22_lr_0=0.001_n=100_random_...,100,2001,1727.77306,683253.7,0.6976,1.120148,7485.08461
86,0.331697,l2-sp,l2-sp_kappa=235285.22_lr_0=0.001_n=100_random_...,100,3001,1990.184736,497417.6,0.7222,0.989427,7704.61928
87,2.804573,l2-sp,l2-sp_kappa=23528.522_lr_0=0.001_n=1000_random...,1000,1001,2403.519479,784980.7,0.8752,0.389472,9604.116255
88,2.688813,l2-sp,l2-sp_kappa=23528.522_lr_0=0.001_n=1000_random...,1000,2001,2304.312855,819827.6,0.8724,0.397867,9308.71659
89,2.632486,l2-sp,l2-sp_kappa=23528.522_lr_0=0.001_n=1000_random...,1000,3001,2256.040596,633631.1,0.8686,0.419948,9086.670098
54,29.271387,l2-sp,l2-sp_kappa=2352.8522_lr_0=0.01_n=10000_random...,10000,1001,2224.625432,3087789.0,0.9499,0.275243,8937.065129
55,27.173933,l2-sp,l2-sp_kappa=2352.8522_lr_0=0.01_n=10000_random...,10000,2001,2065.218935,3189351.0,0.9501,0.272663,8753.639591
56,27.384826,l2-sp,l2-sp_kappa=2352.8522_lr_0=0.01_n=10000_random...,10000,3001,2081.246807,3005865.0,0.9544,0.246688,8565.710572
57,159.481004,l2-sp,l2-sp_kappa=470.57044_lr_0=0.01_n=50000_random...,50000,1001,2392.21506,1330268.0,0.967,0.121275,9360.659174


In [6]:
grouped_results = best_results.groupby(['method', 'n']).agg(lambda x: tuple(x))
columns = ['test_acc', 'test_nll']
for column in columns:
    grouped_results[f'{column}_mean'] = grouped_results[column].apply(lambda item: np.mean(item))
    grouped_results[f'{column}_std'] = grouped_results[column].apply(lambda item: np.std(item))
    grouped_results[f'{column}_min'] = grouped_results[column].apply(lambda item: np.min(item))
    grouped_results[f'{column}_max'] = grouped_results[column].apply(lambda item: np.max(item))
grouped_results = grouped_results.reset_index()
grouped_results

Unnamed: 0,method,n,avg_sec/epoch,model_name,random_state,runtime,train_loss,test_acc,test_nll,total_time,test_acc_mean,test_acc_std,test_acc_min,test_acc_max,test_nll_mean,test_nll_std,test_nll_min,test_nll_max
0,l2-sp,100,"(0.28703835395971933, 0.2879621766805649, 0.33...",(l2-sp_kappa=235285.22_lr_0=0.001_n=100_random...,"(1001, 2001, 3001)","(1722.230123758316, 1727.7730600833893, 1990.1...","(252390.55546875, 683253.696875, 497417.5625)","(0.7049999833106995, 0.6976000070571899, 0.722...","(1.0022236173629762, 1.1201482831954956, 0.989...","(6887.981032133102, 7485.08461022377, 7704.619...",0.708267,0.010305,0.6976,0.7222,1.037266,0.058839,0.989427,1.120148
1,l2-sp,1000,"(2.80457348751215, 2.688813133028313, 2.632486...",(l2-sp_kappa=23528.522_lr_0=0.001_n=1000_rando...,"(1001, 2001, 3001)","(2403.5194787979126, 2304.3128550052643, 2256....","(784980.7453, 819827.5935, 633631.1348000001)","(0.8751999735832214, 0.8723999857902527, 0.868...","(0.3894716707229613, 0.3978671109676359, 0.419...","(9604.116255044937, 9308.716590166092, 9086.67...",0.872067,0.002705,0.8686,0.8752,0.402429,0.012853,0.389472,0.419948
2,l2-sp,10000,"(29.271387263348227, 27.17393336170598, 27.384...",(l2-sp_kappa=2352.8522_lr_0=0.01_n=10000_rando...,"(1001, 2001, 3001)","(2224.6254320144653, 2065.2189354896545, 2081....","(3087788.7058350006, 3189351.18163, 3005865.33...","(0.949899971485138, 0.9501000046730042, 0.9544...","(0.2752425508975982, 0.2726631891250611, 0.246...","(8937.065129041672, 8753.639590501785, 8565.71...",0.951467,0.002076,0.9499,0.9544,0.264865,0.012896,0.246688,0.275243
3,l2-sp,50000,"(159.4810039838155, 147.98391919136049, 148.65...",(l2-sp_kappa=470.57044_lr_0=0.01_n=50000_rando...,"(1001, 2001, 3001)","(2392.2150597572327, 2219.758787870407, 2229.7...","(1330268.157527, 1322979.805988, 1334417.681726)","(0.9670000076293944, 0.968999981880188, 0.9675...","(0.1212750284314155, 0.1171597015619278, 0.122...","(9360.659173727036, 8883.92631316185, 9227.181...",0.967867,0.000838,0.967,0.969,0.120217,0.002196,0.11716,0.122218
4,l2-zero,100,"(0.3054483083486557, 0.28721131598949434, 0.31...",(l2-zero_kappa=235285.22_lr_0=0.001_n=100_rand...,"(1001, 2001, 3001)","(1832.6898500919342, 1723.267895936966, 1865.8...","(76920992.0, 72567470.8, 64847018.0)","(0.6238000392913818, 0.6103999614715576, 0.633...","(2.7010212276458736, 2.8018348682403573, 2.517...","(7561.7303376197815, 7272.718409776688, 7536.5...",0.622533,0.009432,0.6104,0.6334,2.673482,0.117665,2.517591,2.801835
5,l2-zero,1000,"(2.42838150339338, 2.45543369390027, 2.6256449...",(l2-zero_kappa=23528.522_lr_0=0.001_n=1000_ran...,"(1001, 2001, 3001)","(2081.122948408127, 2104.306675672531, 2250.17...","(46154403.76, 43281994.8144, 43791327.2272)","(0.8748998641967773, 0.8754000067710876, 0.871...","(0.5100445497512817, 0.5201272904396056, 0.554...","(8820.797352313995, 8979.942680597305, 8997.84...",0.873833,0.001873,0.8712,0.8754,0.528163,0.018946,0.510045,0.554317
6,l2-zero,10000,"(30.12663439700478, 31.452493655054194, 29.342...",(l2-zero_kappa=2352.8522_lr_0=0.001_n=10000_ra...,"(1001, 2001, 3001)","(2289.6242141723633, 2390.3895177841187, 2230....","(42980314.15488001, 42587871.50560001, 4214105...","(0.9123000502586364, 0.9083999991416932, 0.918...","(0.270970040845871, 0.286863360118866, 0.25891...","(9195.224611759186, 9584.208001375198, 9119.93...",0.9131,0.004202,0.9084,0.9186,0.27225,0.011445,0.258918,0.286863
7,l2-zero,50000,"(136.46109395027162, 153.78041626612347, 149.0...",(l2-zero_kappa=470.57044_lr_0=0.1_n=50000_rand...,"(1001, 2001, 3001)","(2046.916409254074, 2306.706243991852, 2235.98...","(32138092.513312005, 32015075.714208007, 32663...","(0.9321999549865724, 0.9335999488830566, 0.923...","(0.2810456726074218, 0.2754840620994567, 0.309...","(8585.871500015259, 9036.23970079422, 8596.207...",0.929767,0.004468,0.9235,0.9336,0.288585,0.01477,0.275484,0.309225
8,ptyl,100,"(0.3117110659281413, 0.34929334648450217, 0.34...",(ptyl_kappa=235285.22_lr_0=0.001_n=100_random_...,"(1001, 2001, 3001)","(1870.2663955688477, 2095.760078907013, 2090.2...","(118049.86796875, 770929.9953125, 503117.94375)","(0.6980000138282776, 0.6962999701499939, 0.723...","(1.0142373960494997, 1.1354751213073728, 0.982...","(7482.5085418224335, 8451.942963838577, 8297.3...",0.705967,0.012488,0.6963,0.7236,1.044091,0.0659,0.98256,1.135475
9,ptyl,1000,"(2.721273878213266, 2.524125892612453, 2.71914...",(ptyl_kappa=23528.522_lr_0=0.001_n=1000_random...,"(1001, 2001, 3001)","(2332.131713628769, 2163.175889968872, 2330.30...","(790152.7083, 845904.2682500001, 646035.647850...","(0.8743999600410461, 0.8705999851226807, 0.868...","(0.3869720663070679, 0.3949382533550263, 0.418...","(9415.471363067627, 8649.986447572708, 9162.36...",0.871267,0.002334,0.8688,0.8744,0.400282,0.013586,0.386972,0.418937


In [7]:
grouped_results[['method', 'n', 'test_acc_mean', 'test_acc_min', 'test_acc_max']]

Unnamed: 0,method,n,test_acc_mean,test_acc_min,test_acc_max
0,l2-sp,100,0.708267,0.6976,0.7222
1,l2-sp,1000,0.872067,0.8686,0.8752
2,l2-sp,10000,0.951467,0.9499,0.9544
3,l2-sp,50000,0.967867,0.967,0.969
4,l2-zero,100,0.622533,0.6104,0.6334
5,l2-zero,1000,0.873833,0.8712,0.8754
6,l2-zero,10000,0.9131,0.9084,0.9186
7,l2-zero,50000,0.929767,0.9235,0.9336
8,ptyl,100,0.705967,0.6963,0.7236
9,ptyl,1000,0.871267,0.8688,0.8744


In [8]:
grouped_results[['method', 'n', 'test_nll_mean', 'test_nll_min', 'test_nll_max']]

Unnamed: 0,method,n,test_nll_mean,test_nll_min,test_nll_max
0,l2-sp,100,1.037266,0.989427,1.120148
1,l2-sp,1000,0.402429,0.389472,0.419948
2,l2-sp,10000,0.264865,0.246688,0.275243
3,l2-sp,50000,0.120217,0.11716,0.122218
4,l2-zero,100,2.673482,2.517591,2.801835
5,l2-zero,1000,0.528163,0.510045,0.554317
6,l2-zero,10000,0.27225,0.258918,0.286863
7,l2-zero,50000,0.288585,0.275484,0.309225
8,ptyl,100,1.044091,0.98256,1.135475
9,ptyl,1000,0.400282,0.386972,0.418937


In [39]:
import torch
import torchvision

In [40]:
dataset_directory = '/cluster/tufts/hugheslab/eharve06/CIFAR-10'
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
cifar10_train_dataset = torchvision.datasets.CIFAR10(root=dataset_directory, train=True, transform=transform, download=True)
cifar10_test_dataset = torchvision.datasets.CIFAR10(root=dataset_directory, train=False, transform=transform, download=True)

Files already downloaded and verified
Files already downloaded and verified


In [65]:
from sklearn.model_selection import train_test_split
n = 100
train_and_val_indices, _ = train_test_split(
    range(len(cifar10_train_dataset)), 
    test_size=None, 
    train_size=n, 
    random_state=42, 
    shuffle=True, 
    stratify=np.array(cifar10_train_dataset.targets),
)
val_size = int((1/5) * n)
train_indices, val_indices = train_test_split(
    train_and_val_indices, 
    test_size=val_size, 
    train_size=n-val_size, 
    random_state=42, 
    shuffle=True, 
    stratify=np.array(cifar10_train_dataset.targets)[train_and_val_indices],
)
np.array(cifar10_train_dataset.targets)[val_indices]

array([0, 8, 2, 1, 3, 9, 1, 6, 7, 0, 6, 4, 8, 3, 4, 7, 5, 9, 5, 2])