In [None]:
from tqdm import tqdm
import numpy as np
import pandas as pd
import time
import datetime
import sys
import matplotlib.pyplot as plt
import seaborn as sns
import os
import shutil
import json
import pickle
import itertools
from scipy.signal import savgol_filter

np.random.seed(42)

from formats import experiment_pb2
from formats import  quantification_pb2

from skimage import io
import pandas as pd
import utils


from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler

import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset, ChainDataset


import pyro
import pyro.distributions as dist
import pyro.poutine
from pyro.infer import MCMC, NUTS
import math
import torch.nn as nn
import torch.nn.functional as F
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro.infer.autoguide.guides import AutoDiagonalNormal
import pyro.distributions.constraints as constraints
from tqdm import trange

import utils



In [None]:
local_radius_px = 5

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 


In [None]:
import data
dataset = data.get_dataset(local_radius_px=local_radius_px)
data_loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=512)

In [None]:
from importlib import reload
import models
reload(models)
from models import FusionModel, NaiveFusionModel, JointVAE

In [None]:
def eval_recon(x_hat,q_hat, x, q):
    x_flat = x.reshape(-1,x.shape[-1])
    x_hat = x_hat.reshape(-1,x.shape[-1])

    x_mean = x_flat.mean(0)
    x_ssr = (x_flat - x_hat).pow(2).sum()
    x_sst = (x_flat - x_mean).pow(2).sum()
    x_r2 = 1 - x_ssr/x_sst

    q_mean = q.mean(0)
    q_ssr = (q - q_hat).pow(2).sum()
    q_sst = (q - q_mean).pow(2).sum()
    q_r2 = 1 - q_ssr/q_sst

    return x_r2.item(),q_r2.item()

In [None]:
def eval_model(z,c,m,h=256,d=4,r=5,dir='models'):
    model = torch.load(f'{dir}/{m.__name__}-z-{z}-c-{c}-h-{h}-d-{d}-r-{r}.pt').eval().to(device)    
    x , q = [], []
    # c, z = [], []
    x_hat, q_hat = [], []
    for batch_x,batch_q in tqdm(data_loader,total=1+(len(dataset)//512),leave=False):
        batch_x = batch_x.to(device)
        batch_q = batch_q.to(device)
        # batch_c, batch_z = model.encode(batch_x,batch_q)
        batch_x_hat,  batch_q_hat = model.reconstruct(batch_x,batch_q)
        x.append(batch_x.cpu())
        q.append(batch_q.cpu())
        # c.append(batch_c.cpu())
        # z.append(batch_z.cpu())
        x_hat.append(batch_x_hat.detach().cpu())
        q_hat.append(batch_q_hat.detach().cpu())

    x = torch.cat(x,dim=0).detach().cpu()
    q = torch.cat(q,dim=0).detach().cpu()
    # c = torch.cat(c,dim=0).detach().cpu()
    # z = torch.cat(z,dim=0).detach().cpu()
    x_hat = torch.cat(x_hat,dim=0).detach().cpu()
    q_hat = torch.cat(q_hat,dim=0).detach().cpu()

    return eval_recon(x_hat,q_hat,x, q)

In [None]:
data_loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=512)
z_values = [1,2,3]
c_values = [1]
model_types = [FusionModel, NaiveFusionModel, JointVAE] 
eval_results= dict()
for model_type, z,c in tqdm(list(itertools.product(model_types, z_values,c_values))):
    model_label = f'{model_type.__name__}-z-{z}'
    eval_results[model_label] = eval_model(z,c,model_type,dir='models')
    


In [None]:
for k in eval_results:
    print(f'{k}     :   {eval_results[k]}')