In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pystan
import pickle
import seaborn as sns

import torch
import os
import sys

module_path = os.path.abspath(os.path.join('/users/dli44/tool-presence'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src import constants as c
from src import utils
from src import visualization as v
from src import model as m

In [None]:
np.random.seed(101)
matplotlib.rc('text', usetex=True)
matplotlib.rcParams['figure.dpi'] = 200

In [None]:
test_images =['0686',
              '0687',
              '0688',
              '0689',
              '0690',
              '0690',
              '0691',
              '0693',
              '0694',
              '0695',
              '0696',
              '0697',
              '0698',
              '0699',
              '0700',
              '0703',
              '0704']

images = ['../data/youtube_data/train/images/frame_{}.png'.format(number) for number in test_images]

In [None]:
train_data_file = "../mmd/csv/beta_10.0_zdim_80_train.csv"
train = pd.read_csv(train_data_file, index_col=0)    

In [None]:
from imageio import imread
from PIL import Image
from torchvision import transforms


model_file = '../mmd/weights/final_beta_10.0_zdim_80_epoch_80.torch'
model = m.VAE(image_channels=3,
                  image_size=64,
                  h_dim1=1024,
                  h_dim2=128,
                  zdim=80).to(c.device)
model.load_state_dict(torch.load(model_file, map_location='cpu'))

tform = transforms.Compose([transforms.Resize(64), transforms.CenterCrop(64), transforms.ToTensor()])
encodings = []

for image in images:
    im = tform(Image.fromarray(imread(image))).unsqueeze(0)
    enc = utils.torch_to_numpy(model.sampling(*model.encode(im)))
    encodings.append(enc.squeeze())
    
encodings = np.array(encodings)

In [None]:
print(encodings.shape)

In [None]:
labels_file = '../data/youtube_data/train/labels.csv'

test = pd.DataFrame(encodings)
labels = pd.read_csv(labels_file, skiprows=lambda x: x in range(1,548), header=0, index_col=0).reset_index(drop=True)

labels = pd.concat([test, labels], axis=1).dropna()

In [None]:
labels

In [None]:
recompile=False
refit=True
vb = True
compiled_model="../model.pkl"
compiled_fit='../fit_vb.pkl' if vb else "../fit.pkl"

data = {"N": len(train.index),
        "x": train,
        "K": 2,
        "D": len(train.columns)}

if recompile:
    sm = pystan.StanModel(file=model)
    with open(compiled_model, 'wb') as f:
        pickle.dump(sm, f)
else:
    with open(compiled_model, 'rb') as f:
        sm = pickle.load(f)

if refit:
    if vb:
        fit = sm.vb(data=data, algorithm='meanfield')
    else:    
        fit = sm.sampling(data=data, iter=5000, chains=4, thin=1)
    with open(compiled_fit, 'wb') as f:
        pickle.dump(fit, f)
else:
    with open(compiled_fit, 'rb') as f:
        fit = pickle.load(f)

In [None]:
result = utils.pystan_vb_extract(fit)

In [None]:
c,a,f = utils.get_inference_results(result, labels)
print(c, a, f)