In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import argparse
import os
import sys
import numpy as np
import torch
from matplotlib import pyplot as plt
import pandas as pd

In [None]:
module_path = os.path.abspath(os.path.join('/users/dli44/tool-presence'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src import constants as c
from src import utils
from src import visualization as v
from src import model as m

In [None]:
parser = utils.setup_argparse()
args = parser.parse_args(args=['--root=/users/dli44/tool-presence/',
                               '--data-dir=data/cadaver_data/',
                               '--image-size=64',
                               '--loss-function=mmd',
                               '--z-dim=10'
                              ])

In [None]:
datasets, dataloaders = utils.setup_data(args, augmentation=False)

In [None]:
model = m.VAE(image_channels=args.image_channels,
              image_size=args.image_size,
              h_dim1=1024,
              h_dim2=128,
              zdim=args.z_dim).to(c.device)

In [None]:
load_model = True
model_name = "mmd_zdim10_beta_1.0_epoch_50.torch"
model_path = os.path.join(args.root, 'data/cadaver_mmd_vae', model_name)
if load_model:
    model.load_state_dict(torch.load(model_path))

In [None]:
from scipy.stats import norm
from scipy.special import logsumexp

def compute_samples(data, num_samples, debug=False):
    """ Sample from importance distribution z_samples ~ q(z|X) and
        compute p(z_samples), q(z_samples) for importance sampling
    """
#     dataloader_iterator = iter(dataloader)
#     try:
#         data, _ = next(dataloader_iterator)
#     except StopIteration:
#         dataloader_iterator = iter(dataloader)
#         data, target = next(dataloader_iterator)
    
    z_mean, z_log_sigma = model.encode(data.to(c.device))
    z_mean, z_log_sigma = utils.torch_to_numpy(z_mean), utils.torch_to_numpy(z_log_sigma)
    z_samples = []
    qz = []
    
    
    for m, s in zip(z_mean, z_log_sigma):
        z_vals = [np.random.normal(m[i], np.exp(s[i]), num_samples)
                  for i in range(len(m))]
        qz_vals = [norm.pdf(z_vals[i], loc=m[i], scale=np.exp(s[i]))
                  for i in range(len(m))]
        z_samples.append(z_vals)
        qz.append(qz_vals)
        
    
    z_samples = np.array(z_samples)
    pz = norm.pdf(z_samples)
    qz = np.array(qz)
    
    z_samples = np.swapaxes(z_samples, 1, 2)
    pz = np.swapaxes(pz, 1, 2)
    qz = np.swapaxes(qz, 1, 2)
    
    return z_samples, pz, qz

In [None]:
def estimate_logpx(dataloader, num_samples, debug=False):
    
    # Calculate importance sample
    # \log p(x) = E_p[p(x|z)]
    # = \log(\int p(x|z) p(z) dz)
    # = \log(\int p(x|z) p(z) / q(z|x) q(z|x) dz)
    # = E_q[p(x|z) p(z) / q(z|x)]
    # ~= \log(1/n * \sum_i p(x|z_i) p(z_i)/q(z_i))
    # = \log p(x) = \log(1/n * \sum_i e^{\log p(x|z_i) + \log p(z_i) - \log q(z_i)})
    # = \log p(x) = -\logn + \logsumexp_i(\log p(x|z_i) + \log p(z_i) - \log q(z_i))
    # See: scipy.special.logsumexp
    result = []
    for batch_idx, (data, _) in enumerate(dataloader):
        z_samples, pz, qz = compute_samples(data, num_samples)
        assert z_samples.shape == pz.shape
        assert pz.shape == qz.shape
        for i in range(len(data)):
            datum = utils.torch_to_numpy(data[i]).reshape(args.image_size * args.image_size * args.image_channels)
            x_predict = model.decode(torch.from_numpy(z_samples[i]).float().to(c.device))
            x_predict = utils.torch_to_numpy(x_predict).reshape(-1, args.image_size * args.image_size * args.image_channels)
            x_predict = np.clip(x_predict, np.finfo(float).eps, 1. - np.finfo(float).eps)
            p_vals = pz[i]
            q_vals = qz[i]

            # \log p(x|z) = Binary cross entropy
            logp_xz = np.sum(datum * np.log(x_predict) + (1. - datum) * np.log(1.0 - x_predict), axis=-1)
            logpz = np.sum(np.log(p_vals), axis=-1)
            logqz = np.sum(np.log(q_vals), axis=-1)
            argsum = logp_xz + logpz - logqz
            logpx = -np.log(num_samples) + logsumexp(argsum)
            result.append(logpx)
        
        if debug:
            print(x_predict.shape)
            print(p_vals.shape)
            print(q_vals.shape)
            print(logp_xz.shape)
            print(logpz.shape)
            print(logqz.shape)
            print("logp_xz", logp_xz)
            print("logpz", logpz)
            print("logqz", logqz)
            print(argsum.shape)
            print("logpx", logpx)
            
    return np.array(result)
            
logpx = estimate_logpx(dataloaders['val'], num_samples=64, debug=True)
#pass

In [None]:
np.nanmean(logpx)

In [None]:
print(-np.nanmean(logpx)/(args.image_size * args.image_size * args.image_channels))

In [None]:
labels = pd.read_csv(os.path.join(args.root, args.data_dir, 'surgical_labels.csv'))

In [None]:
fig = plt.figure()
plt.title("Initial Images\nStart, End")
plt.imshow(np.hstack([datasets['val'][1][0].numpy().transpose(1,2,0), 
                      datasets['val'][9][0].numpy().transpose(1,2,0)]))

In [None]:
fig = plt.figure()
recon1, z, _, _ = model(datasets['val'][1][0].unsqueeze(0).to(c.device))
recon2, z, _, _ = model(datasets['val'][9][0].unsqueeze(0).to(c.device))

recon1 = utils.torch_to_image(recon1)
recon2 = utils.torch_to_image(recon2)

originals = np.hstack([utils.torch_to_image(datasets['val'][1][0]), 
                       utils.torch_to_image(datasets['val'][9][0])])
recons = np.hstack([recon1, recon2])

plt.imshow(np.vstack([originals, recons]))

In [None]:
images = v.latent_interpolation(datasets['val'][1][0], 
                                datasets['val'][9][0], 
                                model=model)

fig = v.plot_interpolation(images, "Interpolation\nBeta=5")

plt.savefig(os.path.join(args.root,
                         'data/mmd_vae',
                         'mmd_tool_motion.png'), bbox_inches='tight', dpi=400, pad_inches=0.0)

In [None]:
a = utils.torch_to_numpy(v.get_latent_vector(datasets['val'][1][0], model))[0]
b = utils.torch_to_numpy(v.get_latent_vector(datasets['val'][9][0], model))[0]
diff = a-b

In [None]:
fig = plt.figure()
plt.plot(a)
plt.plot(b)

In [None]:
fig = plt.figure()
plt.plot(a-b)

In [None]:
fig = plt.figure()
plt.title("Initial Images\nStart, End")
plt.imshow(np.hstack([datasets['train'][360][0].numpy().transpose(1,2,0), 
                      datasets['train'][368][0].numpy().transpose(1,2,0)]))

In [None]:
images = v.latent_interpolation(datasets['train'][360][0], 
                                datasets['train'][368][0], 
                                model=model)

fig = v.plot_interpolation(images, "Interpolation\nBeta=5")

plt.savefig(os.path.join(args.root,
                         'data/mmd_vae',
                         'mmd_tool_motion2.png'), bbox_inches='tight', dpi=400, pad_inches=0.0)

In [None]:
images = v.explore_latent_dimension(datasets['train'][360][0], model, zdim=9)

In [None]:
v.plot_interpolation(images)