# Import Packages

In [None]:
import torch
from functionalities import dataloader as dl
from functionalities import tracker as tk
from architecture import INN as inn
from functionalities import CIFAR_coder_loss as cl
from functionalities import trainer as tr
from functionalities import filemanager as fm
from functionalities import plot as pl
from functionalities import gpu 

# Pretraining Setup

In [None]:
num_epoch = 15
batch_size = 128
latent_dim_lst = [2 ** x for x in range(11)]
#latent_dim = 400
number_dev = 0
lr_init = 1e-3
l2_reg  = 1e-6
milestones = [10, 15]
modelname = 'cifar_INN_glow_com_bottleneck'
get_model = inn.cifar_inn_com

device = gpu.get_device(number_dev)

In [None]:
trainset, testset, classes = dl.load_cifar()
trainloader, validloader, testloader = dl.make_dataloaders(trainset, testset, batch_size)

# Training

In [None]:
model = tr.train_bottleneck(num_epoch, get_model, 'l1', modelname, milestones, latent_dim_lst, trainloader, None, 
                            testloader, a_distr=0, a_disen=0, lr_init=lr_init, l2_reg=l2_reg, device=device, save_model=True)

# Plot Reconstruction and Difference Images Examples

In [None]:
for lat_dim in latent_dim_lst:
    print("Latent Dimension: ", lat_dim)
    model = fm.load_model('{}_{}_{}'.format(modelname, lat_dim, num_epoch), "{}_bottleneck".format(modelname))
    pl.plot_diff(model.to('cuda'), trainloader, lat_dim, device, 100, 10, filename='com_INN_cifar_{}'.format(lat_dim))

# Plot Recontruction Loss against Bottleneck Size

In [None]:
_, l1_rec_test, _, _, _ = fm.load_variable('bottleneck_test_loss_{}'.format(modelname), modelname)
_, l1_rec_train, _, _, _ = fm.load_variable('bottleneck_train_loss_{}'.format(modelname), modelname)

pl.plot(latent_dim_lst, [l1_rec_train, l1_rec_test], 'bottleneck size', 'loss', ['train', 'test'], 'Test Reconstruction Loss History', '{}_bottleneck_History'.format(modelname)) 

In [None]:
import numpy as np
from tqdm import tqdm_notebook as tqdm
def get_loss(loader, model, criterion, latent_dim, device='cpu'):
    """
    Compute the loss of a model on a train, test or evalutation set wrapped by a loader.

    :param loader: loader that wraps the train, test or evaluation set
    :param model: model that should be tested
    :param criterion: the criterion to compute the loss
    :param latent_dim: dimension of the latent space
    :param tracker: tracker for values during training
    :param device: device on which to do the computation (CPU or CUDA). Please use get_device() function to get the
    device, if using multiple GPU's. Default: cpu
    :return: losses
    """

    model.to(device)

    model.eval()

    losses = np.zeros(5, dtype=np.double)

    #tracker.reset()

    for i, data in enumerate(tqdm(loader), 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        with torch.no_grad():
            lat_img = model(inputs)
            lat_shape = lat_img.shape
            lat_img = lat_img.view(lat_img.size(0), -1)

            lat_img_mod = torch.cat([lat_img[:, :latent_dim], lat_img.new_zeros((lat_img[:, latent_dim:]).shape)], dim=1)
            lat_img_mod = lat_img_mod.view(lat_shape)

            output = model(lat_img_mod, rev=True)

            batch_loss = criterion(inputs, lat_img, output)

            for i in range(len(batch_loss)):
                losses[i] += batch_loss[i].item() * 100

     #       tracker.update(lat_img)

    losses /= len(loader)
    return losses


def get_loss_bottleneck(loader, modelname, subdir, latent_dim_lst, num_epoch, device, a_distr, a_rec, a_spar, a_disen):
    """


    :return:
    """

    total_loss = []
    rec_loss = []
    dist_loss = []
    spar_loss = []
    disen_loss = []

    for i in latent_dim_lst:
        print('bottleneck dimension: {}'.format(i))
        model = fm.load_model('{}_{}_{}'.format(modelname, i, num_epoch), subdir).to(device)
        criterion = cl.CIFAR_coder_loss(a_distr=a_distr, a_rec=a_rec, a_spar=a_spar, a_disen=a_disen, latent_dim=i, loss_type='l1', device=device)
        losses = get_loss(loader, model, criterion, i, device)
        total_loss.append(losses[0])
        rec_loss.append(losses[1])
        dist_loss.append(losses[2])
        spar_loss.append(losses[3])
        disen_loss.append(losses[4])

    return total_loss, rec_loss, dist_loss, spar_loss, disen_loss

In [None]:
x = latent_dim_lst
_, y, _, _, _ = get_loss_bottleneck(testloader, modelname, modelname + '_bottleneck', latent_dim_lst, num_epoch, device, 0, 1, 1, 0)
pl.plot(x, y, 'latent dimension', 'loss', 'l1', 'Test Reconstruction Loss History', '{}_bottleneck_History'.format(modelname)) 