In [63]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import sys, copy, os, shutil
from tqdm.notebook import tqdm
from IPython.display import clear_output

# load in our logs file
logs = pd.read_csv("aggregate_logs.csv")
logs = logs.query("a == 0.0")

In [64]:
'''
For each missing data setting + intensity, let's get the best performing variants in terms of 
mean total reward, hyperparameter tuning over eps, alpha, and g
'''

# what are the variables governing our environment? also variables governing our model, and universal hparams
env_params = ["PS", "PW", "MM", "theta", "t-color", "t-in", "t-out"]
model_params = ["IM", "NC", "K", "p-shuf"]
uni_params = ["eps", "a", "g"]
nuisance_params = ["max-iters", "seed"]
metric_params = ["num_episodes", "mean_total_reward", "mean_steps_river", "mean_path_length",
                 "mean_wallclock_time", "mean20_total_reward", "mean20_steps_river",
                 "mean20_path_length", "mean20_wallclock_time"]

# we're only concerned with the mean over 3 seeds
logs = logs.groupby(env_params + model_params + uni_params, dropna=False).mean().reset_index()

In [None]:
# 1. pretty pictures of the environments with the rewards (1 x 2, side-by-side) COLORED CORRECTLY

# 2. visualize Q function on MCAR + theta=0.4: best joint vs. best baseline (looks like shit) (2x2)
# 2 models, 2 {flood vs. no flood.}

# 3. Lineplot of aggregate results (MCAR + max wind + switch): 1x3 grid of subplots, x-axis=theta 
# y-axis \in {mean_total_reward, mean_steps_river, mean_path_length}, each line is a model (only K=1 and K=10)
# pick best alpha, gamma, epsilon, and p_shuffle (if applicable) for each model

# 4. learning plots over time (MCAR + max_wind + max_switch at theta=0.4): 1x3
# x-axis = timestep -- let's do stepplot to make use of per-episode metrics. 
# y-axis: {reward, river_steps, path_length}, 3 baselines + K={1, 10} x {joint, joint-cons} x {best shuffle}

# 5. MI details -- look at K=1, K=5, K=10 and p-shuffle = 0.1 vs. 0.0
# MCAR at theta=0.4, max_wind and max_switch. x-axis: K=1,5,10, 1 line for p-shuffle 1 line for no-shuffle
# lets do a 1x3 for each metric of {reward, river_steps, path_length}

# 6. Less beautiful-looking things -- Mcolor and Mfog (PS and PW are maxed out.). 2x2
# x-axis: timestep (step function), y-axis: path_length, 
# 3 baselines + K={1, 10} x {joint, joint-cons} x {best shuffle}

Unnamed: 0,PS,PW,MM,theta,t-color,t-in,t-out,IM,NC,K,...,seed,num_episodes,mean_total_reward,mean_steps_river,mean_path_length,mean_wallclock_time,mean20_total_reward,mean20_steps_river,mean20_path_length,mean20_wallclock_time
0,0.0,0.0,MCAR,0.0,,,,joint,1.0,1.0,...,1.0,0.000000,,,,,,,,
1,0.0,0.0,MCAR,0.0,,,,joint,1.0,1.0,...,1.0,0.000000,,,,,,,,
2,0.0,0.0,MCAR,0.0,,,,joint,1.0,1.0,...,1.0,0.000000,,,,,,,,
3,0.0,0.0,MCAR,0.0,,,,joint,1.0,1.0,...,1.0,11.000000,-8007.041919,459.227273,3974.996465,32.241818,-8007.041919,459.227273,3974.996465,32.241818
4,0.0,0.0,MCAR,0.0,,,,joint,1.0,1.0,...,1.0,11.000000,-8007.041919,459.227273,3974.996465,31.768138,-8007.041919,459.227273,3974.996465,31.768138
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3235,0.1,0.1,Mfog,,,0.5,0.0,random-action,,,...,1.0,48.000000,-2271.652190,149.958701,1023.023879,7.606696,-2407.250000,158.366667,1082.950000,8.072081
3236,0.1,0.1,Mfog,,,0.5,0.0,random-action,,,...,1.0,48.000000,-2271.652190,149.958701,1023.023879,7.706984,-2407.250000,158.366667,1082.950000,8.166112
3237,0.1,0.1,Mfog,,,0.5,0.0,random-action,,,...,1.0,60.333333,-1911.805445,131.327353,830.859265,6.516143,-1660.066667,114.133333,733.866667,5.882746
3238,0.1,0.1,Mfog,,,0.5,0.0,random-action,,,...,1.0,60.333333,-1911.805445,131.327353,830.859265,6.500699,-1660.066667,114.133333,733.866667,5.828956


