In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import argparse
import os
import sys
import numpy as np
import torch
from matplotlib import pyplot as plt
import pandas as pd

In [None]:
module_path = os.path.abspath(os.path.join('/users/dli44/tool-presence'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src import constants as c
from src import utils
from src import visualization as v
from src import model as m

In [None]:
parser = utils.setup_argparse()
args = parser.parse_args(args=['--root=/users/dli44/tool-presence/',
                               '--data-dir=data/youtube_data/',
                               '--image-size=64',
                               '--loss-function=mmd',
                               '--z-dim=64'
                              ])
datasets, dataloaders = utils.setup_data(args, augmentation=False)

In [None]:
model_paths = ['data/mmd_vae/mmd_beta_1.0_epoch_50.torch', #mmd_vae
          'data/weights/beta_vae/augmentation_2fc/augment_beta_1_2fc_epoch_50.torch', #standard vae
          'data/weights/beta_vae/augmentation_2fc/augment_beta_5_2fc_epoch_50.torch', #beta = 5
          'data/weights/beta_vae/augmentation_2fc/augment_beta_20_2fc_epoch_50.torch',#beta = 20
          'data/weights/beta_vae/augmentation_2fc/augment_beta_50_2fc_epoch_50.torch',#beta = 50
         ]

In [None]:
del model
torch.cuda.empty_cache()
gc.collect()

In [None]:
import gc
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            print(type(obj), obj.size())
    except:
        pass

In [None]:
pxs = []

for path in model_paths:
    model = m.VAE(image_channels=args.image_channels,
              image_size=args.image_size,
              h_dim1=1024,
              h_dim2=128,
              zdim=args.z_dim).to(c.device)
    model.load_state_dict(torch.load(os.path.join(args.root, path)))
    print(path)

    logpx = utils.estimate_logpx(dataloaders['val'], model, args, 128)
    pxs.append(np.nanmean(logpx))
    
    # Free GPU memory
    del model
    torch.cuda.empty_cache()
    foo = range(10000000)
    del foo
    gc.collect()

In [None]:
pxs

In [None]:
load_model = True
model_name = "mmd_beta_1.0_epoch_50.torch"
model_path = os.path.join(args.root, 'data/mmd_vae', model_name)
if load_model:
    model.load_state_dict(torch.load(model_path))

In [None]:
fig = plt.figure()
plt.title("Initial Images\nStart, End")
plt.imshow(np.hstack([datasets['val'][1][0].numpy().transpose(1,2,0), 
                      datasets['val'][9][0].numpy().transpose(1,2,0)]))

In [None]:
fig = plt.figure()
recon1, z, _, _ = model(datasets['val'][1][0].unsqueeze(0).to(c.device))
recon2, z, _, _ = model(datasets['val'][9][0].unsqueeze(0).to(c.device))

recon1 = utils.torch_to_image(recon1)
recon2 = utils.torch_to_image(recon2)

originals = np.hstack([utils.torch_to_image(datasets['val'][1][0]), 
                       utils.torch_to_image(datasets['val'][9][0])])
recons = np.hstack([recon1, recon2])

plt.imshow(np.vstack([originals, recons]))

In [None]:
images = v.latent_interpolation(datasets['val'][1][0], 
                                datasets['val'][9][0], 
                                model=model)

fig = v.plot_interpolation(images, "Interpolation\nBeta=5")

plt.savefig(os.path.join(args.root,
                         'data/mmd_vae',
                         'mmd_tool_motion.png'), bbox_inches='tight', dpi=400, pad_inches=0.0)

In [None]:
a = utils.torch_to_numpy(v.get_latent_vector(datasets['val'][1][0], model))[0]
b = utils.torch_to_numpy(v.get_latent_vector(datasets['val'][9][0], model))[0]
diff = a-b

In [None]:
for zdim in range(args.z_dim):
    images = v.explore_latent_dimension(datasets['val'][1][0], model, zdim=zdim)
    v.plot_interpolation(images, title='z = {}'.format(zdim))