In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import argparse
import os
import sys
import numpy as np
import torch
from matplotlib import pyplot as plt
import pandas as pd

In [None]:
module_path = os.path.abspath(os.path.join('/users/dli44/tool-presence'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src import constants as c
from src import utils
from src import visualization as v
from src import model as m

In [None]:
parser = utils.setup_argparse()
args = parser.parse_args(args=['--root=/users/dli44/tool-presence/',
                               '--data-dir=data/surgical_data/',
                               '--image-size=64',
                               '--loss-function=mmd'
                              ])

In [None]:
datasets, dataloaders = utils.setup_data(args, augmentation=False)

In [None]:
load_model = True
model_name = "mmd_beta_1.0_epoch_50.torch"
model_path = os.path.join(args.root, 'data/mmd_vae', model_name)

In [None]:
model = m.VAE(image_channels=args.image_channels,
              image_size=args.image_size,
              h_dim1=1024,
              h_dim2=128,
              zdim=args.z_dim).to(c.device)

In [None]:
model.load_state_dict(torch.load(model_path))

In [None]:
labels = pd.read_csv(os.path.join(args.root, args.data_dir, 'surgical_labels.csv'))

In [None]:
fig = plt.figure()
plt.title("Initial Images\nStart, End")
plt.imshow(np.hstack([datasets['val'][1][0].numpy().transpose(1,2,0), 
                      datasets['val'][9][0].numpy().transpose(1,2,0)]))

In [None]:
fig = plt.figure()
recon1, z, _, _ = model(datasets['val'][1][0].unsqueeze(0).to(c.device))
recon2, z, _, _ = model(datasets['val'][9][0].unsqueeze(0).to(c.device))

recon1 = utils.torch_to_image(recon1)
recon2 = utils.torch_to_image(recon2)

originals = np.hstack([utils.torch_to_image(datasets['val'][1][0]), 
                       utils.torch_to_image(datasets['val'][9][0])])
recons = np.hstack([recon1, recon2])

plt.imshow(np.vstack([originals, recons]))

In [None]:
images = v.latent_interpolation(datasets['val'][1][0], 
                                datasets['val'][9][0], 
                                model=model)

fig = v.plot_interpolation(images, "Interpolation\nBeta=5")

plt.savefig(os.path.join(args.root,
                         'data/mmd_vae',
                         'mmd_tool_motion.png'), bbox_inches='tight', dpi=400, pad_inches=0.0)

In [None]:
a = utils.torch_to_numpy(v.get_latent_vector(datasets['val'][1][0], model))[0]
b = utils.torch_to_numpy(v.get_latent_vector(datasets['val'][9][0], model))[0]
diff = a-b

In [None]:
fig = plt.figure()
plt.plot(a)
plt.plot(b)

In [None]:
fig = plt.figure()
plt.plot(a-b)

In [None]:
for zdim in range(8, 15):
    images = v.explore_latent_dimension(datasets['val'][1][0], model, zdim=9)
    v.plot_interpolation(images)

In [None]:
fig = plt.figure()
plt.title("Initial Images\nStart, End")
plt.imshow(np.hstack([datasets['train'][360][0].numpy().transpose(1,2,0), 
                      datasets['train'][368][0].numpy().transpose(1,2,0)]))

In [None]:
images = v.latent_interpolation(datasets['train'][360][0], 
                                datasets['train'][368][0], 
                                model=model)

fig = v.plot_interpolation(images, "Interpolation\nBeta=5")

plt.savefig(os.path.join(args.root,
                         'data/mmd_vae',
                         'mmd_tool_motion2.png'), bbox_inches='tight', dpi=400, pad_inches=0.0)