# Evaluation

## Overview

- [vicsek-10k](#vicsek-10k)
- [dorsogna-1k](#dorsogna-1k)
- [dorsogna-10k](#dorsogna-10k)
- [volex-10k](#volex-10k)


---

In [6]:
import os 
import sys
import glob
import torch
import numpy as np

import pandas as pd
from collections import defaultdict
from sklearn.metrics import r2_score, mean_squared_error
from statsmodels.stats.multitest import multipletests
from statsmodels.tools.eval_measures import rmspe, rmse
from permetrics import RegressionMetric

from IPython.display import display, HTML

sys.path.append('../')
from npd.utils.core import compute_minmax_reverse_stats, compute_minmax_reverse

In [10]:
def find_best_idx_from_val(
    aux_t: torch.Tensor, 
    aux_p: torch.Tensor, 
    val_percent:float=0.25):
    
    assert len(aux_t) == len(aux_p)
    assert aux_t[-1].shape[0] == aux_p[-1].shape[0] 
    assert aux_t[-1].shape[1] == aux_p[-1].shape[1]
    num_aux_dims = aux_p[-1].shape[1]
    num_epochs = len(aux_t) 
    
    n_total = aux_t[-1].shape[0] 
    all_idx = torch.randperm(n_total)
    val_idx = all_idx[0:int(val_percent*n_total)]
    tst_idx = np.setdiff1d(all_idx, val_idx)

    best_i = np.argmax(
        [np.mean([r2_score(aux_t[n][val_idx,d],aux_p[n][val_idx,d]) for d in range(num_aux_dims)]) for n in range(num_epochs)])
    return val_idx, tst_idx, best_i 

In [12]:
def load_stats(
    file_list: list, 
    orig_prms_file: str, 
    show_progress:bool=False):

    metric = RegressionMetric()
    min_d, max_d = compute_minmax_reverse_stats(orig_prms_file)
     
    stats = defaultdict(list)
    for fname in file_list:
        if show_progress: print(fname)
        if not os.path.exists(fname): continue
            
        trn_tracker, tst_tracker, args = torch.load(fname)
        
        r2s_all = []
        rmse_all = []
        smape_all = []

        """split-off N% of testing simulations as validation set to find the best epoch"""
        val_idx, tst_idx, best_i = find_best_idx_from_val(
            tst_tracker['epoch_aux_t'], 
            tst_tracker['epoch_aux_p'],0.2) 

        aux_p = compute_minmax_reverse(tst_tracker['epoch_aux_p'][best_i], min_d, max_d)
        aux_t = compute_minmax_reverse(tst_tracker['epoch_aux_t'][best_i], min_d, max_d)
        assert aux_p.shape == aux_t.shape
        num_aux_dims = aux_p.shape[1]

        # iterate over all available simulation paramters
        for aux_d in range(num_aux_dims):            
            tmp_t = aux_t[tst_idx,aux_d].numpy()
            tmp_p = aux_p[tst_idx,aux_d].numpy()

            # evaluation measures
            r2s = r2_score(tmp_t, tmp_p)
            rmse = metric.root_mean_squared_error(tmp_t, tmp_p)
            smape = metric.symmetric_mean_absolute_percentage_error(tmp_t, tmp_p)
            
            stats['r2s_param'+str(aux_d)].append(r2s)
            stats['rmse_param'+str(aux_d)].append(rmse)
            stats['smape_param'+str(aux_d)].append(smape)
            
            r2s_all.append(r2s)
            rmse_all.append(rmse)
            smape_all.append(smape)
    
        stats['r2s'].append(np.mean(r2s_all))
        stats['rmse'].append(np.mean(rmse_all))  
        stats['smape'].append(np.mean(smape_all))  
        stats['file'].append(fname)
        stats['weight-decay'].append(args.weight_decay)
        stats['tps-frac'].append(args.tps_frac)
        stats['backbone'].append(args.backbone)
        if hasattr(args, 'processor'):
            stats['processor'].append(args.processor)
    return pd.DataFrame(stats)

## vicsek-10k

### Dynamics (Ours)

In [30]:
vicsek_10k_files = glob.glob('../logs/vicsek-10k/*dynamics*')
vicsek_10k_prms_file = '../data/vicsek-10k/prms_10k.pt' 
print(f'Found {len(vicsek_10k_files)} files!')

Found 45 files!


In [31]:
vicsek_10k_df_dynamics = load_stats(vicsek_10k_files, vicsek_10k_prms_file, False)
vicsek_10k_df_dynamics.insert(4, "r2avg", (
                     vicsek_10k_df_dynamics['r2s_param0'] + \
                     vicsek_10k_df_dynamics['r2s_param1'] + \
                     vicsek_10k_df_dynamics['r2s_param2'] + \
                     vicsek_10k_df_dynamics['r2s_param3'])/4, True)
vicsek_10k_df_dynamics.insert(5, "smapeavg", (
                     vicsek_10k_df_dynamics['smape_param0'] + \
                     vicsek_10k_df_dynamics['smape_param1'] + \
                     vicsek_10k_df_dynamics['smape_param2'] + \
                     vicsek_10k_df_dynamics['smape_param3'])/4, True)

vicsek_10k_df_dynamics_summary = vicsek_10k_df_dynamics.groupby([
    'backbone',
    'processor',
    'weight-decay'])[['smapeavg',
                      'r2avg',
                      'rmse_param0', 
                      'rmse_param1',
                      'rmse_param2',
                      'rmse_param3']].aggregate(['mean','std','count'])

vicsek_10k_df_dynamics_summary.style.set_table_attributes("style='display:inline'").set_caption('Dynamics')
with pd.option_context('display.float_format', '{:0.3f}'.format):
    display(vicsek_10k_df_dynamics_summary)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,smapeavg,smapeavg,smapeavg,r2avg,r2avg,r2avg,rmse_param0,rmse_param0,rmse_param0,rmse_param1,rmse_param1,rmse_param1,rmse_param2,rmse_param2,rmse_param2,rmse_param3,rmse_param3,rmse_param3
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,mean,std,count,mean,std,count,mean,std,count,mean,std,count,mean,std,count,mean,std,count
backbone,processor,weight-decay,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2
joint,z_mtantwins,0.001,0.145,0.005,15,0.571,0.032,15,0.612,0.03,15,0.749,0.099,15,0.449,0.01,15,0.976,0.019,15
ptsdyn_only,z_mtantwins,0.001,0.199,0.014,15,0.266,0.083,15,1.031,0.094,15,1.126,0.076,15,0.496,0.024,15,1.174,0.043,15
topdyn_only,z_mtantwins,0.001,0.146,0.006,15,0.574,0.032,15,0.616,0.03,15,0.748,0.103,15,0.445,0.007,15,0.971,0.024,15


## dorsogna-1k 

### Dynamics (Ours)

In [39]:
dorsogna_1k_files = glob.glob('../logs/dorsogna-1k/*dynamics*')
dorsogna_1k_prms_file = '../data/dorsogna-1k/prms_1k.pt'
print(f'Found {len(dorsogna_1k_files)} files!')

Found 45 files!


In [40]:
dorsogna_1k_df_dynamics = load_stats(dorsogna_1k_files, dorsogna_1k_prms_file, False)
dorsogna_1k_df_dynamics.insert(4, "r2avg", (
                            dorsogna_1k_df_dynamics['r2s_param0'] + \
                            dorsogna_1k_df_dynamics['r2s_param1'])/2, True)
dorsogna_1k_df_dynamics.insert(5, "smapeavg", (
                            dorsogna_1k_df_dynamics['smape_param0'] + \
                            dorsogna_1k_df_dynamics['smape_param1'])/2, True)


dorsogna_1k_df_dynamics_summary = dorsogna_1k_df_dynamics.groupby([
    'backbone',
    'processor',
    'weight-decay'])[['smapeavg',
                      'r2avg',
                      'rmse_param0', 
                      'rmse_param1']].aggregate(['mean','std','count'])
dorsogna_1k_df_dynamics_summary.style.set_table_attributes("style='display:inline'").set_caption('Dynamics');
with pd.option_context('display.float_format', '{:0.3f}'.format):
    display(dorsogna_1k_df_dynamics_summary)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,smapeavg,smapeavg,smapeavg,r2avg,r2avg,r2avg,rmse_param0,rmse_param0,rmse_param0,rmse_param1,rmse_param1,rmse_param1
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,mean,std,count,mean,std,count,mean,std,count,mean,std,count
backbone,processor,weight-decay,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2
joint,z_mtantwins,0.001,0.072,0.004,15,0.924,0.006,15,0.085,0.003,15,0.197,0.012,15
ptsdyn_only,z_mtantwins,0.001,0.136,0.019,15,0.799,0.037,15,0.184,0.02,15,0.262,0.025,15
topdyn_only,z_mtantwins,0.001,0.102,0.007,15,0.84,0.013,15,0.131,0.01,15,0.28,0.01,15


## dorsogna-10k

### Dynamics (Ours)

In [41]:
dorsogna_10k_files = glob.glob('../logs/dorsogna-10k/*dynamics*')
dorsogna_10k_prms_file = '../data/dorsogna-10k/prms_10k.pt'
print(f'Found {len(dorsogna_10k_files)} files!')

Found 45 files!


In [42]:
dorsogna_10k_df_dynamics = load_stats(dorsogna_10k_files, dorsogna_10k_prms_file, False)

dorsogna_10k_df_dynamics.insert(4, "r2avg", (
        dorsogna_10k_df_dynamics['r2s_param0'] + \
        dorsogna_10k_df_dynamics['r2s_param1'] + \
        dorsogna_10k_df_dynamics['r2s_param2'] + \
        dorsogna_10k_df_dynamics['r2s_param3'])/4, True)

dorsogna_10k_df_dynamics.insert(5, "smapeavg", (
        dorsogna_10k_df_dynamics['smape_param0'] + \
        dorsogna_10k_df_dynamics['smape_param1'] + \
        dorsogna_10k_df_dynamics['smape_param2'] + \
        dorsogna_10k_df_dynamics['smape_param3'])/4, True)

dorsogna_10k_df_dynamics_summary = dorsogna_10k_df_dynamics.groupby([
   'backbone', 
   'processor',
   'weight-decay'])[['smapeavg',
                     'r2avg',
                     'rmse_param0', 
                     'rmse_param1', 
                     'rmse_param2', 
                     'rmse_param3']].aggregate(['mean','std','count'])
with pd.option_context('display.float_format', '{:0.4f}'.format):
    display(dorsogna_10k_df_dynamics_summary)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,smapeavg,smapeavg,smapeavg,r2avg,r2avg,r2avg,rmse_param0,rmse_param0,rmse_param0,rmse_param1,rmse_param1,rmse_param1,rmse_param2,rmse_param2,rmse_param2,rmse_param3,rmse_param3,rmse_param3
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,mean,std,count,mean,std,count,mean,std,count,mean,std,count,mean,std,count,mean,std,count
backbone,processor,weight-decay,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2
joint,z_mtantwins,0.001,0.0892,0.0034,15,0.6882,0.0196,15,0.2277,0.0064,15,0.4884,0.0302,15,0.1473,0.0036,15,0.1342,0.005,15
ptsdyn_only,z_mtantwins,0.001,0.0885,0.0049,15,0.6717,0.0291,15,0.2384,0.0069,15,0.478,0.0431,15,0.1407,0.0064,15,0.1386,0.0091,15
topdyn_only,z_mtantwins,0.001,0.091,0.0051,15,0.6768,0.0276,15,0.2313,0.0089,15,0.4996,0.0301,15,0.147,0.0049,15,0.136,0.0071,15


## volex-10k

### Dynamics (Ours)

In [43]:
volex_10k_files = glob.glob('../logs/volex-10k/*dynamics*')
volex_10k_prms_file = '../data/volex-10k/prms_1k.pt'
print(f'Found {len(dorsogna_10k_files)} files!')

Found 45 files!


In [44]:
volex_10k_df_dynamics = load_stats(volex_10k_files, volex_10k_prms_file, False)

volex_10k_df_dynamics.insert(4, "r2avg", (
            volex_10k_df_dynamics['r2s_param0'] + \
            volex_10k_df_dynamics['r2s_param1'] + \
            volex_10k_df_dynamics['r2s_param2'] + \
            volex_10k_df_dynamics['r2s_param3'])/4, True)

volex_10k_df_dynamics.insert(5, "smapeavg", (
            volex_10k_df_dynamics['smape_param0'] + \
            volex_10k_df_dynamics['smape_param1'] + \
            volex_10k_df_dynamics['smape_param2'] + \
            volex_10k_df_dynamics['smape_param3'])/4, True)

volex_10k_df_dynamics_summary = volex_10k_df_dynamics.groupby([
   'backbone', 
   'processor',
   'weight-decay'])[['smapeavg',
                     'r2avg',
                     'rmse_param0', 
                     'rmse_param1', 
                     'rmse_param2', 
                     'rmse_param3']].aggregate(['mean','std','count'])
with pd.option_context('display.float_format', '{:0.3f}'.format):
    display(volex_10k_df_dynamics_summary)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,smapeavg,smapeavg,smapeavg,r2avg,r2avg,r2avg,rmse_param0,rmse_param0,rmse_param0,rmse_param1,rmse_param1,rmse_param1,rmse_param2,rmse_param2,rmse_param2,rmse_param3,rmse_param3,rmse_param3
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,mean,std,count,mean,std,count,mean,std,count,mean,std,count,mean,std,count,mean,std,count
backbone,processor,weight-decay,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2
joint,z_mtantwins,0.001,0.081,0.006,15,0.869,0.019,15,0.053,0.007,15,0.233,0.027,15,0.106,0.006,15,0.099,0.005,15
ptsdyn_only,z_mtantwins,0.001,0.096,0.007,15,0.807,0.026,15,0.079,0.009,15,0.332,0.039,15,0.112,0.003,15,0.112,0.004,15
topdyn_only,z_mtantwins,0.001,0.082,0.006,15,0.867,0.018,15,0.054,0.009,15,0.233,0.024,15,0.106,0.005,15,0.1,0.005,15