In [None]:
# PS: 0 vs. 0.1, PW: 0 vs. 0.1
# MCAR + {0.0, 0.05, 0.1, 0.2, 0.4}, Mfog + {2 variants}, Mcolor + {2 variants}
# gamma: {1.0, 0.5, 0.0}, alpha: {1.0, SW boobooed}, epsilon: {0.0, 0.05} -- just take whatever performed best.

In [79]:
logs.query("`t-in` == 0.25 and PS == 0.1 and PW == 0.1")\
.sort_values(by="mean_total_reward").head(20)

Unnamed: 0,PS,PW,MM,theta,t-color,t-in,t-out,IM,NC,K,...,seed,num_episodes,mean_total_reward,mean_steps_river,mean_path_length,mean_wallclock_time,mean20_total_reward,mean20_steps_river,mean20_path_length,mean20_wallclock_time
3092,0.1,0.1,Mfog,,,0.25,0.1,joint,1.0,10.0,...,1.0,34.0,-3167.79712,199.298245,1475.112918,112.618417,-3105.216667,194.35,1457.066667,110.892861
3091,0.1,0.1,Mfog,,,0.25,0.1,joint,1.0,10.0,...,1.0,34.0,-3167.79712,199.298245,1475.112918,113.787003,-3105.216667,194.35,1457.066667,112.472425
3090,0.1,0.1,Mfog,,,0.25,0.1,joint,1.0,10.0,...,1.0,34.0,-3167.79712,199.298245,1475.112918,116.226653,-3105.216667,194.35,1457.066667,114.763261
3128,0.1,0.1,Mfog,,,0.25,0.1,joint-conservative,1.0,10.0,...,1.0,34.0,-3167.79712,199.298245,1475.112918,112.091131,-3105.216667,194.35,1457.066667,110.73026
3127,0.1,0.1,Mfog,,,0.25,0.1,joint-conservative,1.0,10.0,...,1.0,34.0,-3167.79712,199.298245,1475.112918,111.798938,-3105.216667,194.35,1457.066667,110.608356
3126,0.1,0.1,Mfog,,,0.25,0.1,joint-conservative,1.0,10.0,...,1.0,34.0,-3167.79712,199.298245,1475.112918,113.572385,-3105.216667,194.35,1457.066667,112.293089
3079,0.1,0.1,Mfog,,,0.25,0.1,joint,1.0,5.0,...,1.0,38.333333,-2913.400687,186.286962,1337.818026,59.095469,-3038.45,191.333333,1417.45,62.67305
3114,0.1,0.1,Mfog,,,0.25,0.1,joint-conservative,1.0,5.0,...,1.0,38.333333,-2913.400687,186.286962,1337.818026,51.304049,-3038.45,191.333333,1417.45,54.3356
3115,0.1,0.1,Mfog,,,0.25,0.1,joint-conservative,1.0,5.0,...,1.0,38.333333,-2913.400687,186.286962,1337.818026,57.264882,-3038.45,191.333333,1417.45,59.968584
3116,0.1,0.1,Mfog,,,0.25,0.1,joint-conservative,1.0,5.0,...,1.0,38.333333,-2913.400687,186.286962,1337.818026,56.858065,-3038.45,191.333333,1417.45,59.72675


In [73]:
logs.query("`t-color` == 0.1 and PS == 0.1 and PW == 0.1")\
.sort_values(by="mean_total_reward").head(20)

