In [1]:
import os
import itertools
import numpy as np
import pandas as pd
# PyTorch
import torch
import torchvision
import matplotlib.pyplot as plt

In [2]:
import sys
sys.path.append('../src/')
# Importing our custom module(s)
import utils
import losses

In [3]:
def get_df(path):
    df = pd.read_csv(path, index_col='Unnamed: 0')
    return df

def get_last_epoch(df):
    return df.iloc[-1]
    
def get_best_hyperparameters(experiments_path, lr_0s, ns, prior_scales, prior_type, random_states, weight_decays):
    columns = ['lr_0', 'n', 'prior_scale', 'prior_type', 'random_state', 'val_acc', 'weight_decay']
    df = pd.DataFrame(columns=columns)
    for n, random_state in itertools.product(ns, random_states):
        best_val_nll = np.inf
        best_hyperparameters = None
        for lr_0, prior_scale, weight_decay in itertools.product(lr_0s, prior_scales, weight_decays):
            if prior_scale:
                model_name = '{}_lr_0={}_n={}_prior_scale={}_random_state={}_weight_decay={}'\
                .format(prior_type, lr_0, n, prior_scale, random_state, weight_decay)
            else:
                model_name = '{}_lr_0={}_n={}_random_state={}_weight_decay={}'\
                .format(prior_type, lr_0, n, random_state, weight_decay)
            path =  '{}/{}.csv'.format(experiments_path, model_name)
            val_nll = get_val_nll(get_df(path))
            val_acc = get_val_acc(get_df(path))
            if val_nll < best_val_nll: best_val_nll = val_nll; best_hyperparameters = [lr_0, n, prior_scale, prior_type, random_state, val_acc, weight_decay]
        df.loc[df.shape[0]] = best_hyperparameters
    return df

def get_val_acc(df):
    return df.val_or_test_acc.values[-1]

def get_val_nll(df):
    return df.val_or_test_nll.values[-1]

In [5]:
# Nonlearned
experiments_path = '/cluster/tufts/hugheslab/eharve06/bdl-transfer-learning/experiments/tuned_CIFAR-10_Copy1'
lr_0s = np.logspace(-1, -4, num=4)
ns = [1000]
prior_scales = [None]
prior_type = 'nonlearned'
random_states = [1001, 2001, 3001]
weight_decays = np.append(np.logspace(-2, -6, num=5), 0)
nonlearned_hyperparameters = get_best_hyperparameters(experiments_path, lr_0s, ns, prior_scales, prior_type, random_states, weight_decays)
nonlearned_hyperparameters

Unnamed: 0,lr_0,n,prior_scale,prior_type,random_state,val_acc,weight_decay
0,0.0001,1000,,nonlearned,1001,0.755,0.0001
1,0.01,1000,,nonlearned,2001,0.82,1e-05
2,0.01,1000,,nonlearned,3001,0.865,0.001


In [21]:
num_heads = 10
for (nonlearned_index, nonlearned_row) in nonlearned_hyperparameters.iterrows():
    # Finetuned model
    experiments_path = '/cluster/tufts/hugheslab/eharve06/bdl-transfer-learning/experiments/retrained_CIFAR-10_Copy1'
    model_name = 'nonlearned_lr_0={}_n={}_random_state={}_weight_decay={}'\
    .format(nonlearned_row.lr_0, int(nonlearned_row.n), int(nonlearned_row.random_state), nonlearned_row.weight_decay)
    finetuned_checkpoint = torch.load('{}/{}.pth'.format(experiments_path, model_name), map_location=torch.device('cpu'))
    # Pretrained checkpoint
    prior_path = '/cluster/tufts/hugheslab/eharve06/resnet50_ssl_prior'
    checkpoint = torch.load('{}/resnet50_ssl_prior_model.pt'.format(prior_path), map_location=torch.device('cpu'))
    pretrained_checkpoint = torchvision.models.resnet50()
    pretrained_checkpoint.fc = torch.nn.Identity()
    pretrained_checkpoint.load_state_dict(checkpoint)
    pretrained_checkpoint.fc = torch.nn.Linear(in_features=2048, out_features=num_heads, bias=True)
    pretrained_checkpoint.state_dict()['fc.weight'] = finetuned_checkpoint['fc.weight']
    pretrained_checkpoint.state_dict()['fc.bias'] = finetuned_checkpoint['fc.bias']
    interpolations = interpolate_checkpoints(pretrained_checkpoint, finetuned_checkpoint)

    
pretrained_checkpoint

torch.Size([10, 2048])
torch.Size([10])
torch.Size([10, 2048])
torch.Size([10])
torch.Size([10, 2048])
torch.Size([10])


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 