In [1]:
#from MuZeroParallel import MuZero as MuZeroP

from MuZero import MuZero

from GridWorldOption import GridWorldOption


from GridWorld import GridWorld

from mcts.MCTS import MCTS
import numpy as np
import matplotlib.pyplot as plt

from random import choice
from copy import deepcopy
import time

import pickle
import pandas as pd

from matplotlib.pyplot import figure
import itertools
from tqdm import tqdm


In [2]:
MAP_NAME = './maps/door.map'
with open(MAP_NAME) as f:
    the_map = f.read()

In [3]:
def get_g(full_options):
    env = GridWorld(the_map)

    G_model = {}
    for position in tqdm(env.possible_positions):
        for option in full_options:
            if option.is_valid(position):

                test_env = GridWorld(the_map)
                test_env.cur_x = position[0]
                test_env.cur_y = position[1]
                rewards = []
                while True:
                    action, should_break = option.get_action(test_env)

                    if action == -1:
                        break

                    next_position, r, _ = test_env.step(action)
                    rewards.append(r)

                    if should_break:
                        break

                G_model[(position, option)] = (next_position, rewards)
    return G_model


In [4]:
def get_v():
    V_model = {}

    env = GridWorld(the_map)
    the_opt  = GridWorldOption((env.goal_x, env.goal_y),   {'all'}, 10)

    for position in tqdm(env.possible_positions):
        env_test = GridWorld(the_map)

        the_opt  = GridWorldOption((env_test.goal_x, env_test.goal_y),   {'all'}, 10)


        env_test.cur_x = position[0]
        env_test.cur_y = position[1]

        state = position
        rewards = []
        done = False
        while True:

            action, _ = the_opt.get_action(env_test)

            if action == -1 or done:
                break

            #print(action)
            state, r, done = env_test.step(action)
            #print(state)
            rewards.append(r)

        G = 0
        counter = 0
        for r in rewards:
            G += (0.99 ** counter) * r
            counter += 1

        #print(len(rewards))
        #print(G)

        V_model[position] = G
    return V_model

In [5]:
MAP_NAME = './maps/door.map'
with open(MAP_NAME) as f:
    the_map = f.read()

    
first_room_pos = [(i,j) for i in range(6) for j in range(8)]
second_room_pos = [(i,j) for i in range(6) for j in range(9, 15)]
third_room_pos = [(i,j) for i in range(7,13) for j in range(8)]
fourth_room_pos = [(i,j) for i in range(7,13) for j in range(9,15)]

options = [
    # primitives, (0,0) is meaningless
    GridWorldOption((0, 0),   {'all'}, 0, 0),
    GridWorldOption((0, 0),   {'all'}, 1, 1),
    GridWorldOption((0, 0),   {'all'}, 2, 2),
    GridWorldOption((0, 0),   {'all'}, 3, 3),
    GridWorldOption((3,8),  set(first_room_pos + second_room_pos + [(6,3)] + [(6,13)]), 4),
    GridWorldOption((6,3),  set(first_room_pos + third_room_pos + [(3,8)] + [(11,8)]), 5),
    GridWorldOption((6,13), set(second_room_pos + fourth_room_pos + [(3,8)] + [(11,8)]), 6),
    GridWorldOption((11,8), set(third_room_pos + fourth_room_pos + [(6,3)] + [(6,13)]), 7),
]

G_model = get_g(options)
V_model = get_v()

env = GridWorld(the_map)
mu_debug = MuZero(env, options)


s_next = mu_debug.env.reset()


s0 = s_next
mu_debug.f.v_table = V_model
mu_debug.g.model_table = G_model

100%|██████████| 172/172 [00:00<00:00, 1608.80it/s]
100%|██████████| 172/172 [00:00<00:00, 1384.14it/s]


In [6]:
mcts = MCTS(s0, mu_debug.f, mu_debug.g, mu_debug.options)


mcts.run_sim(600)

array([0.07666667, 0.15666667, 0.16333333, 0.08833333, 0.28833333,
       0.22666667, 0.        , 0.        ])

In [7]:
mcts.info()

Expanded:  True
Leaf:  False
hidden state (1, 1)
Option 0
N(s, o) 64
Q(s, o) [-24.198897479046863]
P(s, o) 0.16666666666666666
Q'(s, o):  0.0293486155797228
UCT exploration:  0.3768445758127966
Prior regulation:  1.2801238164214412
U(s, o) 0.10974990167758886
Option 1
N(s, o) 189
Q(s, o) [-23.242468234525973]
P(s, o) 0.16666666666666666
Q'(s, o):  0.0857037841799595
UCT exploration:  0.128920512778062
Prior regulation:  1.2801238164214412
U(s, o) 0.11320948731870314
Option 2
N(s, o) 171
Q(s, o) [-23.288970453974915]
P(s, o) 0.16666666666666666
Q'(s, o):  0.08296375878619067
UCT exploration:  0.14241219434785918
Prior regulation:  1.2801238164214412
U(s, o) 0.11334796574177958
Option 3
N(s, o) 64
Q(s, o) [-24.198897479046867]
P(s, o) 0.16666666666666666
Q'(s, o):  0.02934861557972259
UCT exploration:  0.3768445758127966
Prior regulation:  1.2801238164214412
U(s, o) 0.10974990167758863
Option 4
N(s, o) 49
Q(s, o) [-24.696715909716314, -23.93607667648572, -23.167754218673466, -22.39167092