# Import Packages

In [None]:
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import MNIST
import os

from functionalities import filemanager as fm
from functionalities import dataloader as dl
from tqdm import tqdm_notebook as tqdm
from functionalities import loss as lo
from functionalities import gpu 
from functionalities import plot as pl
from functionalities import trainer as tr

from architecture import CIFAR_autoencoder as cifar

# Pretrainin Setup

In [None]:
num_epoch = 100
batch_size = 128
lr_init = 1e-3
milestones = [60, 85, 100]
latent_dim_lst = [2 ** x for x in range(11)]
number_dev = 0
get_model = cifar.cifar_autoencoder
modelname = "cifar_classic_bottleneck"

device = gpu.get_device(number_dev)

In [None]:
trainset, testset, classes = dl.load_cifar()
trainloader, validloader, testloader = dl.make_dataloaders(trainset, testset, batch_size)

# Training 

In [None]:
tr.train_bottleneck_classic(num_epoch, get_model, modelname, milestones, latent_dim_lst, trainloader,
                     validloader, testloader, lr_init=lr_init, device=device)

# Plot Reconstruction and Difference Images Examples

In [None]:
from torch.autograd import Variable
import torchvision
import matplotlib.pyplot as plt

num_img = 100
grid_row_size = 10

img, label = next(iter(testloader))
#img = img.view(img.size(0), -1)
img = Variable(img).cuda()

for i in latent_dim_lst:
    print('bottleneck dimension: {}'.format(i))
    model = fm.load_model('{}_{}'.format(modelname, i)).to(device)
    output = model(img.to(device))

    original = pl.to_img(img.cpu().data, [3, 32, 32]) 
    pic = pl.to_img(output.cpu().data, [3, 32, 32])

    print("Original Image:")
    pl.imshow(torchvision.utils.make_grid(original[:num_img].detach(), grid_row_size), filename="com_classic_cifar_{}_original".format(i))
    print("Reconstructed Image:")
    pl.imshow(torchvision.utils.make_grid(pic[:num_img].detach(), grid_row_size), filename="com_classic_cifar_{}_reconstructed".format(i))
    print("Difference:")
    diff_img = (original - pic + 1) / 2
    pl.imshow(torchvision.utils.make_grid(diff_img[:num_img].detach(), grid_row_size), filename="com_classic_cifar_{}_difference".format(i))

# Plot Recontruction Loss against Bottleneck Size

In [None]:
train, test = fm.load_variable("{}_bottleneck".format(modelname))
y = [train, test]
x = []

pl.plot(latent_dim_lst, y, 'bottleneck size', 'loss', ['train', 'test'], 'Train & Test Reconstruction Loss History', 'loss_l1_cifar_lin_bottleneck') 