In [1]:
from collections import defaultdict
import random
import numpy as np
np.set_printoptions(precision=2, suppress=True)

import time
import copy 
import multiprocess as mp

import gym
from env import FrozenLakeCustom, FrozenLakeSimulator

from mcts_haver_stochastic import run_mcts_trial
from value_iteration import value_iteration

from config import parse_args
from utils import MultiProcess

import logging
logger = logging.getLogger()
logger.setLevel(logging.FATAL)

In [2]:
np.random.seed(0)
random.seed(0)

# params
args = parse_args()
args["ep_max_steps"] = 20
args["map_name"] = "4x4X"
args["is_state_slippery"] = True
args["is_slippery"] = False
args["slippery_mode"] = "mild"


#
env_id = "FrozenLake-v1"
env = FrozenLakeCustom(
    map_name=args["map_name"], 
    is_state_slippery=args["is_state_slippery"],
    is_slippery=args["is_slippery"], slippery_mode=args["slippery_mode"], 
    render_mode=args["render_mode"])

simulator = FrozenLakeSimulator(env.P)

V_vit, Q_vit = value_iteration(
    simulator, args["gamma"], args["vit_thres"])
# global Q_vit_g = Q_vit
        
for state in range(simulator.num_states):
    logging.warning(f"\n-> state = {state}")
    logging.warning(f"V[state] = {V_vit[state]:0.4f}")
    for action in range(simulator.num_actions):
        logging.warning(f"Q[state][action] = {Q_vit[state][action]:0.4f}")
    logging.warning(f"best_action={np.argmax(Q_vit[state])}")
    
manager = mp.Manager()
ep_reward_list = manager.list()
Q_mcts_list = manager.list()

def run_trial(i_trial, Q_vit, args):

    random.seed(10000+i_trial)
    np.random.seed(10000+i_trial)

    env = FrozenLakeCustom(
        map_name=args["map_name"], 
        is_state_slippery=args["is_state_slippery"],
        is_slippery=args["is_slippery"], slippery_mode=args["slippery_mode"], 
        render_mode=args["render_mode"])

    simulator = FrozenLakeSimulator(env.P)

    Q_mcts, ep_reward = run_mcts_trial(env, simulator, Q_vit, i_trial, args)

    ep_reward_list.append(ep_reward)
    Q_mcts_list.append(Q_mcts)
    return ep_reward

In [6]:
args["update_method"] = "haver"
args["rollout_method"] = ""

print(f"map_name = {args['map_name']}")
print(f"ep_max_steps = {args['ep_max_steps']}")
print(f"num_trials = {args['num_trials']}")
# print(f"mcts_num_trajectories = {args['mcts_num_trajectories']}")


hparam_ucb_scale_list = np.arange(10, 100, 10)
hparam_ucb_scale_list = [1, 2, 4, 8, 16, 32, 64, 128]
# hparam_ucb_scale_list = [2**i for i in range(1, 9)]
args["hparam_ucb_scale"] = 64

hparam_haver_std_list = np.arange(10, 100, 10)
hparam_haver_std_list = [1, 2, 4, 8, 16, 32, 64, 128]
# hparam_haver_std_list = [2**i for i in range(1, 9)]


