In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import datasets, transforms

from tqdm import tqdm_notebook, tnrange

from sklearn import decomposition, manifold

In [None]:
import sys
import os

module_path = os.path.abspath(os.path.join('/users/dli44/tool-presence'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from src import constants as c
# from src.model import VAE
from src.model import loss_function
from src import visualization as v
from src import utils

In [None]:
data_transforms = transforms.Compose([
    transforms.Resize(c.image_size),
    transforms.CenterCrop(c.image_size),
#     transforms.RandomHorizontalFlip(), #Trying dataset augmentation
#     transforms.RandomAffine(15),
    transforms.ToTensor()
])

image_datasets = {x: datasets.ImageFolder(os.path.join(c.data_home, 'surgical_data', x),
                                          data_transforms)
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], 
                                              batch_size=c.batch_size,
                                              shuffle=True)
               for x in ['train', 'val']}

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class VAE_1fc(nn.Module):
    def __init__(self, image_channels, image_size, h_dim2, zdim, conv_channels=[16, 16]):
        super(VAE_1fc, self).__init__()

        self.image_channels = image_channels
        self.image_size = image_size
        self.h_dim2 = h_dim2
        self.zdim = zdim
        self.conv_channels = conv_channels
        
        # Encoder
        self.conv1 = nn.Conv2d(image_channels, conv_channels[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(conv_channels[0], conv_channels[1], kernel_size=3, stride=1, padding=1, bias=False)
        self.pool1 = nn.MaxPool2d(2)
        
        # Latent vectors
        self.fc1 = nn.Linear(image_size//2 * image_size//2 * conv_channels[1], h_dim2)
        self.fc31 = nn.Linear(h_dim2, zdim)
        self.fc32 = nn.Linear(h_dim2, zdim)
        
        # Decoder
        self.fc3 = nn.Linear(zdim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, image_size//2 * image_size//2 * conv_channels[1])
        
        self.conv3 = nn.ConvTranspose2d(conv_channels[1], conv_channels[0], kernel_size=3, stride=1, padding=1, bias=False)
        self.conv4 = nn.ConvTranspose2d(conv_channels[0], image_channels, kernel_size=3, stride=1, padding=1, bias=False)

    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
    
    def encode(self, x):
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        x = F.dropout(self.pool1(x))
        x = x.view(-1, self.image_size//2 * self.image_size//2 * self.conv_channels[-1])
        x = F.relu(self.fc1(x))
        return self.fc31(x), self.fc32(x)
    
    def decode(self, z):
        z = F.relu(self.fc3(z))
        z = F.relu(self.fc5(z))
        z = z.view(-1, self.conv_channels[-1], self.image_size//2, self.image_size//2)
        z = F.interpolate(z, scale_factor=2)       
        z = F.relu(self.conv3(z))
        z = torch.sigmoid(self.conv4(z))
        return z
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.sampling(mu, log_var)
        return self.decode(z), mu, log_var

In [None]:
model = VAE_1fc(image_channels=c.image_channels,
                    image_size=c.image_size, 
                    h_dim2=128,
                    zdim=64).to(c.device)

In [None]:
model.load_state_dict(torch.load(c.data_home + "weights/beta_vae/1fc/beta_100_vae_64_epoch_50_zdim_64.torch"))

In [None]:
labels = pd.read_csv(os.path.join(c.data_home, 'surgical_data/', 'surgical_labels.csv'))

In [None]:
dataframes = {beta: pd.concat([pd.DataFrame(encoded_inputs[beta]), labels], axis=1) for beta in [5]}

In [None]:
components=3
pca = decomposition.PCA(n_components=components)
pca_result = pca.fit_transform(dataframes[beta].loc[:][list(range(64))].values)
for i in range(components):
    dataframes[beta]['pc{}'.format(i+1)] = pca_result[:,i]
print('beta={}:\tExplained variation per principal component: {} {}'.format(beta, sum(pca.explained_variance_ratio_),pca.explained_variance_ratio_))

## Rob's suggestions

In [None]:
import matplotlib.pyplot as plt

In [None]:
fig = plt.figure()
plt.title("Initial Images\nStart, End")
plt.imshow(np.hstack([image_datasets['val'][1][0].numpy().transpose(1,2,0), 
                      image_datasets['val'][9][0].numpy().transpose(1,2,0)]))

In [None]:
fig = plt.figure()
recon1, _, _ = model(image_datasets['val'][1][0].unsqueeze(0).to(c.device))
recon2, _, _ = model(image_datasets['val'][9][0].unsqueeze(0).to(c.device))

recon1 = utils.torch_to_numpy(recon1)
recon2 = utils.torch_to_numpy(recon2)

originals = np.hstack([utils.torch_to_numpy(image_datasets['val'][1][0]), 
                       utils.torch_to_numpy(image_datasets['val'][9][0])])
recons = np.hstack([recon1, recon2])

plt.imshow(np.vstack([originals, recons]))

In [None]:
images = v.latent_interpolation(image_datasets['val'][1][0], 
                                image_datasets['val'][9][0], 
                                model=model)

fig = v.plot_interpolation(images, "Interpolation\nBeta=100")
# fig, ax = plt.subplots(1,10, figsize=(10,2),
#                        frameon=False,gridspec_kw={'wspace':0.05, 'width_ratios':[1.25,1,1,1,1,1,1,1,1,1.25]})
# for i in range(10):
#     ax[i].imshow(images[i])
#     ax[i].axis('off')
# ax[0].set_title("Start")
# ax[-1].set_title("End")

plt.savefig(os.path.join(c.data_home,
                         'figures',
                         'augmentation_1fc',
                         'beta_100_tool_motion.png'), bbox_inches='tight', dpi=400, pad_inches=0.0)

In [None]:
a = v.get_latent_vector(image_datasets['train'][85][0], models[5]).cpu().detach().numpy()
b = v.get_latent_vector(image_datasets['train'][180][0], models[5]).cpu().detach().numpy()
diff = a-b

In [None]:
fig = plt.figure()
plt.plot(a[0])
plt.plot(b[0])
# plt.plot(diff[0])

In [None]:
np.abs(diff)

In [None]:
fig = plt.figure()
plt.title("Initial Images\nStart, End")
plt.imshow(np.hstack([image_datasets['train'][360][0].numpy().transpose(1,2,0), 
                      image_datasets['train'][368][0].numpy().transpose(1,2,0)]))

In [None]:
images = v.latent_interpolation(image_datasets['train'][360][0], image_datasets['train'][368][0], model=models[5])

# fig = v.plot_interpolation(images)
fig, ax = plt.subplots(1,10, figsize=(10,2),
                       frameon=False,gridspec_kw={'wspace':0.05, 'width_ratios':[1.25,1,1,1,1,1,1,1,1,1.25]})
for i in range(10):
    ax[i].imshow(images[i])
    ax[i].axis('off')
ax[0].set_title("Start")
ax[-1].set_title("End")

# plt.savefig(os.path.join(data_home,'figures','tool_different_anatomy_similar.png'), bbox_inches='tight', dpi=400, pad_inches=0.0)