In [1]:
import numpy as np
import torch
import torch.optim as optim
import os
from torch.distributions import Categorical
from pathlib import Path
import glob
import gc

from mctsnet import MCTSnet
from tangram import Tangram
from mctsnettree import MCTSnetTree
from itertools import count
%matplotlib inline

In [2]:
torch.set_default_dtype(torch.float64)
torch.autograd.set_detect_anomaly(True)

gamma=0.9
seed=543
render=False
log_interval=1
gpu=True
load_agent=True

n_grid = 10
n_blocks = 3
n_possible_blocks = 5

embedding_size = 128
readout_hidden_size = 128
action_dims = [n_possible_blocks,n_grid,n_grid]
state_dims = [2,n_grid,n_grid]
embedding_n_residual_blocks = 3
embedding_channel_sizes = [64,64,64,32]
embedding_kernels = [3,3,3,1]
embedding_strides = [1,1,1,1]
policy_n_residual_blocks = 2
policy_channel_sizes = [32,32,32,32]
policy_kernels = [3,3,3,1]
policy_strides = [1,1,1,1]

n_simuls = 10 #25

serialization_path = './models'
print('serialization_path: ',serialization_path)
serialize_every_n_episodes = 10000
# create folder 
Path(serialization_path).mkdir(parents=True, exist_ok=True)

serialization_path:  ./models


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if gpu and torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
else:
    torch.manual_seed(seed)
    device = "cpu"
    if gpu:
        gpu = False
torch.set_num_threads(os.cpu_count() - 1)
print("Running on Device = ", device)
filesave_paths = sorted(glob.glob(f'{serialization_path}/mctsnet_e*'))
if load_agent and len(filesave_paths) > 0:
    mctsnet = torch.load(open(filesave_paths[-1],'rb'))
    n_episodes = int(filesave_paths[-1][18:23])
    running_reward = float(filesave_paths[-1][25:].replace('.pt',''))
    print('Loaded MCTSnet from '+ filesave_paths[-1])
else:
    mctsnet = MCTSnet(embedding_size,
                      readout_hidden_size,
                      action_dims,
                      state_dims,
                      embedding_n_residual_blocks,
                      embedding_channel_sizes,
                      embedding_kernels,
                      embedding_strides,
                      policy_n_residual_blocks,
                      policy_channel_sizes,
                      policy_kernels,
                      policy_strides,
                      device).to(device)
    n_episodes = 0
    running_reward = -1
    print('Initialized new MCTSnet')
seed += n_episodes
env = Tangram(seed, n_grid, n_blocks, n_possible_blocks)
tree = MCTSnetTree(env, embedding_size, device)
optimizer = optim.Adam(mctsnet.parameters(), lr=5e-4)
#optimizer = torch.optim.SGD(mctsnet.parameters(), lr=5e-4)

Running on Device =  cpu
Loaded MCTSnet from ./models/mctsnet_e20000_p0.09373902464399213.pt


In [4]:
def select_action(tree, n_simuls, mctsnet):
    probs, probs_mask = mctsnet(tree, n_simuls)    
    probs = torch.squeeze(probs)
    """
    block_mask = torch.tensor(np.any(probs_mask, axis=(1,2)).astype(float)).to(device)
    masked_block_probs = probs[:action_dims[0]]*block_mask/torch.sum(probs[:action_dims[0]]*block_mask)
    m1 = Categorical(masked_block_probs)
    block = m1.sample()
    y_mask = torch.tensor(np.any(probs_mask[block.item()], axis=1).astype(float)).to(device)
    masked_y_probs = probs[action_dims[0]:action_dims[0]+action_dims[1]]*y_mask/torch.sum(probs[action_dims[0]:action_dims[0]+action_dims[1]]*y_mask)
    m2 = Categorical(masked_y_probs)
    y = m2.sample()
    x_mask = torch.tensor(probs_mask[block.item(),y.item()]).to(device)
    masked_x_probs = probs[action_dims[0]+action_dims[1]:]*x_mask/torch.sum(probs[action_dims[0]+action_dims[1]:]*x_mask)
    m3 = Categorical(masked_x_probs)  
    x = m3.sample()
    env_action = np.array([block.item(),y.item(),x.item()])
    action = env_action[0]*n_grid*n_grid + env_action[1]*n_grid + env_action[2]
    m4 = Categorical(probs[:action_dims[0]])
    m5 = Categorical(probs[action_dims[0]:action_dims[0]+action_dims[1]])
    m6 = Categorical(probs[action_dims[0]+action_dims[1]:])
    """

    probs_mask = torch.tensor(probs_mask).to(device)
    masked_probs = probs*probs_mask
    if ~np.any(masked_probs.clone().cpu().detach().numpy()):
        masked_probs = probs_mask
    masked_probs /= torch.sum(masked_probs)
    m = Categorical(masked_probs)
    action = m.sample()
    action_id = action.item()
    block = action_id//n_grid//n_grid
    y = (action_id - block*n_grid*n_grid)//n_grid
    x = action_id - block*n_grid*n_grid - y*n_grid
    env_action = np.array([block,y,x])
    m_prime = Categorical(probs)
    #mctsnet.update_saved_log_probs((m4.log_prob(block), m5.log_prob(y), m6.log_prob(x)))
    mctsnet.update_saved_log_probs(m_prime.log_prob(action))
    return env_action, action

