In [None]:
import torch
import numpy as np
import matplotlib
from scipy import ndimage
import os, sys
import math
import pickle
import model_utils as modutil
import data_utils as datutil
import hmc

In [None]:
# Data loader initialization
trainloader1 = datutil.generate_dataloaders('CIFAR10_TRAIN', batch_size=50, shuffle=False, num_workers=2)
testloader1 = datutil.generate_dataloaders('CIFAR10_TEST', batch_size=10, shuffle=False, num_workers=2)
validloader1 = datutil.generate_dataloaders('CIFAR100_TEST', batch_size=10, shuffle=False, num_workers=2)

trainloader2 = datutil.generate_dataloaders('GP+ENCODED256_D164_CIFAR10_TRAIN', batch_size=50, shuffle=True, num_workers=2)
testloader2 = datutil.generate_dataloaders('GP+ENCODED256_D164_CIFAR10_TEST', batch_size=10, shuffle=False, num_workers=2)
validloader2 = datutil.generate_dataloaders('GP+ENCODED256_D164_CIFAR10_VALID', batch_size=10, shuffle=False, num_workers=2)

# Cifar-100 interesting classes (used during out-of-class entropy calculation)
interesting_labels = [0, 1, 16, 17, 20, 21, 29, 39, 40, 49, 57, 71, 72, 73, 76]

In [None]:
trainloader1 = datutil.generate_dataloaders('DIAB_RETIN_TRAIN', batch_size=7, shuffle=True, num_workers=2)
testloader1 = datutil.generate_dataloaders('DIAB_RETIN_TEST', batch_size=5, shuffle=False, num_workers=2)

In [None]:
# model to train/load/analyse
# user defined params
models = [{'model_type' : "PreResNet+GP",     # <Kernel_name> + <GP>
    'saved_checkpoint_name' : "Initialized-PreResNet+GP_depth-164_lr-0.005_mom-0.9_wd-0.0001_FC-_acc-91.36-2019-08-19-16.22",
    'fc_setup' : [],
    'load_model' : True,
    'partial_load' : False,
    'component_pretrained_mods' : ['encoded+GP_depth-110_lr-0.1_mom-0.9_wd-0.0003_FC-_acc-90.23-2019-08-02-15.03',
                                   'ModifiedResNet-fixed_depth-110_lr-0.1_mom-0.9_wd-0.0001_FC-10_acc-90.92-2019-03-17-12.28'],

#     'component_pretrained_mods' : ['encoded+GP_depth-164_lr-0.1_mom-0.9_wd-0.0001_FC-_acc-90.88-2019-08-13-17.08',
#                                    'ModifiedResnet-fixed_depth-164_lr-0.1_mom-0.9_wd-0.0001_FC-10_acc-92.03-2019-04-28-03.10'],
    'train_model' : False,
    'train_epoch' : 5,
    'num_classes' : 10,
    'weight_decay' : 1e-4,
    'predef_test_acc' : 92,
    'depth' : 110,
    'grid_size' : 64,
    'optim_SGD' : True,
    'device' : torch.device('cuda:2'),
    'lr_init' : 0.008,
    'lr_final' : 0.005,
    'gp_kernel_feature' : 256, # 256, 640
    'print_init_model_state' : False},]

In [None]:
for model in models:
    
    if 'encoded' not in model['model_type']:
        trainloader = trainloader1
        testloader = testloader1
        validloader = validloader1
    else:
        trainloader = trainloader2
        testloader = testloader2
        validloader = validloader2
        
    print('='*20, "Loading Model", '='*20)
    modutil.refresh_params()
    for propt in model:
        if propt in modutil.__dict__:
            modutil.__dict__[propt] = model[propt]
        else:
            print("Model property '%s' not found!"%(propt))
    
    #load / train the model
    modutil.load_train(trainloader, testloader, model['partial_load'])
    if model['train_model']:
        print("Saving model!")
        modutil.save_model()

    # param_list = param_chain if 'mcmc' in model['model_type'] else []
    # Perform evaluation on model
    modutil.validate("out-of-class", validloader, interesting_labels=interesting_labels)
    print("Validation (out of class) data analysis performed")

    modutil.validate("test", testloader)
    print("Test data analysis performed")


In [None]:
# directory = '/home/rahul/work/NN_with_prior_weights/swa-paper_experiments/saved_predictions/untitled/'
directory = 'saved_predictions/'
test_file = directory + models[0]['saved_checkpoint_name'] + '.testpred'
prob_cutoffs = [(i+1)/30.0 for i in range(30)]

with open(test_file, 'rb') as test_file:
    stats = pickle.load(test_file)
    targets = stats['targets']
    preds = stats['predictions']
    probs = stats['probs']
    ece, ece_avg = modutil.calculate_ECE(probs, preds, targets, ECE_bin=prob_cutoffs)

print(ece, ece_avg)

In [None]:
stats['loss']

In [None]:
modutil.encode_dump('GP+encoded164ResNet_CIFAR100_256.pt', trainloader)
modutil.encode_dump('GP+encoded164ResNet_CIFAR100_256_test', testloader, evalmode=True)
modutil.encode_dump('GP+encoded164ResNet_CIFAR100_256_valid', validloader1, evalmode=True)