In [1]:
import numpy as np
import torch
import os
from pathlib import Path
import glob

from tangram import Tangram
from mctsnettree import MCTSnetTree
%matplotlib inline

In [2]:
torch.set_default_dtype(torch.float64)

seed=123
render=False
gpu=True

n_grid = 10
n_blocks = 3
n_possible_blocks = 5
embedding_size = 8
n_simuls = 10 #25
n_evals = 10

serialization_path = './models'
print('serialization_path: ',serialization_path)
# create folder 
Path(serialization_path).mkdir(parents=True, exist_ok=True)

serialization_path:  ./models


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if gpu and torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
else:
    torch.manual_seed(seed)
    device = "cpu"
    if gpu:
        gpu = False
torch.set_num_threads(os.cpu_count() - 1)
print("Running on Device = ", device)
filesave_paths = sorted(glob.glob(f'{serialization_path}/mctsnet_e*'))
mctsnet = torch.load(open(filesave_paths[-1],'rb'), map_location=torch.device('cpu'))
n_episodes = int(filesave_paths[-1][18:23])
running_reward = float(filesave_paths[-1][25:].replace('.pt',''))
print('Loaded MCTSnet from '+ filesave_paths[-1])
seed += n_episodes
env = Tangram(seed, n_grid, n_blocks, n_possible_blocks)

Running on Device =  cpu
Loaded MCTSnet from ./models/mctsnet_e70000_p0.8333912764112664.pt


In [4]:
mctsnet.eval()

with torch.no_grad():
    success = 0
    for evanl_num in range(1,n_evals+1):
        env.reset()
        tree = MCTSnetTree(env, embedding_size, device)
        done = False
        while not done:
            if render:
                env.render()
            probs, probs_mask = mctsnet(tree, n_simuls)
            print(np.sum(probs_mask))
            probs = torch.squeeze(probs)
            probs_mask = torch.tensor(probs_mask).to(device)
            masked_probs = probs*probs_mask 
            if ~np.any(masked_probs.clone().cpu().detach().numpy()):
                masked_probs += probs_mask
            #sample_probs = masked_probs/torch.sum(masked_probs)
            #m_sample = Categorical(sample_probs)
            #action = m_sample.sample().item()
            action = torch.argmax(masked_probs).item()
            block = action//n_grid//n_grid
            y = (action - block*n_grid*n_grid)//n_grid
            x = action - block*n_grid*n_grid - y*n_grid
            env_action = np.array([block,y,x])
            env_state, env_reward, done = env.step(env_action)

            state = torch.unsqueeze(torch.tensor(env_state[:2]), 0).to(device)
            reward = torch.unsqueeze(torch.tensor([env_reward]), 0).to(device)
            for (child_id, child) in tree.get_root().get_children():
                if child_id == action:
                    child.set_state(state)
                    child.set_reward(reward)
                    child.set_done(done)
                    child.set_probs(probs)
                    child.set_probs_mask(tree.get_env().get_mask())
                    tree.set_root(child)
                    break
        if env_reward == 1:
            success += 1
        if render:
            env.render()
    print("Success rate: {}%".format(success*100/n_evals))

6.0
4.0
1.0
14.0
6.0
1.0
7.0
4.0
1.0
11.0
4.0
1.0
17.0
2.0
1.0
9.0
3.0
1.0
15.0
2.0
11.0
5.0
1.0
8.0
2.0
1.0
16.0
8.0
1.0
Success rate: 90.0%