num_trajectories_list = [200, 500, 1000, 1500, 2000, 2500, 3000]
# num_trajectories_list = [200, 500, 800]
# num_trajectories_list = [5000, 7500, 10000]
best_param_list = []
max_reward_mean_list = []
res_text1 = ""
res_text2 = ""
for num_trajectories in num_trajectories_list:
    print(f"\n-> num_trajectories = {num_trajectories}")
    args["mcts_num_trajectories"] = num_trajectories
    
    best_param = None
    max_reward_mean = -np.inf
    start_time = time.time()
    res_text1 += f"{num_trajectories} "
    res_text2 += f"{num_trajectories} "
    for hparam_haver_std in hparam_haver_std_list:
        # start_time = time.time()
    
        args["hparam_haver_var"] = hparam_haver_std**2
        print(f"hparam_haver_var = {args['hparam_haver_var']}")
        # print(f"hparam_ucb_scale = {args['hparam_ucb_scale']}")
        
        pool = mp.Pool()
        pool.starmap(run_trial, [(i, Q_vit, args) for i in range(args["num_trials"])])

        reward_mean = np.mean(ep_reward_list)
        reward_std = np.std(ep_reward_list, ddof=1) if len(ep_reward_list) > 1 else 0
        reward_error = reward_std/np.sqrt(args["num_trials"])
        if hparam_haver_std <= 8:
            res_text1 += f"& {reward_mean:0.2f} (\u00B1{reward_error:0.2f}) "
        else:
            res_text2 += f"& {reward_mean:0.2f} (\u00B1{reward_error:0.2f}) "
        print(f"reward = {reward_mean:0.2f} +/- {reward_error:0.2f}")

        # if reward_mean > max_reward_mean:
        #     max_reward_mean = reward_mean 
        #     best_param = hparam_haver_std
    
        ep_reward_list[:] = []
        Q_mcts_list[:] = []
    
        end_time = time.time()
        # print(f"it takes {end_time-start_time:0.4f}")
    
    res_text1 += "\\\\ \n \hline \n"
    res_text2 += "\\\\ \n \hline \n"
    
    # print(f"max_reward_mean = {max_reward_mean:0.2f}")
    print(f"it takes {end_time-start_time:0.4f}")
    
    max_reward_mean_list.append(max_reward_mean)
    best_param_list.append(best_param)

map_name = 4x4X
ep_max_steps = 20
num_trials = 20

-> num_trajectories = 5000
hparam_haver_var = 1
reward = -34.75 +/- 10.43
hparam_haver_var = 4
reward = -39.70 +/- 10.76
hparam_haver_var = 16
reward = -44.75 +/- 11.03
hparam_haver_var = 64
reward = -39.70 +/- 10.81
hparam_haver_var = 256
reward = -39.65 +/- 10.74
hparam_haver_var = 1024
reward = -39.40 +/- 10.86
hparam_haver_var = 4096
reward = -39.30 +/- 10.75
hparam_haver_var = 16384
reward = -9.40 +/- 4.98
it takes 376.1376

-> num_trajectories = 7500
hparam_haver_var = 1
reward = -39.35 +/- 10.79
hparam_haver_var = 4
reward = -29.70 +/- 9.80
hparam_haver_var = 16
reward = -39.20 +/- 10.77
hparam_haver_var = 64
reward = -29.50 +/- 9.77
hparam_haver_var = 256
reward = -24.75 +/- 9.07
hparam_haver_var = 1024
reward = -24.65 +/- 9.11
hparam_haver_var = 4096
reward = -24.10 +/- 9.00
hparam_haver_var = 16384
reward = -23.95 +/- 9.10
it takes 587.0572

-> num_trajectories = 10000
hparam_haver_var = 1
reward = -29.45 +/- 9.88
hparam_have

In [4]:
print(res_text1)
print(res_text2)

1500 & -38.60 (±10.60) & -39.25 (±10.71) & -53.70 (±11.18) & -44.10 (±11.06) \\ 
 \hline 
2000 & -34.05 (±10.29) & -49.25 (±11.25) & -34.00 (±10.27) & -24.15 (±8.91) \\ 
 \hline 
2500 & -19.45 (±8.16) & -29.55 (±9.81) & -28.85 (±9.77) & -28.95 (±9.78) \\ 
 \hline 
3000 & -24.75 (±9.15) & -39.60 (±10.75) & -19.10 (±8.09) & -24.25 (±9.07) \\ 
 \hline 

1500 & -48.60 (±11.11) & -58.35 (±11.21) & -33.80 (±10.27) & -43.45 (±10.95) \\ 
 \hline 
2000 & -24.10 (±9.00) & -23.95 (±9.13) & -33.80 (±10.20) & -33.80 (±10.25) \\ 
 \hline 
2500 & -24.10 (±9.08) & -14.05 (±6.80) & -19.60 (±8.11) & -28.90 (±9.82) \\ 
 \hline 
3000 & -24.40 (±9.02) & -19.35 (±8.13) & -24.25 (±9.04) & -9.40 (±4.88) \\ 
 \hline 

