In [1]:
import torch
import pandas as pd 

import json 
from math import sqrt

In [2]:
Z_99_PERCENT = 2.576

def ci(t): 
    return (
        Z_99_PERCENT * 
        (t.std() / sqrt(t.size(0)))
    ).item()

def to_dict(t):
    t = t.float().sort().values
    quartile = t.size(0) // 4
    trunc = t[quartile:-quartile]
    trunc_mean = trunc.mean().item()
    ci_range = ci(trunc)

    return {
        'mean': t.mean().item(),
        'max': t.max().item(),
        'min': t.min().item(),
        'std': t.std().item(),
        'trunc-mean': trunc_mean,
        'CI-low': trunc_mean - ci_range,
        'CI-high': trunc_mean + ci_range
    }

def get_results(s):
    data = torch.load(f'../results/{s}_eval.pt')
    lens = data['lens']; rews = data['rews']

    return {'rewards': to_dict(rews), 'episode lens': to_dict(lens)}


In [3]:
import pandas as pd 
def compare_one(n, dir):
    d_last = get_results(f'{dir}ppo_{n}N_0_last')
    #d_best = get_results(f'ppo_{n}N_0')

    last_r = d_last['rewards']
    last_r['n'] = n
    last_l = d_last['episode lens']
    last_l['n'] = n

    '''
    best_r = d_best['rewards']
    best_r['name'] = 'Best'
    best_l = d_best['episode lens']
    best_l['name'] = 'Best'
    '''
    
    return last_r, last_l

def eval(dir=''):
    rs, ls = zip(*[
        compare_one(n, dir)
        for n in [10,20,40]
    ])
    return pd.DataFrame(rs).transpose()#, pd.DataFrame(ls)


In [4]:
eval(dir='deg5e/')

Unnamed: 0,0,1,2
mean,225.469864,174.02681,161.114014
max,549.098022,574.052246,586.5271
min,153.599762,136.549698,142.350616
std,106.690361,79.759674,33.488491
trunc-mean,186.042252,150.146255,158.031876
CI-low,179.726615,149.260501,157.640253
CI-high,192.357888,151.03201,158.423498
n,10.0,20.0,40.0


In [5]:
eval(dir='deg1e/')

Unnamed: 0,0,1,2
mean,232.157043,186.053757,160.912735
max,549.098022,574.052246,586.5271
min,153.799759,137.84967,143.100647
std,120.576317,87.04631,33.196182
trunc-mean,180.684418,157.616287,157.781677
CI-low,176.459731,155.530145,157.374338
CI-high,184.909105,159.70243,158.189016
n,10.0,20.0,40.0


In [6]:
eval(dir='og1e/')

Unnamed: 0,0,1,2
mean,260.277557,186.269333,159.449722
max,549.098022,574.052246,586.5271
min,153.999756,136.749695,146.025635
std,131.704697,96.347672,20.444815
trunc-mean,210.702301,152.27243,157.821686
CI-low,201.556477,150.921894,157.477273
CI-high,219.848125,153.622967,158.166098
n,10.0,20.0,40.0


In [8]:
eval(dir='og_inductive/')

Unnamed: 0,0,1,2
mean,403.421173,236.254852,167.950592
max,549.098022,574.052246,586.5271
min,203.599564,140.899643,143.750626
std,121.334602,124.590134,42.74884
trunc-mean,407.296997,186.062012,163.757599
CI-low,394.390206,183.000457,163.258862
CI-high,420.203788,189.123566,164.256336
n,10.0,20.0,40.0


In [None]:
eval(dir='og5e/')

Unnamed: 0,0,1,2
mean,258.546722,179.086945,160.78891
max,549.098022,574.052246,586.5271
min,153.699753,136.949692,143.075638
std,142.374573,86.818016,33.642872
trunc-mean,195.354477,150.004852,157.775192
CI-low,187.541623,149.068844,157.416633
CI-high,203.167331,150.940861,158.133752
n,10.0,20.0,40.0


In [None]:
eval(dir='og_tanh/')

Unnamed: 0,0,1,2
mean,225.225525,183.326218,159.838776
max,549.098022,574.052246,574.152222
min,153.699753,137.34967,143.02565
std,117.048775,95.514565,27.557026
trunc-mean,176.725998,153.167236,157.393402
CI-low,173.106527,151.800168,157.058372
CI-high,180.345469,154.534304,157.728432
n,10.0,20.0,40.0


In [None]:
eval(dir='deg_dom_rnd/')

Unnamed: 0,0,1,2
mean,244.247482,170.570282,156.893539
max,549.098022,574.052246,175.450577
min,153.499756,137.149689,143.175629
std,124.764923,74.684303,5.465354
trunc-mean,193.200623,149.130661,156.629089
CI-low,186.114025,148.492584,156.280997
CI-high,200.28722,149.768738,156.977182
n,10.0,20.0,40.0


In [None]:
eval(dir='og_dom_rnd/')

Unnamed: 0,0,1,2
mean,203.824188,185.624344,157.607651
max,549.098022,574.052246,182.951096
min,153.199768,137.699677,142.450638
std,101.863586,91.390091,5.733146
trunc-mean,163.650803,155.574402,157.402283
CI-low,161.908436,154.042752,157.045367
CI-high,165.393169,157.106051,157.759199
n,10.0,20.0,40.0
