In [None]:
%matplotlib notebook 
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import os

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from src import constants as c
from src.model import VAE
from src import visualization as v

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import pandas as pd
from sklearn import decomposition, manifold

from tqdm import tqdm, tnrange, tqdm_notebook

In [None]:
import torch
from torchvision import datasets, transforms

In [None]:
data_transforms = transforms.Compose([
    transforms.Resize(c.image_size),
    transforms.CenterCrop(c.image_size),
    transforms.ToTensor()
])

image_datasets = {x: datasets.ImageFolder(os.path.join(c.data_home, 'surgical_data/',x),
                                          data_transforms)
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], 
                                              batch_size=c.batch_size,
                                              shuffle=True)
               for x in ['train', 'val']}

In [None]:
models = {zdim: VAE(image_channels=c.image_channels,
                    image_size=c.image_size, 
                    h_dim1=1024,
                    h_dim2=128,
                    zdim=zdim).to(c.device) for zdim in [10]}

In [None]:
for zdim, model in models.items():
    model.load_state_dict(torch.load(c.data_home + "weights/tools_vae_{}_epoch_50_zdim_{}.torch".format(c.image_size,
                                                                                                        zdim)))

In [None]:
labels = pd.read_csv(os.path.join(c.data_home, 'surgical_data/', 'surgical_labels.csv'))

In [None]:
encoded_inputs = {zdim: [] for zdim in [10]}

with torch.no_grad():
    for zdim in tqdm_notebook(encoded_inputs):
        for index in tnrange(len(image_datasets['train'])):
            data = image_datasets['train'][index][0].view(-1, c.image_channels, c.image_size, c.image_size).to(c.device)
            latent_vector = models[zdim].sampling(*models[zdim].encode(data)).cpu().detach().numpy()
            encoded_inputs[zdim].extend([ar[0] for ar in np.split(latent_vector, data.shape[0])])
            
        for index in tnrange(len(image_datasets['val'])):
            data = image_datasets['val'][index][0].view(-1, c.image_channels, c.image_size, c.image_size).to(c.device)
            latent_vector = models[zdim].sampling(*models[zdim].encode(data)).cpu().detach().numpy()
            encoded_inputs[zdim].extend([ar[0] for ar in np.split(latent_vector, data.shape[0])])

In [None]:
dataframes = {zdim: pd.concat([pd.DataFrame(encoded_inputs[zdim]), labels], axis=1) for zdim in [10]}

In [None]:
latent_space = pd.concat(dataframes)
latent_space.head()

In [None]:
for zdim, d in dataframes.items():
    d.to_csv(os.path.join(c.data_home, 'dataframes', 'encoded_inputs_{}_epoch_50_zdim_{}.csv'.format(c.image_size,
                                                                                                 zdim
                                                                                                )))

In [None]:
for zdim in [10]:
    components=5
    pca = decomposition.PCA(n_components=components)
    pca_result = pca.fit_transform(latent_space.loc[zdim][[0,1,2,3,4]].values)
    for i in range(components):
        latent_space.loc[zdim,'pc{}'.format(i+1)] = pca_result[:,i]
    print('z={}:\tExplained variation per principal component: {} {}'.format(zdim, sum(pca.explained_variance_ratio_),pca.explained_variance_ratio_))

In [None]:
fig = v.plot_pca(5, latent_space.loc[10])
fig.savefig('./pca_zdim_10.png')

In [None]:
for zdim in [10]:
    tsne_result = manifold.TSNE(n_components=2).fit_transform(latent_space.loc[zdim][[0,1,2,3,4]])
    latent_space.loc[zdim, 'tsne1'] = tsne_result[:,0]
    latent_space.loc[zdim, 'tsne2'] = tsne_result[:,1]

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)
scatter = ax.scatter(latent_space.loc[10]['tsne1'], latent_space.loc[10]['tsne2'], c=latent_space.loc[10]['Tool'], cmap='bwr', alpha=0.2)
plt.show()

In [None]:
fig, ax = plt.subplots()
v.imscatter(latent_space.loc[10]['tsne1'], 
          latent_space.loc[10]['tsne2'], 
          data=image_datasets['train'], 
          ax=ax, 
          zoom=0.25)

plt.show()

In [None]:
from scipy.stats import norm
resolution=15
zdim=10
dimensions=[0,1]
u_grid = np.dstack(np.meshgrid(np.linspace(0.05, 0.95, resolution),
                                   np.linspace(0.05, 0.95, resolution)))
z_grid = norm.ppf(u_grid)

sampled = z_grid.reshape(resolution*resolution, 2)
result = np.zeros((resolution*resolution, zdim))
result[:sampled.shape[0], dimensions[0]] = sampled[:,0]
result[:sampled.shape[0], dimensions[1]] = sampled[:,1]

x_decoded = model.decode(torch.from_numpy(result).to(c.device).float())
x_decoded = x_decoded.reshape(resolution, resolution, 3, c.image_size, c.image_size)

In [None]:
%matplotlib notebook
plt.axis('off')
fig, ax = plt.subplots(frameon=False, dpi=300)
ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])

im = ax.imshow(np.block(list(map(list, x_decoded.detach().cpu().numpy()))).transpose(1,2,0))
fig.savefig('latent-space.png', bbox_inches='tight', pad_inches=0)
