In [None]:
%matplotlib notebook 
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import os

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from src import constants as c
from src.model import VAE
from src import visualization as v

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import pandas as pd
from sklearn import decomposition, manifold

from tqdm import tqdm, tnrange, tqdm_notebook

In [None]:
import torch
from torchvision import datasets, transforms

In [None]:
data_transforms = transforms.Compose([
    transforms.Resize(c.image_size),
    transforms.CenterCrop(c.image_size),
    transforms.ToTensor()
])

image_datasets = {x: datasets.ImageFolder(os.path.join(c.data_home, 'surgical_data/',x),
                                          data_transforms)
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], 
                                              batch_size=c.batch_size,
                                              shuffle=True)
               for x in ['train', 'val']}

In [None]:
models = {zdim: VAE(image_channels=c.image_channels,
                    image_size=c.image_size, 
                    h_dim1=1024,
                    h_dim2=128,
                    zdim=zdim).to(c.device) for zdim in [10]}

In [None]:
for zdim, model in models.items():
    model.load_state_dict(torch.load(c.data_home + "weights/tools_vae_{}_epoch_50_zdim_{}.torch".format(c.image_size,
                                                                                                        zdim)))

In [None]:
labels = pd.read_csv(os.path.join(c.data_home, 'surgical_data/', 'surgical_labels.csv'))

In [None]:
encoded_inputs = {zdim: [] for zdim in [10]}

with torch.no_grad():
    for zdim in tqdm_notebook(encoded_inputs):
        for index in tnrange(len(image_datasets['train'])):
            data = image_datasets['train'][index][0].view(-1, c.image_channels, c.image_size, c.image_size).to(c.device)
            latent_vector = models[zdim].sampling(*models[zdim].encode(data)).cpu().detach().numpy()
            encoded_inputs[zdim].extend([ar[0] for ar in np.split(latent_vector, data.shape[0])])
            
        for index in tnrange(len(image_datasets['val'])):
            data = image_datasets['val'][index][0].view(-1, c.image_channels, c.image_size, c.image_size).to(c.device)
            latent_vector = models[zdim].sampling(*models[zdim].encode(data)).cpu().detach().numpy()
            encoded_inputs[zdim].extend([ar[0] for ar in np.split(latent_vector, data.shape[0])])

In [None]:
dataframes = {zdim: pd.concat([pd.DataFrame(encoded_inputs[zdim]), labels], axis=1) for zdim in [10]}

In [None]:
latent_space = pd.concat(dataframes)
latent_space.head()

In [None]:
for zdim, d in dataframes.items():
    d.to_csv(os.path.join(c.data_home, 'dataframes', 'encoded_inputs_{}_epoch_50_zdim_{}.csv'.format(c.image_size,
                                                                                                 zdim
                                                                                                )))

In [None]:
for zdim in [10]:
    components=5
    pca = decomposition.PCA(n_components=components)
    pca_result = pca.fit_transform(latent_space.loc[zdim][[0,1,2,3,4]].values)
    for i in range(components):
        latent_space.loc[zdim,'pc{}'.format(i+1)] = pca_result[:,i]
    print('z={}:\tExplained variation per principal component: {} {}'.format(zdim, sum(pca.explained_variance_ratio_),pca.explained_variance_ratio_))

# Testing Interpolation

## Similar Anatomy, tool enters frame

In [None]:
fig = plt.figure()
plt.title("Initial Images\nStart, End")
plt.imshow(np.hstack([image_datasets['train'][0][0].numpy().transpose(1,2,0), 
                      image_datasets['train'][10][0].numpy().transpose(1,2,0)]))

In [None]:
images = v.latent_interpolation(image_datasets['train'][0][0], image_datasets['train'][10][0], model=models[10])

fig, ax = plt.subplots(1,10, figsize=(10,2),
                       frameon=False,gridspec_kw={'wspace':0.05, 'width_ratios':[1.25,1,1,1,1,1,1,1,1,1.25]})
for i in range(10):
    ax[i].imshow(images[i])
    ax[i].axis('off')
ax[0].set_title("Start")
ax[-1].set_title("End")

plt.savefig(os.path.join(c.data_home,'figures','tool_different_anatomy_similar.png'), bbox_inches='tight', dpi=200, pad_inches=0.0)

In [None]:
fig = plt.figure()
plt.title("Initial Images\nStart, End")
plt.imshow(np.hstack([image_datasets['train'][1000][0].numpy().transpose(1,2,0),
                      image_datasets['train'][1200][0].numpy().transpose(1,2,0)]))

In [None]:
images = v.latent_interpolation(image_datasets['train'][1000][0], image_datasets['train'][1200][0], model=models[10])

fig = plt.figure(figsize=(10, 2))
plt.imshow(np.hstack(images))

## Tool present, different anatomy

In [None]:
fig = plt.figure()
plt.title("Initial Images\nStart, End")
plt.imshow(np.hstack([image_datasets['train'][10][0].numpy().transpose(1,2,0), 
                      image_datasets['train'][400][0].numpy().transpose(1,2,0)]))

In [None]:
images = v.latent_interpolation(image_datasets['train'][10][0], image_datasets['train'][400][0], model=models[10])

fig, ax = plt.subplots(1,10, figsize=(10,2),
                       frameon=False,gridspec_kw={'wspace':0.05, 'width_ratios':[1.25,1,1,1,1,1,1,1,1,1.25]})
for i in range(10):
    ax[i].imshow(images[i])
    ax[i].axis('off')
