In [1]:
import matplotlib.pyplot as plt
from os.path import isfile, join, abspath, dirname
from os import listdir
import sys
import pandas as pd
import numpy as np

In [2]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [3]:
sys.path.insert(0, '../src/')

In [4]:
from logger import LogData, TemporalLogger, EnvLogger
from utils import plot_typography

In [5]:
plot_typography(True, 12, 14, 16)

In [10]:
__file__ = "analysis.ipynb"
log_dir = join(dirname(dirname(abspath(__file__))), "log")
envs =  ["PongNoFrameskip-v0", "PongNoFrameskip-v4",
         "BreakoutNoFrameskip-v0", "BreakoutNoFrameskip-v4",
         "SeaquestNoFrameskip-v0", "SeaquestNoFrameskip-v4"]
cols = ['attention_target', 'attention_type', 'env_name', 
        'timestamp', 'mean_reward',
       'mean_feat_std', 'mean_proxy']
decimate = 100


def figure_factory(log_dir, env, variant, rew_scale=5, feat_scale=1, decimate=100, save=True, loc_feat=1, loc_rwd=4, zoom_feat=2.5, zoom_rwd=2.5):
    # sanity check
    if variant not in [0,4]:
        raise ValueError(f"Invalid variant, got {variant}, should be 0 or 4")
    
    # load data 
    el = EnvLogger(f"{env}NoFrameskip-v{variant}", log_dir, decimate)

    # plot
    rwd_metrics = el.plot_decorator(keyword="rewards", save=save, y_inset_std_scale=rew_scale, loc=loc_rwd, zoom=zoom_rwd)
    el.plot_decorator(keyword="features", save=save, y_inset_std_scale=feat_scale, loc=loc_feat, zoom=zoom_feat)

    return rwd_metrics

In [None]:
pong_metrics4 = figure_factory(log_dir, "Pong", 4, 5, .5, 100, save=False, loc_rwd=4, loc_feat=1, zoom_rwd=2.5, zoom_feat=4) 

In [None]:
breakout_metrics4 = figure_factory(log_dir, "Breakout", 4, 3.5, 1, 100, save=False, loc_rwd=2, loc_feat=4, zoom_rwd=2.75, zoom_feat=3.7) 

In [None]:
seaquest_metrics4 = figure_factory(log_dir, "Seaquest", 4, 1.2, 1., 100, save=False, loc_rwd=2, loc_feat=1, zoom_rwd=2.1, zoom_feat=1.7) 

In [None]:
# seeds set

pong_metrics0 = figure_factory(log_dir, "Pong", 0, 3.85, .75, 100, save=False, loc_rwd=4, loc_feat=1)

In [None]:
#not saved
breakout_metrics0 = figure_factory(log_dir, "Breakout", 0, 2.4, .5, 100, save=False, loc_rwd=2, loc_feat=4, zoom_rwd=1.5, zoom_feat=2.5)

In [None]:
#not saved
seaquest_metrics0 = figure_factory(log_dir, "Seaquest", 0, 1.4, .4, 100, save=False, loc_rwd=2, loc_feat=4, zoom_rwd=1.65, zoom_feat=2.5) 

In [14]:
df0 = pd.DataFrame([pong_metrics0, breakout_metrics0, seaquest_metrics0])
df4 = pd.DataFrame([pong_metrics4, breakout_metrics4, seaquest_metrics4])
df=pd.concat([df0,df4])

In [15]:
df

Unnamed: 0,Baseline,"ICM, single attention","ICM, double attention",AttA2C,RCM
0,96.138059,100.0,96.592801,94.411685,99.951489
1,82.247977,89.171067,83.069633,100.0,81.334197
2,96.045262,89.432758,87.058865,65.331944,100.0
0,94.669607,100.0,94.045772,94.119132,98.701972
1,95.039593,96.821546,92.266545,100.0,95.325408
2,89.826709,93.888922,100.0,74.197706,96.550187


In [16]:
df.to_numpy().mean(axis=0), df.to_numpy().std(axis=0), np.median(df.to_numpy(),axis=0)

(array([92.32786763, 94.88571552, 92.17226928, 88.01007786, 95.31054211]),
 array([ 4.98387019,  4.46177857,  5.67653452, 13.35974642,  6.47953637]),
 array([94.85459978, 95.35523433, 93.1561586 , 94.26540853, 97.62607948]))