In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import argparse
import os
import sys
import numpy as np
import torch
from matplotlib import pyplot as plt
import pandas as pd

In [None]:
module_path = os.path.abspath(os.path.join('/users/dli44/tool-presence'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src import constants as c
from src import utils
from src import visualization as v
from src import model as m

In [None]:
parser = utils.setup_argparse()
args = parser.parse_args(args=['--root=/users/dli44/tool-presence/',
                               '--data-dir=data/youtube_data/',
                               '--image-size=64',
                               '--loss-function=mmd',
                               '-a'
                              ])

In [None]:
datasets, dataloaders = utils.setup_data(args)

In [None]:
load_model = True
model_name = "mmd_beta_1.0_epoch_50.torch"
model_path = os.path.join(args.root, 'data/mmd_vae', model_name)

In [None]:
model = m.VAE(image_channels=args.image_channels,
              image_size=args.image_size,
              h_dim1=1024,
              h_dim2=128,
              zdim=64).to(c.device)

In [None]:
model.load_state_dict(torch.load(model_path))

In [None]:
labels = pd.read_csv(os.path.join(c.data_home, 'youtube_data/', 'surgical_labels.csv'))

In [None]:
encoded_inputs = {zdim: [] for zdim in args.z_dim.split(',')}

with torch.no_grad():
    for zdim in encoded_inputs:
        for index in range(len(datasets['train'])):
            data = datasets['train'][index][0].view(-1, args.image_channels, args.image_size, args.image_size).to(c.device)
            latent_vector = model.sampling(*model.encode(data)).cpu().detach().numpy()
            encoded_inputs[zdim].extend([ar[0] for ar in np.split(latent_vector, data.shape[0])])
            
        for index in range(len(datasets['val'])):
            data = datasets['val'][index][0].view(-1, args.image_channels, args.image_size, args.image_size).to(c.device)
            latent_vector = model.sampling(*model.encode(data)).cpu().detach().numpy()
            encoded_inputs[zdim].extend([ar[0] for ar in np.split(latent_vector, data.shape[0])])

In [None]:
latent_space = pd.concat([pd.DataFrame(encoded_inputs['64']), labels], axis=1)
latent_space.loc[:,list(range(64))]

In [None]:
from sklearn.decomposition import PCA

pca = PCA(0.99, whiten=True)
data = pca.fit_transform(latent_space.loc[:,list(range(64))])
data.shape

In [None]:
data

In [None]:
from sklearn.mixture import GaussianMixture as GMM

n_components = np.arange(50, 1000, 50)
models = [GMM(n, covariance_type='full', random_state=0)
          for n in n_components]
aics = [model.fit(data).aic(data) for model in models]
plt.plot(n_components, aics);

In [None]:
gmm = GMM(2, covariance_type='full', random_state=0)
gmm.fit(data)

In [None]:
gmm.predict(latent_space.loc[:,:64])

In [None]:
for (idx, row) in latent_space.iterrows():
    transformed_row = pca.transform(row[:64].values.reshape(-1,1))
    print(gmm.predict_proba(transformed_row))
    break