ax[0].set_title("Start")
ax[-1].set_title("End")
plt.savefig(os.path.join(c.data_home,'figures','tool_present_anatomy_different.png'),bbox_inches='tight', dpi=200, pad_inches=0.0)

In [None]:
fig = plt.figure()
plt.title("Initial Images\nStart, End")
plt.imshow(np.hstack([image_datasets['train'][400][0].numpy().transpose(1,2,0), 
                      image_datasets['train'][1000][0].numpy().transpose(1,2,0)]))

In [None]:
images = v.latent_interpolation(image_datasets['train'][10][0], image_datasets['train'][400][0], model=models[10])
fig = plt.figure(figsize=(10, 2))
plt.imshow(np.hstack(images))

## Tool initially present, different anatomy

In [None]:
fig = plt.figure()
plt.title("Initial Images\nStart, End")
plt.imshow(np.hstack([image_datasets['train'][10][0].numpy().transpose(1,2,0), 
                      image_datasets['train'][830][0].numpy().transpose(1,2,0)]))

In [None]:
images = v.latent_interpolation(image_datasets['train'][10][0], image_datasets['train'][830][0], model=models[10])

fig, ax = plt.subplots(1,10, figsize=(10,2),
                       frameon=False,gridspec_kw={'wspace':0.05, 'width_ratios':[1.25,1,1,1,1,1,1,1,1,1.25]})
for i in range(10):
    ax[i].imshow(images[i])
    ax[i].axis('off')
ax[0].set_title("Start")
ax[-1].set_title("End")
plt.savefig(os.path.join(c.data_home, 'figures', 'tool_different_anatomy_different.png'),bbox_inches='tight', dpi=200, pad_inches=0.0)

In [None]:
# plot it
f, (a0, a1) = plt.subplots(1,2, gridspec_kw = {'width_ratios':[1.25, 1]})
a0.imshow(image_datasets['train'][10][0].numpy().transpose(1,2,0))
a0.axis('off')
a1.imshow(image_datasets['train'][830][0].numpy().transpose(1,2,0))
a1.axis('off')
f.tight_layout()

In [None]:
images = v.latent_interpolation(image_datasets['train'][0][0], image_datasets['train'][830][0], model=models[10])

fig, ax = plt.subplots(1,10, figsize=(10,2),
                       frameon=False,gridspec_kw={'wspace':0.05, 'width_ratios':[1.25,1,1,1,1,1,1,1,1,1.25]})
for i in range(10):
    ax[i].imshow(images[i])
    ax[i].axis('off')
ax[0].set_title("Start")
ax[-1].set_title("End")
plt.savefig(os.path.join(c.data_home, 'figures', 'tool_none_anatomy_different.png'),bbox_inches='tight', dpi=200, pad_inches=0.0)

In [None]:
images = v.latent_interpolation(image_datasets['train'][0][0], image_datasets['train'][1200][0], model=models[10])

fig, ax = plt.subplots(1,10, figsize=(10,2),
                       frameon=False,gridspec_kw={'wspace':0.05, 'width_ratios':[1.25,1,1,1,1,1,1,1,1,1.25]})
for i in range(10):
    ax[i].imshow(images[i])
    ax[i].axis('off')
ax[0].set_title("Start")
ax[-1].set_title("End")
plt.savefig(os.path.join(c.data_home, 'figures', 'tool_different_anatomy_different.png'),bbox_inches='tight', dpi=200, pad_inches=0.0)

In [None]:
images = v.latent_interpolation(image_datasets['train'][1200][0], image_datasets['train'][830][0], model=models[10])

fig, ax = plt.subplots(1,10, figsize=(10,2),
                       frameon=False,gridspec_kw={'wspace':0.05, 'width_ratios':[1.25,1,1,1,1,1,1,1,1,1.25]})
for i in range(10):
    ax[i].imshow(images[i])
    ax[i].axis('off')
ax[0].set_title("Start")
ax[-1].set_title("End")
plt.savefig(os.path.join(c.data_home, 'figures', 'tool_different_anatomy_different.png'),bbox_inches='tight', dpi=200, pad_inches=0.0)

# Interpolate one latent dimension at a time

In [None]:
# np.set_printoptions(formatter={'float': '{: 0.3f}'.format})
images = v.latent_interpolation_by_dimension(image_datasets['train'][0][0], 
                                             image_datasets['train'][10][0], 
                                             model=models[10], 
                                             zdim=10)

fig = plt.figure()
plt.imshow(np.vstack(np.hstack(im) for im in images))
plt.axis('off')

# Interpolate ONLY last three dimensions

# Latent space vector addition

In [None]:
latent_start = v.get_latent_vector(image_datasets['train'][0][0], models[10])
latent_end = v.get_latent_vector(image_datasets['train'][10][0], models[10])

In [None]:
diff = latent_end - latent_start

In [None]:
result = models[10].decode(diff)
result = result.cpu().detach().numpy().squeeze().transpose(1,2,0)

In [None]:
fig = plt.figure()
plt.imshow(np.hstack([image_datasets['train'][0][0].numpy().transpose(1,2,0), image_datasets['train'][10][0].numpy().transpose(1,2,0), result]))

In [None]:
new_start = v.get_latent_vector(image_datasets['train'][830][0], models[10])
new_end = new_start + diff
result = models[10].decode(new_end)
result = result.cpu().detach().numpy().squeeze().transpose(1,2,0)

fig = plt.figure()
plt.imshow(np.hstack([image_datasets['train'][830][0].numpy().transpose(1,2,0), result]))