In [1]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.notebook import tqdm
from collections import namedtuple

from definitions import ROOT_DIR
from unc.utils import load_info

plt.rcParams.update({'font.size': 18})



In [27]:
all_paths = {
    '2t_nn_prediction': Path(ROOT_DIR, 'results', '2t_nn_prediction')
}

split_by = ['step_size', 'epsilon']
Args = namedtuple('args', split_by)

In [28]:
def process_dir(dir_path, split_by):
    all_results = {}
        
    for f in tqdm(list(dir_path.iterdir())):
        if not f.is_file() or f.suffix != '.npy':
            continue
        info = load_info(f)
        args = info['args'].item()
        
        hparams = Args(*tuple(args[s] for s in split_by))

        if hparams not in all_results:
            all_results[hparams] = []

#         all_results[hparams].append((info['episode_reward'], args))
        all_results[hparams].append((args, info['avg_episode_loss'], f))


    
    return all_results

In [29]:
raw_data = {}

for key, v in all_paths.items():
    processed = process_dir(v, split_by)
    
    raw_data[key] = {}
    for hparams, ret in processed.items():
        loss_lst, arg_lst, f_lst = [], [], []
        for r in ret:
            loss_lst.append(r[1])
            arg_lst.append(r[0])
            f_lst.append(r[2])
        raw_data[key][hparams] = {'loss': np.stack(loss_lst), 'args': arg_lst, 'paths': f_lst}
        
    

  0%|          | 0/243 [00:00<?, ?it/s]

In [62]:
# Here we get our best hparams
all_best = {}
best_over = 500

for key, res in raw_data.items():
    stats_list = []
    for args, rew_dict in res.items():
        # we take mean over both time and seeds
        stats_list.append((args, rew_dict['loss'][:, -best_over].mean(), rew_dict['loss'], rew_dict['args'], rew_dict['paths']))
    all_best[key] = sorted(stats_list, key=lambda x: x[1])[0]
# all_best['2t_nn_prediction'][-2]
[s for s in stats_list if s[0].step_size == 1e-5]


[(args(step_size=1e-05, epsilon=1.0),
  0.27036023,
  array([[0.29788962, 0.27775106, 0.25880617, ..., 0.28746188, 0.33162522,
          0.16581258],
         [0.32760417, 0.41153118, 0.37234184, ..., 0.36327934, 0.20247269,
          0.27912807],
         [0.264392  , 0.37890106, 0.30763337, ..., 0.28080788, 0.24747871,
          0.39746556],
         ...,
         [0.27284187, 0.27630654, 0.36924455, ..., 0.3432928 , 0.3332931 ,
          0.25080335],
         [0.32730085, 0.35114232, 0.3743248 , ..., 0.23663229, 0.28662664,
          0.24163316],
         [0.30665952, 0.4144486 , 0.398225  , ..., 0.33662984, 0.18914647,
          0.30163354]], dtype=float32),
  [{'po_degree': 0.0,
    'epsilon_start': 1.0,
    'buffer_size': 20000,
    'test_episodes': 5,
    'p_prefilled': 0.0,
    'k_rnn_hs': 1,
    'checkpoint_freq': 0,
    'distance_noise': False,
    'weight_init': 'fan_avg',
    'unnormalized_counts': False,
    'fixed_gvf_path': None,
    'max_episode_steps': 200,
    'epsilo

In [59]:
# Checkpoint directories are sorted by seed.
checkpoint_dirs = [None for i in range(len(all_best['2t_nn_prediction'][-2]))]


for arg, path in zip(all_best['2t_nn_prediction'][-2], all_best['2t_nn_prediction'][-1]):
    checkpoints_dir = path.parent / 'checkpoints'
    arg_checkpoint_name = path.name.split('_')[0]
    arg_checkpoint_dir = checkpoints_dir / arg_checkpoint_name
    latest_checkpoint_name = sorted(list(arg_checkpoint_dir.iterdir()))[-1].name
    slurm_results_dir = Path('/home/taodav/scratch/uncertainty/results/2t_nn_prediction')
    slurm_checkpoint_fname = slurm_results_dir / 'checkpoints' / arg_checkpoint_name / latest_checkpoint_name
    checkpoint_dirs[arg['seed'] - 2020] = str(slurm_checkpoint_fname)

In [60]:
checkpoint_dirs[:30]

['/home/taodav/scratch/uncertainty/results/2t_nn_prediction/checkpoints/f1ae73a28536159489b708d248a8a02e/500000.pkl',
 '/home/taodav/scratch/uncertainty/results/2t_nn_prediction/checkpoints/9ef9314dc18a228abc35095921646bfc/500000.pkl',
 '/home/taodav/scratch/uncertainty/results/2t_nn_prediction/checkpoints/2147d68114ecd9bc91c236ddde5d8c2b/500000.pkl',
 '/home/taodav/scratch/uncertainty/results/2t_nn_prediction/checkpoints/fd587082f2a90e8036bba1c0c11e7c46/500000.pkl',
 '/home/taodav/scratch/uncertainty/results/2t_nn_prediction/checkpoints/cef489976dedec7ed44934333c05925e/500000.pkl',
 '/home/taodav/scratch/uncertainty/results/2t_nn_prediction/checkpoints/aaafa5a2192e86755110ee9a92606e52/500000.pkl',
 '/home/taodav/scratch/uncertainty/results/2t_nn_prediction/checkpoints/fe06cb483a7d0b65c2fb550821e82258/500000.pkl',
 '/home/taodav/scratch/uncertainty/results/2t_nn_prediction/checkpoints/fb360bbcf03f8427fd4745be3be8e1c8/500000.pkl',
 '/home/taodav/scratch/uncertainty/results/2t_nn_predict