# Import Packages

In [None]:
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from functionalities import dataloader as dl
from functionalities import tracker as tk
from architecture import INN as inn
from functionalities import MMD_autoencoder_loss as mmd_loss
from functionalities import trainer as tr
from functionalities import filemanager as fm
from functionalities import plot as pl
from functionalities import gpu 

# Pretraining Setup

In [None]:
num_epoch = 8
batch_size = 32
latent_dim_lst = [2 ** x for x in range(0, 11, 2)]
number_dev = 0
lr_init = 1e-3
l2_reg  = 1e-6
milestones = [6, 7, 8]
modelname = 'celeba_INN_com_bottleneck'
get_model = inn.celeba_inn_com
image_size = 128

device = gpu.get_device(number_dev)

In [None]:

IMAGE_PATH = './img_align_celeba/'
transform = transforms.Compose([
    #transforms.Scale(image_size),
    transforms.Resize(image_size),
    transforms.ToTensor(),
    #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = ImageFolder(IMAGE_PATH, transform)
trainloader, testloader = dl.split_dataset(dataset, 0.2, batch_size, False)

# Training

In [None]:
model = tr.train_bottleneck(num_epoch, get_model, 'l1', modelname, milestones, latent_dim_lst, trainloader, None, 
                            testloader, a_distr=0, a_disen=0, lr_init=lr_init, l2_reg=l2_reg, device=device, 
                            save_model=True, num_img=25, grid_row_size=5)

# Plot Reconstruction and Difference Images Examples

In [None]:
pl.plot_diff_all(get_model, modelname, num_epoch, testloader, latent_dim_lst, device='cpu', num_img=25, grid_row_size=5, figsize=(30, 30), 
              filename=None, conditional=False)

# Plot Recontruction Loss against Bottleneck Size

In [None]:
_, l1_rec_test, _, _, _ = fm.load_variable('bottleneck_test_loss_{}'.format(modelname), modelname)
_, l1_rec_train, _, _, _ = fm.load_variable('bottleneck_train_loss_{}'.format(modelname), modelname)

pl.plot(latent_dim_lst, [l1_rec_train, l1_rec_test], 'latent dimension', 'loss', ['train', 'test'], 'Test Reconstruction Loss History', '{}_bottleneck_History'.format(modelname)) 