In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pystan
import pickle
import seaborn as sns

import os
import sys

module_path = os.path.abspath(os.path.join('/users/dli44/tool-presence'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src import constants as c
from src import utils
from src import visualization as v
from src import model as m

In [None]:
np.random.seed(101)
matplotlib.rc('text', usetex=True)
matplotlib.rcParams['figure.dpi'] = 200

In [None]:
train_data_file = "../mmd/csv/beta_10.0_zdim_80_train.csv"
test_data_file = "../mmd/csv/beta_10.0_zdim_80_test.csv"
test_labels_file = '../data/youtube_data/val/labels.csv'

train = pd.read_csv(train_data_file, index_col=0)
test = pd.read_csv(test_data_file, index_col=0)
test_labels = pd.read_csv(test_labels_file, index_col=0)
test_labels = pd.concat([test, test_labels], axis=1).dropna()

In [None]:
recompile=False
refit=True
vb = True
model='../model.stan'
compiled_model="../model.pkl"
compiled_fit='../fit_vb.pkl' if vb else "../fit.pkl"

data = {"N": len(train.index),
        "N2": len(test_labels),
        "x": train,
        "x_test": test_labels.values[:, :10],
        "K": 2,
        "D": len(train.columns)}

In [None]:
if recompile:
    sm = pystan.StanModel(file=model)
    with open(compiled_model, 'wb') as f:
        pickle.dump(sm, f)
else:
    with open(compiled_model, 'rb') as f:
        sm = pickle.load(f)

In [None]:
if refit:
    if vb:
        fit = sm.vb(data=data, algorithm='meanfield')
    else:    
        fit = sm.sampling(data=data, iter=5000, chains=4, thin=1)
    with open(compiled_fit, 'wb') as f:
        pickle.dump(fit, f)
else:
    with open(compiled_fit, 'rb') as f:
        fit = pickle.load(f)

In [None]:
# result = fit.extract()
result = utils.pystan_vb_extract(fit)

In [None]:
# len(np.log(result['theta'][-1]))
test_labels

In [None]:
c,a,f = utils.get_inference_results(result, test_labels)
print(c, a, f)

In [None]:
sns.heatmap(c.astype('float') / c.sum(axis=1)[:, np.newaxis], 
            cmap=sns.color_palette("Blues"),
            xticklabels=['Cluster 1', 'Cluster 2'], 
            yticklabels=['No Tool', 'Tool'], 
            annot=c, annot_kws={"size": 28}, 
            fmt='g',cbar=False)

plt.ylabel("Predictions")
plt.xlabel("Actual")
# plt.title('MMD-VAE Confusion Matrix\n' + r"$\lambda=1, z=10$")
# plt.title(r"$\beta$" "-VAE Confusion Matrix\n" + 
#           r"$\beta=10, z=10$");
# plt.savefig('beta_vae_beta10_zdim_10_confusion.png')

In [None]:
fig = plt.figure()
plt.hist(result['mu'].flatten(), bins=50);
# plt.title("Posterior distribution\n"+
# #           r"MMD-VAE $\lambda=1, z=10$")
#           r"$\beta$-VAE $\beta=10, z=10$")
plt.ylabel('Frequency')
plt.xlabel(r'$\mu$')
# plt.savefig('vb_beta10_zdim_10.png')
# plt.savefig('vb_mmd_lambda1_zdim_10.png')

In [None]:
# Accuracy?????

#elbo beta = 1: 60.9%
#mmd lambda=1: 68%
#mmd lambda=10: 58%

In [None]:
c = np.array([[27,71],[50,159]])