Unnamed: 0,PS,PW,MM,theta,t-color,t-in,t-out,IM,NC,K,...,seed,num_episodes,mean_total_reward,mean_steps_river,mean_path_length,mean_wallclock_time,mean20_total_reward,mean20_steps_river,mean20_path_length,mean20_wallclock_time
2970,0.1,0.1,Mcolor,,0.1,,,joint,1.0,1.0,...,1.0,36.0,-2962.739719,191.672186,1338.690043,10.778019,-2717.683333,179.766667,1200.783333,9.699498
3013,0.1,0.1,Mcolor,,0.1,,,joint-conservative,1.0,1.0,...,1.0,36.0,-2962.739719,191.672186,1338.690043,10.220634,-2717.683333,179.766667,1200.783333,9.162026
3012,0.1,0.1,Mcolor,,0.1,,,joint-conservative,1.0,1.0,...,1.0,36.0,-2962.739719,191.672186,1338.690043,10.63387,-2717.683333,179.766667,1200.783333,9.548823
3008,0.1,0.1,Mcolor,,0.1,,,joint-conservative,1.0,1.0,...,1.0,36.0,-2962.739719,191.672186,1338.690043,10.45523,-2717.683333,179.766667,1200.783333,9.377126
3007,0.1,0.1,Mcolor,,0.1,,,joint-conservative,1.0,1.0,...,1.0,36.0,-2962.739719,191.672186,1338.690043,10.496097,-2717.683333,179.766667,1200.783333,9.4242
3006,0.1,0.1,Mcolor,,0.1,,,joint-conservative,1.0,1.0,...,1.0,36.0,-2962.739719,191.672186,1338.690043,10.76084,-2717.683333,179.766667,1200.783333,9.660813
2978,0.1,0.1,Mcolor,,0.1,,,joint,1.0,1.0,...,1.0,36.0,-2962.739719,191.672186,1338.690043,10.46428,-2717.683333,179.766667,1200.783333,9.414179
3014,0.1,0.1,Mcolor,,0.1,,,joint-conservative,1.0,1.0,...,1.0,36.0,-2962.739719,191.672186,1338.690043,10.439354,-2717.683333,179.766667,1200.783333,9.392675
2976,0.1,0.1,Mcolor,,0.1,,,joint,1.0,1.0,...,1.0,36.0,-2962.739719,191.672186,1338.690043,10.728194,-2717.683333,179.766667,1200.783333,9.650713
2977,0.1,0.1,Mcolor,,0.1,,,joint,1.0,1.0,...,1.0,36.0,-2962.739719,191.672186,1338.690043,10.439723,-2717.683333,179.766667,1200.783333,9.388127


In [72]:
logs.query("`t-color` == 0.0 and PS == 0.1 and PW == 0.1")\
.sort_values(by="mean_total_reward").head(20)

Unnamed: 0,PS,PW,MM,theta,t-color,t-in,t-out,IM,NC,K,...,seed,num_episodes,mean_total_reward,mean_steps_river,mean_path_length,mean_wallclock_time,mean20_total_reward,mean20_steps_river,mean20_path_length,mean20_wallclock_time
2930,0.1,0.1,Mcolor,,0.0,,,joint-conservative,1.0,5.0,...,1.0,38.333333,-2808.478741,181.949495,1271.933286,52.349906,-2940.966667,188.216667,1348.016667,55.771753
2929,0.1,0.1,Mcolor,,0.0,,,joint-conservative,1.0,5.0,...,1.0,38.333333,-2808.478741,181.949495,1271.933286,56.86989,-2940.966667,188.216667,1348.016667,60.309951
2928,0.1,0.1,Mcolor,,0.0,,,joint-conservative,1.0,5.0,...,1.0,38.333333,-2808.478741,181.949495,1271.933286,48.890852,-2940.966667,188.216667,1348.016667,51.769567
2894,0.1,0.1,Mcolor,,0.0,,,joint,1.0,5.0,...,1.0,38.333333,-2808.478741,181.949495,1271.933286,51.541914,-2940.966667,188.216667,1348.016667,54.637806
2893,0.1,0.1,Mcolor,,0.0,,,joint,1.0,5.0,...,1.0,38.333333,-2808.478741,181.949495,1271.933286,56.515392,-2940.966667,188.216667,1348.016667,59.969268
2892,0.1,0.1,Mcolor,,0.0,,,joint,1.0,5.0,...,1.0,38.333333,-2808.478741,181.949495,1271.933286,49.93945,-2940.966667,188.216667,1348.016667,52.992981
2880,0.1,0.1,Mcolor,,0.0,,,joint,1.0,1.0,...,1.0,40.0,-2661.544974,168.390212,1247.033069,9.794437,-2523.683333,162.15,1165.333333,9.149413
2923,0.1,0.1,Mcolor,,0.0,,,joint-conservative,1.0,1.0,...,1.0,40.0,-2661.544974,168.390212,1247.033069,9.496928,-2523.683333,162.15,1165.333333,8.883409
2922,0.1,0.1,Mcolor,,0.0,,,joint-conservative,1.0,1.0,...,1.0,40.0,-2661.544974,168.390212,1247.033069,9.645071,-2523.683333,162.15,1165.333333,9.013811
2918,0.1,0.1,Mcolor,,0.0,,,joint-conservative,1.0,1.0,...,1.0,40.0,-2661.544974,168.390212,1247.033069,9.691133,-2523.683333,162.15,1165.333333,9.056673


In [65]:
logs.query()

