In [None]:
# !pip install notify_run
%matplotlib notebook 
%load_ext autoreload
%autoreload 2

# from google.colab import drive
# drive.mount('/content/gdrive')

# Setup

In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image

from tqdm import tqdm, tnrange, tqdm_notebook
from notify_run import Notify

In [None]:
notify = Notify()
# notify.register()

In [None]:
import sys
import os

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
from src import constants as c
from src.model import VAE

# Dataset Initialization

In [None]:
#Tools Dataset
dataset = torchvision.datasets.ImageFolder(os.path.join(c.data_home, "surgical_data/"),
                                           transform=transforms.Compose([
                                               transforms.Resize(c.image_size),
                                               transforms.CenterCrop(c.image_size),
                                               transforms.ToTensor()
                                           ]))

dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(c.validation_split * dataset_size))

np.random.seed(c.seed)
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)


train_loader = DataLoader(dataset=dataset,
                          batch_size=c.batch_size, 
                          shuffle=False,
                          sampler=train_sampler)

test_loader = DataLoader(dataset=dataset,
                         batch_size=c.batch_size, 
                         sampler=valid_sampler,
                         shuffle=False)

# Training setup

In [None]:
def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

In [None]:
for zdim in [2,5,10,15]:

    model = VAE(image_channels=c.image_channels,
                image_size=c.image_size, 
                h_dim1=1024,
                h_dim2=128,
                zdim=zdim).to(c.device)
    
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    for epoch in tqdm_notebook(range(50)):
        notify.send("z-dim = {}, Training Epoch {}".format(zdim, epoch+1))
        
        """
        Training
        """
        model.train()
        train_loss = 0
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(c.device)
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(data)
            loss = loss_function(recon_batch, data, mu, logvar)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()

            if batch_idx % c.log_interval == 0:
                tqdm.write('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader),
                           loss.item() / len(data)))

        tqdm.write('====> Epoch: {} Average loss: {:.4f}'.format(
            epoch, train_loss / len(train_loader.dataset)))

        """
        Testing
        """
        model.eval()
        test_loss = 0
        with torch.no_grad():
            for i, (data, _) in enumerate(test_loader):
                data = data.to(c.device)
                recon_batch, mu, logvar = model(data)
                test_loss += loss_function(recon_batch, data, mu, logvar).item()
                if i == 0:
                    n = min(data.size(0), 8)
                    comparison = torch.cat([data[:n],
                                            recon_batch.view(batch_size, 3, 64, 64)[:n]])

                    save_image(comparison.cpu(),
                               data_home + 'samples/reconstruction_epoch_{}_zdim_{}.png'.format(epoch, zdim), nrow=n)
        
    torch.save(model.state_dict(), c.data_home + "weights/mnist_vae_{}_epochs_{}_zdim_{}.torch".format(c.image_size, 
                                                                                                     epochs, 
                                                                                                     zdim))

    with torch.no_grad():
        z = torch.randn(64, zdim)
        sample = model.decode(z.to(c.device))
        save_image(sample.view(64, 
                               c.image_channels, 
                               c.image_size, 
                               c.image_size).cpu(), 
                   c.data_home + 'figures/latent_space_random_{}_epochs_{}_zdim_{}.png'.format(c.image_size, 
                                                                                            epochs, 
                                                                                            zdim))

# Model evaluation

In [None]:
import pandas as pd
from sklearn import decomposition, manifold
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

In [None]:
models = {zdim: VAE(image_channels=c.image_channels,
                    image_size=c.image_size, 
                    h_dim1=1024,
                    h_dim2=128,
                    zdim=zdim).to(c.device) for zdim in [5]}

In [None]:
models[5]

In [None]:
for zdim, model in models.items():
    model.load_state_dict(torch.load(c.data_home + "weights/tools_64_epochs_50_zdim_{}.torch".format(zdim)))

## Saving latent space encodings

In [None]:
encoded_inputs = {zdim: [] for zdim in [5]}

with torch.no_grad():
    for zdim in tqdm_notebook(encoded_inputs):
        z = torch.randn(64, zdim)
        for batch_idx, (data, _) in enumerate(tqdm_notebook(train_loader)):
            data = data.to(c.device)
            latent_vector = models[5].sampling(*models[5].encode(data)).cpu().detach().numpy()
            encoded_inputs[zdim].extend([ar[0] for ar in np.split(latent_vector, data.shape[0])])

In [None]:
latent_space = pd.concat({zdim: pd.DataFrame(encoded_inputs[zdim]) for zdim in [5]})
latent_space.head()

## PCA on latent space

In [None]:
pca = decomposition.PCA(n_components=3)

In [None]:
for zdim in [5]:
    pca = decomposition.PCA(n_components=3)
    pca_result = pca.fit_transform(latent_space.loc[zdim][[0,1,2,3,4]].values)
    latent_space.loc[zdim,'pc1'] = pca_result[:,0]
    latent_space.loc[zdim,'pc2'] = pca_result[:,1]
    latent_space.loc[zdim,'pc3'] = pca_result[:,2]
    print('z={}:\tExplained variation per principal component: {} {}'.format(zdim, sum(pca.explained_variance_ratio_),pca.explained_variance_ratio_))

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(latent_space.loc[5]['pc1'], latent_space.loc[5]['pc2'], latent_space.loc[5]['pc3'])
plt.title("Scatter Plot of Principal Components of 5-dimensional latent space")
plt.show()

In [None]:
def imscatter(x, y, ax, imageData, zoom):
    images = []
    for i in range(len(x)):
        x0, y0 = x[i], y[i]
        # Convert to image
        img = imageData[i]*255.
        img = img.astype(np.uint8).reshape([c.image_size,c.image_size])

        image = OffsetImage(img, zoom=zoom)
        ab = AnnotationBbox(image, (x0, y0), xycoords='data', frameon=False)
        images.append(ax.add_artist(ab))
    
    ax.update_datalim(np.column_stack([x, y]))
    ax.autoscale()

## T-SNE on latent space

In [None]:
for zdim in [5]:
    tsne_result = manifold.TSNE(n_components=2).fit_transform(latent_space.loc[zdim][[0,1,2,3,4]])
    latent_space.loc[zdim, 'tsne1'] = tsne_result[:,0]
    latent_space.loc[zdim, 'tsne2'] = tsne_result[:,1]

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(latent_space.loc[5]['tsne1'], latent_space.loc[5]['tsne2'])