In [71]:
import json
import numpy as np
import matplotlib.pyplot as plt
import os
from collections import defaultdict

In [72]:
directory = "../logs"

tasks = ["pendulum_regression", "pendulum_interpolation", "activity_classification","physionet_interpolation", "rotating_mnist" ]

experiments = defaultdict(list)
for task in tasks:
    for log_file in os.listdir(directory):
        if task in log_file:
            filename = os.path.join(directory,log_file)
            if filename.endswith(".json"):      
                print(filename)      
                with open(filename, 'r') as f:
                    dic = json.load(f)
                    if dic["final"]["epoch"] == 990:
                        if task == "physionet_interpolation":
                            experiments[f'{task}_{dic["args"]["sample_tp"]}'].append(dic['final'])
                        else:
                            experiments[task].append(dic['final'])




../logs/pendulum_regression_5328.json
../logs/pendulum_regression_53864.json
../logs/pendulum_regression_87399.json
../logs/pendulum_regression_10850.json
../logs/pendulum_regression_72986.json
../logs/pendulum_regression_43325.json
../logs/pendulum_regression_68092.json
../logs/pendulum_regression_56207.json
../logs/pendulum_regression_65334.json
../logs/pendulum_regression_9790.json
../logs/pendulum_interpolation_76020.json
../logs/pendulum_interpolation_5910.json
../logs/pendulum_interpolation_14474.json
../logs/pendulum_interpolation_93831.json
../logs/pendulum_interpolation_61052.json
../logs/pendulum_interpolation_56807.json
../logs/pendulum_interpolation_59211.json
../logs/pendulum_interpolation_36438.json
../logs/pendulum_interpolation_77345.json
../logs/pendulum_interpolation_63339.json
../logs/pendulum_interpolation_79478.json
../logs/pendulum_interpolation_40940.json
../logs/pendulum_interpolation_34971.json
../logs/pendulum_interpolation_19179.json
../logs/pendulum_interpol

In [73]:
summarized = defaultdict(dict)
for task in experiments.keys():
    for key in experiments[task][0].keys():
        summarized[task][key] = [experiment[key] for experiment in experiments[task]]

In [74]:
for task in summarized.keys():
    print(task, len(summarized[task]['trn_elbo']))

pendulum_regression 8
pendulum_interpolation 11
activity_classification 8
physionet_interpolation_0.5 8
rotating_mnist 10


In [75]:
def gen_table(exp, key, scale, prec=2):
    l = len(exp[key])
    r1 = f"|      Run     |     Mean    |" + "".join([f"   {i}   |" for i in range(1,l+1)]) + "\n"
    r2 = f"| -------------| ----------- |" + "".join([f" ----- |" for i in range(1,l+1)]) + "\n"
    r3 =  (f"|{key} | {exp[key].mean()*scale:.{prec}f} \u00B1 {exp[key].std()*scale:.{prec}f}|" 
    +"".join([ f" {x*scale:.{prec}f} |" for x in exp[key].values])
    )
    return r1 + r2 + r3

In [78]:
import pandas as pd 
# task = 'physionet_interpolation_0.5'
task = 'activity_classification'
exp = pd.DataFrame.from_dict(summarized[task])
target = 'tst_aux_acc*'

print(gen_table(exp, target, 100, prec=2))
print(exp)


|      Run     |     Mean    |   1   |   2   |   3   |   4   |   5   |   6   |   7   |   8   |
| -------------| ----------- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- |
|tst_aux_acc* | 90.58 ± 0.48| 90.03 | 90.82 | 90.64 | 90.78 | 90.22 | 89.91 | 91.12 | 91.14 |
   trn_loss  trn_elbo    trn_kl0       trn_klp  trn_log_pxz  trn_aux_val  \
0 -0.166629 -0.234294  55.928303  13061.939096   -36.463468     0.006767   
1 -0.189094 -0.250202  55.973752  14803.383114   -39.023938     0.006111   
2 -0.207947 -0.273369  55.930261  13100.706892   -42.329285     0.006542   
3 -0.179430 -0.243604  55.912941  13260.945647   -37.880083     0.006417   
4  0.206266  0.015034  55.346198  12449.186306     1.003128     0.019123   
5  0.206579  0.014839  55.215253  12449.401103     0.973840     0.019174   
6 -0.189289 -0.253822  55.945511  13825.852842   -39.469383     0.006453   
7 -0.171779 -0.236753  55.959229  14236.256106   -36.949680     0.006497   

   tst_loss  tst_aux_val  tst_a

In [30]:
exp

Unnamed: 0,trn_loss,trn_elbo,trn_kl0,trn_klp,trn_log_pxz,tst_loss,tst_elbo,tst_kl0,tst_klp,tst_log_pxz,...,tst_mse_trgt*,val_loss,val_elbo,val_kl0,val_klp,val_log_pxz,val_mse_full,val_mse_trgt,val_mse_trgt*,epoch
0,0.311623,0.311623,55.993642,115290.676389,2675.900239,0.338022,0.338022,55.989768,112542.064757,4228.885666,...,0.012027,0.336025,0.336025,55.993793,113203.701389,4203.76888,0.010574,0.011639,0.011207,990
1,0.311633,0.311633,55.993793,115068.064236,2676.012012,0.336738,0.336738,55.993511,112801.727951,4212.750803,...,0.011466,0.333905,0.333905,55.992166,113266.909722,4177.174316,0.009777,0.010955,0.010821,990
2,0.311534,0.311534,55.993468,111676.097569,2675.499186,0.335471,0.335471,55.990538,110257.68125,4197.110655,...,0.011036,0.333836,0.333836,55.992925,110622.352431,4176.576063,0.009748,0.010974,0.010602,990
3,0.311478,0.311478,55.993793,115261.109375,2674.658084,0.337112,0.337112,55.993533,113163.293056,4217.408691,...,0.01202,0.334723,0.334723,55.993793,114442.510417,4187.309679,0.010104,0.011429,0.011159,990
4,0.311591,0.311591,55.993349,122807.135069,2674.877089,0.335856,0.335856,55.990159,119507.160937,4201.018001,...,0.011802,0.335526,0.335526,55.992166,120143.727431,4196.823947,0.010277,0.011958,0.011348,990
5,0.311099,0.311099,55.993793,113321.648785,2671.581261,0.336702,0.336702,55.991873,111868.158681,4212.403212,...,0.011445,0.33443,0.33443,55.992925,110700.693576,4184.008192,0.009978,0.01106,0.010763,990
6,0.312182,0.312182,55.993425,112962.785069,2680.957422,0.337059,0.337059,55.991981,110513.877604,4217.007726,...,0.011743,0.335996,0.335996,55.993793,110108.418403,4203.715061,0.010562,0.011113,0.010894,990
7,0.311489,0.311489,55.980333,111830.693403,2675.09452,0.336691,0.336691,55.958013,109775.532639,4212.473242,...,0.011575,0.335628,0.335628,55.965039,111064.776042,4199.008138,0.010372,0.011723,0.011362,990
8,0.311736,0.311736,55.985596,111324.75434,2677.274642,0.336525,0.336525,55.97181,108936.059028,4210.466829,...,0.01089,0.333803,0.333803,55.979579,110062.75,4176.213379,0.009736,0.010985,0.010531,990
9,0.312657,0.312657,55.993793,112418.743924,2685.104785,0.336826,0.336826,55.991981,109799.831944,4214.164605,...,0.011338,0.334788,0.334788,55.993793,111820.050347,4188.389648,0.010131,0.011217,0.010788,990