In [5]:
def finish_episode(mctsnet, optimizer):
    R = 0
    mctsnet_loss = []
    returns = []
    for r in mctsnet.get_rewards()[::-1]:
        R = r + gamma * R
        returns.insert(0, R)
    returns = torch.tensor(returns).to(device)
    """    
    for (log_prob1, log_prob2, log_prob3), R in zip(mctsnet.get_saved_log_probs(), returns):
        mctsnet_loss.append(torch.unsqueeze(-log_prob1 * R -log_prob2 * R -log_prob3 * R,0))
    """
    for log_prob, R in zip(mctsnet.get_saved_log_probs(), returns):
        mctsnet_loss.append(torch.unsqueeze(-log_prob*R,0))
    optimizer.zero_grad()
    mctsnet_loss = torch.cat(mctsnet_loss).sum()
    mctsnet_loss.backward()
    #torch.nn.utils.clip_grad_norm_(mctsnet.parameters(), 1.0)
    #torch.nn.utils.clip_grad_norm_(mctsnet.parameters(), 1.0)
    optimizer.step()
    mctsnet.update_losses(mctsnet_loss.clone().cpu().detach().numpy())
    mctsnet.delete_rewards()
    mctsnet.delete_saved_log_probs()

In [6]:
reward_threshold = 0.99
for i_episode in count(n_episodes+1):
    env.reset()
    if render:
        env.render()
    tree = MCTSnetTree(env, embedding_size, device)
    ep_reward = 0
    for t in range(1, n_blocks + 1):
        env_action, action = select_action(tree, n_simuls, mctsnet)
        env_state, env_reward, done = env.step(env_action)
        state = torch.unsqueeze(torch.tensor(env_state[:2]), 0).to(device)
        reward = torch.unsqueeze(torch.tensor([env_reward]), 0).to(device)
        for (child_id, child) in tree.get_root().get_children():
            if child_id == action:
                child.set_state(state)
                child.set_reward(reward)
                #child.set_action(torch.reshape(action,(1,1)))
                child.set_done(done)
                child.set_probs_mask(tree.get_env().get_mask())
                tree.set_root(child)
            else:
                child.delete()
        if render:
            env.render()
        mctsnet.update_rewards(reward)
        ep_reward += env_reward
        if done:
            break

    running_reward = 0.05 * ep_reward + (1 - 0.05) * running_reward
    finish_episode(mctsnet, optimizer)

    if i_episode % log_interval == 0:
        print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}\tLast loss: {:.2f}'.format(
                i_episode, ep_reward, running_reward, mctsnet.get_last_episode_loss()))
    if serialize_every_n_episodes > 0 and i_episode % serialize_every_n_episodes == 0:
        torch.save(mctsnet, f"{serialization_path}/mctsnet_e{str(i_episode).zfill(5)}_p{running_reward}.pt")
        print("Saved the model!")
        del mctsnet, optimizer
        gc.collect()
        torch.cuda.empty_cache()
        filesave_paths = sorted(glob.glob(f'{serialization_path}/mctsnet_e*'))
        mctsnet = torch.load(open(filesave_paths[-1],'rb'))
        n_episodes = int(filesave_paths[-1][18:23])
        running_reward = float(filesave_paths[-1][25:].replace('.pt',''))
        print('Loaded MCTSnet from '+ filesave_paths[-1])
        optimizer = optim.Adam(mctsnet.parameters(), lr=5e-4)
        #optimizer = torch.optim.SGD(mctsnet.parameters(), lr=5e-4)
    if running_reward > reward_threshold:
        print("Solved! Running reward is now {} and "
                "the last episode runs to {} time steps!".format(running_reward, t))
        torch.save(mctsnet, f"{serialization_path}/mctsnet_e{str(i_episode).zfill(5)}_p{running_reward}.pt")
        print("Saved the model!")
        break
    tree.delete()
    del tree
    gc.collect()
    torch.cuda.empty_cache()

Episode 20001	Last reward: -1.00	Average reward: 0.04	Last loss: -68.48
Episode 20002	Last reward: -1.00	Average reward: -0.01	Last loss: -68.48
Episode 20003	Last reward: -1.00	Average reward: -0.06	Last loss: -68.48
Episode 20004	Last reward: -1.00	Average reward: -0.11	Last loss: -68.48
Episode 20005	Last reward: 1.00	Average reward: -0.05	Last loss: 97.68
Episode 20006	Last reward: 1.00	Average reward: -0.00	Last loss: 97.68
Episode 20007	Last reward: -1.00	Average reward: -0.05	Last loss: -68.48
Episode 20008	Last reward: 1.00	Average reward: 0.00	Last loss: 97.68
Episode 20009	Last reward: -1.00	Average reward: -0.05	Last loss: -68.48
Episode 20010	Last reward: -1.00	Average reward: -0.10	Last loss: -68.48
Episode 20011	Last reward: 1.00	Average reward: -0.04	Last loss: 97.68
Episode 20012	Last reward: 1.00	Average reward: 0.01	Last loss: 97.68
Episode 20013	Last reward: -1.00	Average reward: -0.04	Last loss: -68.48
Episode 20014	Last reward: -1.00	Average reward: -0.09	Last loss

KeyboardInterrupt: 