Unnamed: 0,PS,PW,MM,theta,t-color,t-in,t-out,IM,NC,K,...,seed,num_episodes,mean_total_reward,mean_steps_river,mean_path_length,mean_wallclock_time,mean20_total_reward,mean20_steps_river,mean20_path_length,mean20_wallclock_time
0,0.0,0.0,MCAR,0.0,,,,joint,1.0,1.0,...,1.0,0.000000,,,,,,,,
1,0.0,0.0,MCAR,0.0,,,,joint,1.0,1.0,...,1.0,0.000000,,,,,,,,
2,0.0,0.0,MCAR,0.0,,,,joint,1.0,1.0,...,1.0,0.000000,,,,,,,,
3,0.0,0.0,MCAR,0.0,,,,joint,1.0,1.0,...,1.0,11.000000,-8007.041919,459.227273,3974.996465,32.241818,-8007.041919,459.227273,3974.996465,32.241818
4,0.0,0.0,MCAR,0.0,,,,joint,1.0,1.0,...,1.0,11.000000,-8007.041919,459.227273,3974.996465,31.768138,-8007.041919,459.227273,3974.996465,31.768138
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3235,0.1,0.1,Mfog,,,0.5,0.0,random-action,,,...,1.0,48.000000,-2271.652190,149.958701,1023.023879,7.606696,-2407.250000,158.366667,1082.950000,8.072081
3236,0.1,0.1,Mfog,,,0.5,0.0,random-action,,,...,1.0,48.000000,-2271.652190,149.958701,1023.023879,7.706984,-2407.250000,158.366667,1082.950000,8.166112
3237,0.1,0.1,Mfog,,,0.5,0.0,random-action,,,...,1.0,60.333333,-1911.805445,131.327353,830.859265,6.516143,-1660.066667,114.133333,733.866667,5.882746
3238,0.1,0.1,Mfog,,,0.5,0.0,random-action,,,...,1.0,60.333333,-1911.805445,131.327353,830.859265,6.500699,-1660.066667,114.133333,733.866667,5.828956


In [62]:
logs.query("IM == 'missing-state'").sort_values("mean_total_reward", ascending=False)

Unnamed: 0,PS,PW,MM,theta,t-color,t-in,t-out,IM,NC,K,...,seed,num_episodes,mean_total_reward,mean_steps_river,mean_path_length,mean_wallclock_time,mean20_total_reward,mean20_steps_river,mean20_path_length,mean20_wallclock_time
1251,0.0,0.1,MCAR,0.4,,,,missing-state,,,...,1.0,52.333333,-1940.209443,121.095036,951.354116,12.297618,-1613.683333,100.233333,812.583333,10.686564
1071,0.0,0.1,MCAR,0.1,,,,missing-state,,,...,1.0,52.333333,-1940.209443,121.095036,951.354116,12.502689,-1613.683333,100.233333,812.583333,10.871584
1432,0.0,0.1,Mcolor,,0.1,,,missing-state,,,...,1.0,52.333333,-1940.209443,121.095036,951.354116,13.086207,-1613.683333,100.233333,812.583333,11.360300
1431,0.0,0.1,Mcolor,,0.1,,,missing-state,,,...,1.0,52.333333,-1940.209443,121.095036,951.354116,15.228258,-1613.683333,100.233333,812.583333,13.207655
1343,0.0,0.1,Mcolor,,0.0,,,missing-state,,,...,1.0,52.333333,-1940.209443,121.095036,951.354116,12.898148,-1613.683333,100.233333,812.583333,11.112496
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2329,0.1,0.0,Mfog,,,0.25,0.1,missing-state,,,...,1.0,0.000000,,,,,,,,
2330,0.1,0.0,Mfog,,,0.25,0.1,missing-state,,,...,1.0,0.000000,,,,,,,,
2418,0.1,0.0,Mfog,,,0.50,0.0,missing-state,,,...,1.0,0.000000,,,,,,,,
2419,0.1,0.0,Mfog,,,0.50,0.0,missing-state,,,...,1.0,0.000000,,,,,,,,


In [54]:
logs.groupby(env_params + model_params + uni_params, dropna=False).mean().reset_index()

