# Import Packages

In [None]:
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import MNIST
import os

from functionalities import filemanager as fm
from functionalities import dataloader as dl
from tqdm import tqdm_notebook as tqdm
from functionalities import loss as lo
from functionalities import plot as pl
from functionalities import trainer as tr
from functionalities import gpu 

from architecture import CelebA_autoencoder as celeba

# Pretraining Setup

In [None]:
num_epoch = 2 #10
batch_size = 32
lr_init = 1e-3
image_size = 178
milestones = [8, 9, 10]
latent_dim_lst = [128, 1000, 4000] #[2 ** x for x in range(0, 11, 2)]
number_dev = 0
get_model = celeba.celeba_autoencoder
modelname = "celeba_classic_bottleneck"

number_dev = 0
device = gpu.get_device(number_dev)
print(device)


# you need to manually download celebA dataset, have the celebA images save in folder ./img_align_celeba/dummy_class/...
IMAGE_PATH = './img_align_celeba/'
transform = transforms.Compose([
    #transforms.Scale(image_size),
    #transforms.Resize(image_size),
    transforms.ToTensor(),
    #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = ImageFolder(IMAGE_PATH, transform)
trainloader, testloader = dl.split_dataset(dataset, 0.2, batch_size, False)

# Training

In [None]:
tr.train_bottleneck_classic(num_epoch, get_model, modelname, milestones, latent_dim_lst, trainloader,
                     None, testloader, lr_init=lr_init, device=device)

# Plot Reconstruction and Difference Images Examples

In [None]:
from torch.autograd import Variable
import torchvision
import matplotlib.pyplot as plt

num_img = 25
grid_row_size = 5

img, label = next(iter(testloader))
#img = img.view(img.size(0), -1)
img = Variable(img).cuda()

for i in latent_dim_lst:
    print('bottleneck dimension: {}'.format(i))
    model = fm.load_model('{}_{}'.format(modelname, i)).to(device)
    output = model(img.to(device))

    original = pl.to_img(img.cpu().data, [3, 218, 178]) 
    pic = pl.to_img(output.cpu().data, [3, 218, 178])

    print("Original Image:")
    pl.imshow(torchvision.utils.make_grid(original[:num_img].detach(), grid_row_size), filename="com_classic_celeba_{}_original".format(i))
    print("Reconstructed Image:")
    pl.imshow(torchvision.utils.make_grid(pic[:num_img].detach(), grid_row_size), filename="com_classic_celeba_{}_reconstructed".format(i))
    print("Difference:")
    diff_img = (original - pic + 1) / 2
    pl.imshow(torchvision.utils.make_grid(diff_img[:num_img].detach(), grid_row_size), filename="com_classic_celeba_{}_difference".format(i))

# Plot Recontruction Loss against Bottleneck Size

In [None]:
train, test = fm.load_variable("{}_bottleneck".format(modelname))
y = [train, test]

pl.plot(latent_dim_lst, y, 'bottleneck size', 'loss', ['train', 'test'], 'Train & Test Reconstruction Loss History', 'loss_l1_celebA_bottleneck') 