# Loading pre-trained models

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline
import matplotlib.pyplot as plt 

import os
import types
import torch
import numpy as np
import torch.nn as nn
from torchdiffeq import odeint

import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from sklearn.metrics import mean_squared_error, r2_score
from permetrics import RegressionMetric

import sys
sys.path.append('..')

In [2]:
from npd.nn.core import (
	SignatureHead,
	MTANHead,
	LatentStateHead,
	TDABackbone,
	PointNetBackbone,
	JointBackbone,
	VecReconNet, 
	LatentODEfunc, 
	PathToGaussianDecoder)

from dynamics import (
    create_recog_backbone,
    create_recon_backbone,
    create_processor,
    load_data)

def compute_minmax_reverse_stats(prms_orig_file):
    prms_orig = torch.load(prms_orig_file)
    max_d = prms_orig.max(dim=0, keepdim=True).values
    min_d = prms_orig.min(dim=0, keepdim=True).values
    return min_d, max_d

def compute_minmax_reverse(x, min_d, max_d):
    return (x + 1.)/2 * (max_d-min_d) + min_d

In [31]:
config = {
    # file used for logging (contains args)
    'log_file': '../logs/dorsogna_1k_log.pt',
    # saved model
    'mdl_file': '../logs/dorsogna_1k_mdl.pt',
    # original parameter file
    'aux_file': '../data/Giusti23a/1k/prms_1k.pt',
    # correct relative paths if necessary
    'base_dir': '../'
}

Load **training arguments** and **state dictionary**:

In [17]:
_,_,args = torch.load(config['log_file'])
state_dict = torch.load(config['mdl_file'], map_location='cpu')

Instantiate model (in configuration used for training):

In [21]:
recog_backbone = create_recog_backbone(args) 
recon_backbone = create_recon_backbone(args) 
processor = create_processor(args)

modules = nn.ModuleDict(
{
    "recog_net": recog_backbone,
    "recon_net": recon_backbone,
    "lnode_net": LatentODEfunc(args.z_dim, args.ode_h_dim),
    "ptogd_net": PathToGaussianDecoder(nn.Identity(), initial_sigma=1.0),
    "processor": processor,
    "regressor": nn.Sequential(
            nn.Linear(processor.get_outdim(), args.num_aux_dim),
            nn.Tanh())})

modules.load_state_dict(state_dict) # load state_dict
modules = modules.to(args.device)

Load and prepare data:

In [22]:
args.vec_inp_file = os.path.join(config['base_dir'], args.vec_inp_file)
args.aux_inp_file = os.path.join(config['base_dir'], args.aux_inp_file)
ds = load_data(args)

split_generator = torch.Generator().manual_seed(args.seed) 
trn_set, tst_set = torch.utils.data.random_split(
    ds, 
    [0.8, 0.2], 
    generator=split_generator)
t = torch.linspace(0, 1.0, args.num_timepts).to(args.device)

dl_tst = DataLoader(tst_set, 
                    batch_size=args.batch_size, 
                    shuffle=False, 
                    collate_fn=ds.get_collate())

In [32]:
def predict(modules, dl, t):
    y_hat = []
    y_org = []
    modules.eval()
    for batch in dl_tst:
        out, evd_obs, evd_msk, aux_obs = modules['recog_net'](batch, args.device)

        qz0_mean, qz0_logvar = out[:, :args.z_dim], out[:, args.z_dim:]
        epsilon = torch.randn(qz0_mean.size()).to(args.device)
        z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean
        
        zs = odeint(
            modules['lnode_net'], 
            z0, 
            t, 
            method="euler").permute(1, 0, 2)
        aux_enc = modules['processor'](zs) 
        aux_out = modules['regressor'](aux_enc)
        
        y_hat.append(aux_out.detach().cpu())
        y_org.append(aux_obs.detach().cpu())
        
    return torch.cat(y_hat), torch.cat(y_org)

In [33]:
min_d, max_d = compute_minmax_reverse_stats(config['aux_file'])
unscale = lambda x: compute_minmax_reverse(x, min_d, max_d)
y_hat, y_org = predict(modules, dl_tst, t)
y_hat = unscale(y_hat) # scale [-1,1] prediction back to original scale
y_org = unscale(y_org) # scale [-1,1] prediction back to original scale

Compute stats ...

In [34]:
metric = RegressionMetric()
scores = {
    'r2s': np.mean(
        [r2_score(
            y_org[:,i],
            y_hat[:,i]) for i in range(y_hat.shape[1])
        ]),
    'smp': np.mean(
        [metric.symmetric_mean_absolute_percentage_error(
            y_org[:,i].numpy(),
            y_hat[:,i].numpy()) for i in range(y_hat.shape[1])
        ])
}
num_params = np.sum([p.numel() for p in modules.parameters()])
print(f'#Parameters={num_params}')
print(f'#Samples={len(tst_set)}')
print('R2={:0.4f} | SMAPE={:0.4f}'.format(
    scores['r2s'], scores['smp']))

#Parameters=135931
#Samples=200
R2=0.8429 | SMAPE=0.0918
