In [1]:
# System
import os
import sys
import tabulate
import time

# Data processing
import numpy as np
import math as m

# Results presentation
from tqdm import tqdm_notebook as tqdm
from IPython.display import clear_output
import matplotlib
import matplotlib.pyplot as plt

# NN related stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
# from torch.autograd import Variable

import data
import models
import utils
import correlation


%matplotlib inline

In [2]:
count = 0
for i in range(690, 1011, 20):
    count += 1
count

17

In [3]:
class GlobalArguments():
    
    def __init__(self):
        self.model       = 'VGG16BN'
        self.dataset     = 'CIFAR10'
        self.data_path   = 'Data/'
        self.batch_size  = 128
        self.num_workers = 4
        self.transform   = 'VGG'
        self.use_test    = False
        self.models_path = 'Checkpoints/'
        self.n_models    = 17
args = GlobalArguments()

In [4]:
loaders, num_classes = data.loaders(
    args.dataset,
    args.data_path,
    args.batch_size,
    args.num_workers,
    args.transform,
    args.use_test
)

Files already downloaded and verified
Using train (45000) + validation (5000)
Files already downloaded and verified


In [5]:
architecture = getattr(models, args.model)

In [7]:
# models_list_ind = []
# indicies = np.arange(args.n_models)
# # np.random.shuffle(indicies[1:-1])

# for i in indicies:
#     model = architecture.base(num_classes=num_classes, **architecture.kwargs)
#     checkpoint = torch.load(
#           args.models_path
#         + args.model + '_'
#         + args.dataset + '_'
#         + str(i) + '/checkpoint-200.pt',
#         map_location=torch.device('cpu'))
# #     print (model)
#     model.load_state_dict(checkpoint['model_state'])
#     models_list_ind.append(model)

In [8]:
models_list_fge = []
# indicies = range (690, 1011, 20)
indicies = range (210, 1011, 20)

for i in indicies:
    model = architecture.base(num_classes=num_classes, **architecture.kwargs)
    checkpoint = torch.load('./Checkpoints/FGE_VGG16BN_CIFAR10_CYCLE8_3/fge-' + str(i) + '.pt', map_location=torch.device('cpu'))
#     print (model)
    model.load_state_dict(checkpoint['model_state'])
    models_list_fge.append(model)

In [9]:
len(models_list_fge)

41

In [10]:
needed_models = [models_list_fge[0],
                 models_list_fge[8],
                 models_list_fge[16],
                 models_list_fge[24],
                 models_list_fge[32],
                 models_list_fge[-1]]

In [11]:
def ensemble_models (model_list, dataloader, device=torch.device('cpu')):
    predictions = []
    target = []
    n_models = len(model_list)
    
    for _, labels in dataloader:
        target.append(labels.detach())
    target = torch.cat(target, dim=0)
    
    for iter, model in tqdm(enumerate(model_list)):
        model.eval().to(device)
        predictions.append([])
        for input, _ in dataloader:
            predictions[iter].append(model(input.detach().to(device)).detach().cpu())
        predictions[iter] = torch.cat(predictions[iter], dim=0)
        
    print (predictions[0].shape, predictions[0].dtype)
    print (target.shape, target.dtype)
    
    sum_prediction = torch.zeros_like(predictions[0])
    acc_list = []
    ens_acc_list = []
    for prediction in tqdm(predictions):
        sum_prediction += prediction
        pred_labels     = prediction    .argmax(dim=1)
        ens_pred_labels = sum_prediction.argmax(dim=1)
        
        acc     = torch.eq(pred_labels    , target).sum().item() / len(target)
        ens_acc = torch.eq(ens_pred_labels, target).sum().item() / len(target)
        print ('Accuracy: ', acc, 'Ensemble accuracy: ', ens_acc)
        acc_list    .append(acc)
        ens_acc_list.append(ens_acc)
        
    return acc_list, ens_acc_list

In [None]:
cor_matrix = correlation.cross_entropy_Nmodels (needed_models, loaders['test'], torch.device('cpu'))

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

In [None]:
# acc_list_ind, ens_acc_list_ind = ensemble_models(models_list_ind, loaders['train'], torch.device('cpu'))

In [None]:
acc_list_fge, ens_acc_list_fge = ensemble_models(models_list_fge, loaders['train'], torch.device('cpu'))

In [None]:
def plot (y_mas, time, savefig=None):

    fig, ax = plt.subplots()
    for iter, y in enumerate(y_mas):
        ax.plot(time, y, label=str(iter))

    ax.set(xlabel='Number of models', ylabel='Accuracy',
           title='Ensembling methods comparison')
    ax.grid()

    if savefig is not None:
        fig.savefig(savefig)
    plt.legend()
    plt.show()

In [None]:
plot ([acc_list_ind, acc_list_fge, ens_acc_list_ind, ens_acc_list_fge], np.arange(len(acc_list_ind)))
# plot ([acc_list_fge, ens_acc_list_fge], np.arange(len(acc_list_fge)))
