In [1]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.notebook import tqdm
from collections import namedtuple

from definitions import ROOT_DIR
from unc.utils import load_info

plt.rcParams.update({'font.size': 18})



In [63]:
all_paths = {
    '2t_nn_prediction': Path(ROOT_DIR, 'results', '2t_nn_prediction')
}

split_by = ['step_size', 'epsilon']
Args = namedtuple('args', split_by)

In [65]:
def process_dir(dir_path, split_by):
    all_results = {}
        
    for f in tqdm(list(dir_path.iterdir())):
        if not f.is_file() or f.suffix != '.npy' or 'layer' in f.name:
            continue
        info = load_info(f)
        args = info['args'].item()
        
        hparams = Args(*tuple(args[s] for s in split_by))

        if hparams not in all_results:
            all_results[hparams] = []

#         all_results[hparams].append((info['episode_reward'], args))
        all_results[hparams].append((args, info['avg_episode_loss'], f))


    
    return all_results

In [66]:
raw_data = {}

for key, v in all_paths.items():
    processed = process_dir(v, split_by)
    
    raw_data[key] = {}
    for hparams, ret in processed.items():
        loss_lst, arg_lst, f_lst = [], [], []
        for r in ret:
            loss_lst.append(r[1])
            arg_lst.append(r[0])
            f_lst.append(r[2])
        raw_data[key][hparams] = {'loss': np.stack(loss_lst), 'args': arg_lst, 'paths': f_lst}
        
    

  0%|          | 0/94 [00:00<?, ?it/s]

In [67]:
# Here we get our best hparams
all_best = {}
best_over = 500

for key, res in raw_data.items():
    stats_list = []
    for args, rew_dict in res.items():
        # we take mean over both time and seeds
        stats_list.append((args, rew_dict['loss'][:, -best_over].mean(), rew_dict['loss'], rew_dict['args'], rew_dict['paths']))
    all_best[key] = sorted(stats_list, key=lambda x: x[1])[0]
# all_best['2t_nn_prediction'][-2]
[s for s in stats_list if s[0].step_size == 1e-5]


[(args(step_size=1e-05, epsilon=1.0),
  0.10902104,
  array([[0.08887608, 0.10727347, 0.09531257, ..., 0.10582814, 0.13916044,
          0.11166141],
         [0.1084057 , 0.12891631, 0.13388388, ..., 0.06999672, 0.0841626 ,
          0.10832828],
         [0.10730546, 0.09628888, 0.0976272 , ..., 0.04916447, 0.07499634,
          0.09332876],
         ...,
         [0.14245696, 0.17409822, 0.14637515, ..., 0.09915948, 0.03249844,
          0.05999561],
         [0.13801372, 0.12022826, 0.09747491, ..., 0.08916309, 0.0849959 ,
          0.05416443],
         [0.10611156, 0.130829  , 0.13111934, ..., 0.1433242 , 0.07999524,
          0.06916246]], dtype=float32),
  [{'checkpoint_freq': 0,
    'v_max': 100,
    'exploration': 'eps',
    'value_step_size': 0.0001,
    'algo': 'sarsa',
    'save_model': False,
    'k_rnn_hs': 1,
    'max_episode_steps': 200,
    'task_fname': 'task_{}_config.json',
    'offline_eval_freq': 0,
    'batch_size': 64,
    'init_hidden_var': 0.0,
    'resample_

In [68]:
# Checkpoint directories are sorted by seed.
checkpoint_dirs = [None for i in range(len(all_best['2t_nn_prediction'][-2]))]


for arg, path in zip(all_best['2t_nn_prediction'][-2], all_best['2t_nn_prediction'][-1]):
    checkpoints_dir = path.parent / 'checkpoints'
    arg_checkpoint_name = path.name.split('_')[0]
    arg_checkpoint_dir = checkpoints_dir / arg_checkpoint_name
    latest_checkpoint_name = sorted(list(arg_checkpoint_dir.iterdir()))[-1].name
    slurm_results_dir = Path('/home/taodav/scratch/uncertainty/results/2t_nn_prediction')
    slurm_checkpoint_fname = slurm_results_dir / 'checkpoints' / arg_checkpoint_name / latest_checkpoint_name
    checkpoint_dirs[arg['seed'] - 2020] = str(slurm_checkpoint_fname)

In [69]:
checkpoint_dirs[:30]

['/home/taodav/scratch/uncertainty/results/2t_nn_prediction/checkpoints/3b0cd57081bfd8ac0bcc37708091119c/500000.pkl',
 '/home/taodav/scratch/uncertainty/results/2t_nn_prediction/checkpoints/8edd78bdf60ea6a0e31be780ec2141ef/500000.pkl',
 '/home/taodav/scratch/uncertainty/results/2t_nn_prediction/checkpoints/c0c05b42d05fa5fb3ee3361fd10077b1/500000.pkl',
 '/home/taodav/scratch/uncertainty/results/2t_nn_prediction/checkpoints/789a2dfaecd896fc125ea87390f76970/500000.pkl',
 '/home/taodav/scratch/uncertainty/results/2t_nn_prediction/checkpoints/862d45c24e9629e23d726eff5d7767fc/500000.pkl',
 '/home/taodav/scratch/uncertainty/results/2t_nn_prediction/checkpoints/05184069ec1e8fe996aa3eaf417d9c1b/500000.pkl',
 '/home/taodav/scratch/uncertainty/results/2t_nn_prediction/checkpoints/43065c6b09b65de0840836b674ba78a7/500000.pkl',
 '/home/taodav/scratch/uncertainty/results/2t_nn_prediction/checkpoints/c3868a1bfe69cc0135619af24456c50a/500000.pkl',
 '/home/taodav/scratch/uncertainty/results/2t_nn_predict