Unnamed: 0,PS,PW,MM,theta,t-color,t-in,t-out,IM,NC,K,...,seed,num_episodes,mean_total_reward,mean_steps_river,mean_path_length,mean_wallclock_time,mean20_total_reward,mean20_steps_river,mean20_path_length,mean20_wallclock_time
0,0.0,0.0,MCAR,0.0,,,,joint,1.0,1.0,...,1.0,0.000000,,,,,,,,
1,0.0,0.0,MCAR,0.0,,,,joint,1.0,1.0,...,1.0,0.000000,,,,,,,,
2,0.0,0.0,MCAR,0.0,,,,joint,1.0,1.0,...,1.0,0.000000,,,,,,,,
3,0.0,0.0,MCAR,0.0,,,,joint,1.0,1.0,...,1.0,11.000000,-8007.041919,459.227273,3974.996465,32.241818,-8007.041919,459.227273,3974.996465,32.241818
4,0.0,0.0,MCAR,0.0,,,,joint,1.0,1.0,...,1.0,11.000000,-8007.041919,459.227273,3974.996465,31.768138,-8007.041919,459.227273,3974.996465,31.768138
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3235,0.1,0.1,Mfog,,,0.5,0.0,random-action,,,...,1.0,48.000000,-2271.652190,149.958701,1023.023879,7.606696,-2407.250000,158.366667,1082.950000,8.072081
3236,0.1,0.1,Mfog,,,0.5,0.0,random-action,,,...,1.0,48.000000,-2271.652190,149.958701,1023.023879,7.706984,-2407.250000,158.366667,1082.950000,8.166112
3237,0.1,0.1,Mfog,,,0.5,0.0,random-action,,,...,1.0,60.333333,-1911.805445,131.327353,830.859265,6.516143,-1660.066667,114.133333,733.866667,5.882746
3238,0.1,0.1,Mfog,,,0.5,0.0,random-action,,,...,1.0,60.333333,-1911.805445,131.327353,830.859265,6.500699,-1660.066667,114.133333,733.866667,5.828956


In [None]:
# group by environment and base model
best_logs = logs.sort_values(by="mean_total_reward", ascending=False)\
.groupby(env_params + model_params + ["seed"], dropna=False).first().reset_index()

In [48]:
logs.groupby(env_params + model_params + uni_params, dropna=False).mean().reset_index()\
.query("PS == 0.1 and PW == 0.1 and MM == 'MCAR' and theta == 0.4")\
.sort_values(by="mean20_path_length").query("`p-shuf` == 0.1")

Unnamed: 0,PS,PW,MM,theta,t-color,t-in,t-out,IM,NC,K,...,seed,num_episodes,mean_total_reward,mean_steps_river,mean_path_length,mean_wallclock_time,mean20_total_reward,mean20_steps_river,mean20_path_length,mean20_wallclock_time
5621,0.1,0.1,MCAR,0.4,,,,joint,1.0,5.0,...,1.0,4827.000000,79.587514,1.217072,10.458837,0.524983,75.983333,1.700000,9.716667,0.487950
5651,0.1,0.1,MCAR,0.4,,,,joint,1.0,10.0,...,1.0,4554.666667,73.451192,1.839328,10.994856,0.966935,72.750000,2.016667,10.100000,0.889966
5644,0.1,0.1,MCAR,0.4,,,,joint,1.0,10.0,...,1.0,3270.000000,78.299606,0.816176,15.354806,1.382801,81.200000,0.683333,13.650000,1.235317
5627,0.1,0.1,MCAR,0.4,,,,joint,1.0,5.0,...,1.0,4365.333333,77.026303,1.390494,11.459255,0.586693,75.600000,1.300000,13.700000,0.702916
5626,0.1,0.1,MCAR,0.4,,,,joint,1.0,5.0,...,1.0,2382.666667,70.734988,1.028170,21.011481,1.117321,78.400000,0.966667,13.900000,0.742031
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5712,0.1,0.1,MCAR,0.4,,,,joint-conservative,1.0,10.0,...,1.0,36.000000,-3040.557041,196.402496,1373.934581,107.185355,-2981.150000,193.333333,1342.150000,104.777549
5715,0.1,0.1,MCAR,0.4,,,,joint-conservative,1.0,10.0,...,1.0,43.333333,-736.671939,5.693342,786.431864,71.653182,-1491.950000,3.433333,1562.050000,142.908998
5673,0.1,0.1,MCAR,0.4,,,,joint-conservative,1.0,1.0,...,1.0,14.000000,-2636.174603,20.886243,2549.198413,23.341778,-2636.174603,20.886243,2549.198413,23.341778
5595,0.1,0.1,MCAR,0.4,,,,joint,1.0,1.0,...,1.0,11.666667,-2630.072751,16.642857,2581.287037,23.460455,-2630.072751,16.642857,2581.287037,23.460455


In [47]:
logs.columns

Index(['PS', 'PW', 'MM', 'theta', 't-color', 't-in', 't-out', 'IM', 'NC', 'K',
       'p-shuf', 'max-iters', 'eps', 'a', 'g', 'seed', 'num_episodes',
       'mean_total_reward', 'mean_steps_river', 'mean_path_length',
       'mean_wallclock_time', 'mean20_total_reward', 'mean20_steps_river',
       'mean20_path_length', 'mean20_wallclock_time'],
      dtype='object')