In [3]:
import os
import sys
sys.path.append('/'.join(os.getcwd().split('/')[:-1]))

import numpy as np
import matplotlib.pyplot as plt
import matplotlib

from Agents import QLearningAgent, BayesianQAgent, PSRLAgent, UbeNoUnrollAgent, MomentMatchingAgent
from Environments import DeepSea, WideNarrow, PriorMDP

from utils import solve_tabular_continuing_PI,   \
                  run_experiment,                \
                  run_oracle_experiment,         \
                  load_agent,                    \
                  get_agent_and_oracle_regret

from tqdm import tqdm_notebook as tqdm

# For saving figures and agents
if not os.path.exists('results'): os.mkdir('results')
if not os.path.exists('results/figures'): os.mkdir('results/figures')
if not os.path.exists('results/agent_logs'): os.mkdir('results/agent_logs')

fig_loc = 'results/figures/'

plt.rc('xtick', labelsize=16)
plt.rc('ytick', labelsize=16)
plt.rc('legend', fontsize=16)
plt.rc('figure', titlesize=50)

In [11]:
# Agent parameters
agent_classes = [QLearningAgent, BayesianQAgent, PSRLAgent, UbeNoUnrollAgent, MomentMatchingAgent]
agent_names = ['QL', 'BQL', 'PSRL', 'UBE ($\\zeta = 0.1$)', 'MM ($\\zeta = 1.0$)']

agent_params = [{'gamma'            : 0.9,
                 'dither_param'     : 0.5,
                 'lr'               : 0.1,
                 'Q0'               : 0.0,
                 'anneal_timescale' : float('inf'),
                 'dither_mode'      : 'epsilon-greedy'},
                
                {'gamma'               : 0.9,
                 'mu0'                 : 0.0,
                 'lamda'               : 1.0,
                 'alpha'               : 2.0,
                 'beta'                : 2.0,
                 'num_mixture_samples' : 1000},
                
                {'gamma'            : 0.9,
                 'kappa'            : 1.0,
                 'mu0'              : 0.0,
                 'lamda'            : 1.0,
                 'alpha'            : 2.0,
                 'beta'             : 2.0},
               
                {'gamma'            : 0.9,
                 'kappa'            : 1.0,
                 'mu0'              : 0.0,
                 'lamda'            : 1.0,
                 'alpha'            : 2.0,
                 'beta'             : 2.0,
                 'zeta'             : 1.0,
                 'num_dyn_samples'  : 100},
               
                {'gamma'            : 0.9,
                 'kappa'            : 1.0,
                 'mu0'              : 0.0,
                 'lamda'            : 1.0,
                 'alpha'            : 2.0,
                 'beta'             : 2.0,
                 'zeta'             : 1.0,
                 'num_dyn_samples'  : 100}]

# DeepSea

In [19]:
for N in [4, 6, 8, 10, 12]:

    num_time_steps = 1250 * N
    save_every = num_time_steps // 100
    
    delta = 1e-1 * np.exp(- N / 4)
    rew_params = ((0., delta), (-delta, delta), (1., delta))
    env_params = {'N'          :  N,
                  'episodic'   :  False,
                  'rew_params' :  rew_params}

    # Define environment
    environment = DeepSea(env_params)

    # Number of PI steps and maximum buffer length (PSRL, UBE and MM only)
    max_iter = 2 * N
    
    for i, (agent_class, agent_param) in enumerate(zip(agent_classes, agent_params)):
        
        agent_param['max_iter'] = max_iter
        agent_param['sa_list'] = environment.sa_list()
        agent_param['Rmax'] = environment.get_mean_P_and_R()[1].max()
        
        agent = agent_class(agent_param)
        
        max_buffer_length = 1 if i <= 1 else N + 1
        
        for seed in range(10):
            
            run_experiment(environment=environment,
                           agent=agent,
                           seed=seed,
                           num_time_steps=num_time_steps,
                           max_buffer_length=,
                           save_every=save_every)

1


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


5


HBox(children=(IntProgress(value=0, max=5001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


7


HBox(children=(IntProgress(value=0, max=7501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


9


HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


11


HBox(children=(IntProgress(value=0, max=12501), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


1


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))


13


HBox(children=(IntProgress(value=0, max=15001), HTML(value='')))




# WideNarrow

In [22]:
for N in [1, 2]:
    for W in [6, 10, 14]:

        num_time_steps = 1000 * N * W
        save_every = num_time_steps // 100

        mu_l, mu_h = 0.0, 0.5
        sig_l, sig_h = 1.0, 1.0

        rew_params = ((mu_l, sig_l), (mu_h, sig_h), (mu_l, sig_l))
        env_params = {'N'          : N,
                      'W'          : W,
                      'rew_params' : rew_params}

        # Define environment
        environment = WideNarrow(env_params)

        # Number of PI steps and maximum buffer length (PSRL, UBE and MM only)
        max_iter = 4 * N

        for i, (agent_class, agent_param) in enumerate(zip(agent_classes, agent_params)):

            agent_param['max_iter'] = max_iter
            agent_param['sa_list'] = environment.sa_list()
            agent_param['Rmax'] = environment.get_mean_P_and_R()[1].max()

            agent = agent_class(agent_param)
            
            max_buffer_length = 1 if i <= 1 else 2 * N + 1

            for seed in range(10):
                
                run_experiment(environment=environment,
                               agent=agent,
                               seed=seed,
                               num_time_steps=num_time_steps,
                               max_buffer_length=max_buffer_length,
                               save_every=save_every)

HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=6001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=10001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=14001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=12001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=20001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




HBox(children=(IntProgress(value=0, max=28001), HTML(value='')))




# PriorMDP

In [None]:
for Ns in [2, 3]:
    for Na in [4, 6, 8]:
        for env_seed in range(4):

            num_time_steps = 1000 * Ns
            save_every = num_time_steps // 100

            env_params = {'Ns'         : Ns,
                          'Na'         : Na,
                          'kappa'      : 1.0,
                          'mu0'        : 0.0,
                          'lamda'      : 1.0,
                          'alpha'      : 4.0,
                          'beta'       : 4.0,
                          'seed'       : env_seed}

            # Define environment
            environment = PriorMDP(env_params)

            # Number of PI steps and maximum buffer length (PSRL, UBE and MM only)
            max_iter = 2 * Ns

            for i, (agent_class, agent_param) in enumerate(zip(agent_classes, agent_params)):

                agent_param['max_iter'] = max_iter
                agent_param['sa_list'] = environment.sa_list()
                agent_param['Rmax'] = environment.get_mean_P_and_R()[1].max()

                agent = agent_class(agent_param)

                max_buffer_length = 1 if i <= 1 else N + 1

                for seed in range(10):

                    run_experiment(environment=environment,
                                   agent=agent,
                                   seed=seed,
                                   num_time_steps=num_time_steps,
                                   max_buffer_length=max_buffer_length,
                                   save_every=save_every)