In [1]:
import numpy as np
import torch
import torch.optim as optim
import os
from torch.distributions import Categorical
from pathlib import Path
import glob
import gc

from mcts1 import MCTS1
from tangram import Tangram
from mctstree import MCTSTree
from itertools import count
%matplotlib inline

In [2]:
torch.set_default_dtype(torch.float64)
torch.autograd.set_detect_anomaly(True)

c = np.sqrt(2)
gamma=0.9
environment_seed=543
network_seed=543
render=False
log_interval=1
gpu=False
load_agent=True

n_grid = 20
n_blocks = 4
n_possible_blocks = 6
chunk_type = 7
n_blocks_H = 2
n_distinct_samples = 20
n_samples = 20

action_dims = [3,n_blocks,n_possible_blocks]

n_simuls = 10
n_evals = 100

serialization_path = './models/mcts1/hierarchical_blocks_{}_n_samples_{}_unbiased_dataset_network_seed_{}'.format(n_blocks_H, n_samples, network_seed)
print('serialization_path: ',serialization_path)
serialize_every_n_episodes = 10000
update_every_n_episodes = 1
test_every_n_episodes = 100
# create folder 
Path(serialization_path).mkdir(parents=True, exist_ok=True)

serialization_path:  ./models/mcts1/hierarchical_blocks_2_n_samples_20_unbiased_dataset_network_seed_543


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if gpu and torch.cuda.is_available():
    torch.cuda.manual_seed_all(network_seed)
else:
    torch.manual_seed(network_seed)
    device = "cpu"
    if gpu:
        gpu = False
torch.set_num_threads(os.cpu_count() - 1)
print("Running on Device = ", device)
filesave_paths_mcts1 = sorted(glob.glob(f'{serialization_path}/mcts1_e*'))
if load_agent and len(filesave_paths_mcts1) > 0:
    mcts1 = torch.load(open(filesave_paths_mcts1[-1],'rb'))
    n_episodes = int(filesave_paths_mcts1[-1][-30:-24])
    running_reward = float(filesave_paths_mcts1[-1][-22:].replace('.pt',''))
    print('Loaded MCTS1 from '+ filesave_paths_mcts1[-1])
else:
    mcts1 = MCTS1(action_dims,
                  device).to(device)
    n_episodes = 0
    running_reward = -1
    print('Initialized new MCTS1')
env = Tangram(environment_seed, n_grid, n_blocks, n_possible_blocks, chunk_type, n_blocks_H, n_distinct_samples, n_samples)
tree = MCTSTree(env)
#optimizer_mcts1 = optim.Adam(mcts1.parameters(), lr=5e-4)
optimizer_mcts1 = optim.SGD(mcts1.parameters(), lr=5e-4)

Running on Device =  cpu
Initialized new MCTS1
Generating an unbiased Tangram environment...


  return torch._C._cuda_getDeviceCount() > 0


Connectivity matrix:
 [[ 0. 16.  6.  6.]
 [16.  0.  3.  6.]
 [ 6.  3.  0. 16.]
 [ 6.  6. 16.  0.]]
Uniformity threshhold: 50.50%


In [4]:
def select_action(tree, n_simuls, mcts1):
    probs, action_mask = mcts1(tree, n_simuls, c, gamma)    
    probs = torch.squeeze(probs)
    action_mask = torch.tensor(action_mask).to(device)
    masked_probs = probs*action_mask
    if not torch.sum(masked_probs.clone()) > 0:
        masked_probs += action_mask
    masked_probs /= torch.sum(masked_probs)
    m = Categorical(masked_probs)
    action = m.sample()
    action_id = action.item()
    block = action_id//(3*n_blocks)
    loc = action_id - block*3*n_blocks
    env_action = np.array([block,loc])
    #train_probs = masked_probs + 1e-4*(1-action_mask)
    #train_probs /= torch.sum(train_probs)
    #m_train = Categorical(train_probs)
    mcts1.update_saved_log_probs(m.log_prob(action))
    mcts1.update_saved_entropies(m.entropy())
    return env_action, action_id, m.log_prob(action)

In [5]:
def finish_episode(mcts1, optimizer, episode_num):
    R = 0
    mcts1_loss = []
    returns = []
    for r in mcts1.get_rewards()[::-1]:
        R = r + gamma * R
        returns.insert(0, R)
    returns = torch.tensor(returns).to(device)
    for log_prob, R in zip(mcts1.get_saved_log_probs(), returns):
        mcts1_loss.append(torch.unsqueeze(-log_prob*R,0))
    mcts1_loss = torch.cat(mcts1_loss).sum()/update_every_n_episodes
    if episode_num % update_every_n_episodes == 0:
        optimizer.zero_grad()
    mcts1_loss.backward()
    #torch.nn.utils.clip_grad_norm_(mcts1.parameters(), 1.0)
    if episode_num % update_every_n_episodes == 0:
        optimizer.step()
        #print('Updated Weights!')
    mcts1.update_losses(mcts1_loss.clone().cpu().detach().numpy())
    mcts1.delete_rewards()
    mcts1.delete_saved_log_probs()
    mcts1.delete_saved_entropies()

In [6]:
reward_threshold = 0.99
success_threshold = 99
success_ratio = 0
for i_episode in count(n_episodes+1):
    env.reset('train')
    if render:
        env.render()
    tree = MCTSTree(env)
    ep_reward = 0
    for t in range(1, n_blocks + 1):
        env_action, action, log_prob = select_action(tree, n_simuls, mcts1) #probs
        env_state, env_reward, done = env.step(env_action)
        for (child_id, child) in tree.get_root().get_children():
            if child_id == action:
                child.set_reward(env_reward)
                child.set_done(done)
                child.set_action_mask(tree.get_env().get_mask())
                tree.set_root(child)
            else:
                child.delete()
        if render:
            env.render()
        mcts1.update_rewards(env_reward)
        ep_reward += env_reward
        if done:
            break

    running_reward = 0.05 * ep_reward + (1 - 0.05) * running_reward
    mcts1.update_running_rewards(running_reward)
    finish_episode(mcts1, optimizer_mcts1, i_episode)

    if i_episode % log_interval == 0:
        print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}\tLast loss: {:.2f}'.format(
            i_episode, ep_reward, running_reward, mcts1.get_last_episode_loss()))
    if serialize_every_n_episodes > 0 and i_episode % serialize_every_n_episodes == 0:
        torch.save(mcts1, f"{serialization_path}/mcts1_e{str(i_episode).zfill(6)}_p{running_reward}.pt")
        print("Saved the model!")
        del mcts1, optimizer_mcts1
        filesave_paths_mcts1 = sorted(glob.glob(f'{serialization_path}/mcts1_e*'))
        mcts1 = torch.load(open(filesave_paths_mcts1[-1],'rb'))
        n_episodes = int(filesave_paths_mcts1[-1][-30:-24])
        running_reward = float(filesave_paths_mcts1[-1][-22:].replace('.pt',''))
        print('Loaded MCTS1 from '+ filesave_paths_mcts1[-1])
        #optimizer_mcts1 = optim.Adam(mcts1.parameters(), lr=5e-4)
        optimizer_mcts1 = optim.SGD(mcts1.parameters(), lr=5e-4)
        gc.collect()
        torch.cuda.empty_cache()
    if i_episode % test_every_n_episodes == 0:
        print('Testing...')
        mcts1.eval()
        with torch.no_grad():
            success_ratio = 0
            for eval_num in range(1,n_evals+1):
                env.reset('test')
                tree = MCTSTree(env)
                done = False
                while not done:
                    if render:
                        env.render()
                    probs, action_mask = mcts1(tree, n_simuls, c, gamma)
                    probs = torch.squeeze(probs)
                    action_mask = torch.tensor(action_mask).to(device)
                    masked_probs = probs*action_mask 
                    if not torch.sum(masked_probs.clone()) > 0:
                        masked_probs += action_mask
                    masked_probs /= torch.sum(masked_probs)
                    m = Categorical(masked_probs)
                    #action = m.sample().item()
                    action = torch.argmax(masked_probs)
                    action_id = action.item()
                    block = action_id//(3*n_blocks)
                    loc = action_id - block*3*n_blocks
                    env_action = np.array([block,loc])
                    env_state, env_reward, done = env.step(env_action)

                    for (child_id, child) in tree.get_root().get_children():
                        if child_id == action_id:
                            child.set_reward(env_reward)
                            child.set_done(done)
                            child.set_action_mask(tree.get_env().get_mask())
                            tree.set_root(child)
                        else:
                            child.delete()
                if env_reward == 1:
                    success_ratio += 1
                if render:
                    env.render()
            success_ratio *= 100/n_evals
            print("Success ratio: {}%".format(success_ratio))
            mcts1.update_success_ratios(success_ratio)
        mcts1.train()
    if running_reward > reward_threshold or success_ratio > success_threshold:
        print("Solved! Running reward is now {} and "
                "the last episode runs to {} time steps!".format(running_reward, t))
        torch.save(mcts1, f"{serialization_path}/mcts1_e{str(i_episode).zfill(6)}_p{running_reward}.pt")
        print("Saved the model!")
        break
    tree.delete()
    del tree
    gc.collect()
    torch.cuda.empty_cache()

Episode 1	Last reward: -1.00	Average reward: -1.00	Last loss: -3.22
Episode 2	Last reward: -1.00	Average reward: -1.00	Last loss: -3.21
Episode 3	Last reward: -1.00	Average reward: -1.00	Last loss: -3.66
Episode 4	Last reward: -1.00	Average reward: -1.00	Last loss: -3.95
Episode 5	Last reward: -1.00	Average reward: -1.00	Last loss: -2.67
Episode 6	Last reward: -1.00	Average reward: -1.00	Last loss: -3.58
Episode 7	Last reward: -1.00	Average reward: -1.00	Last loss: -3.33
Episode 8	Last reward: 1.00	Average reward: -0.90	Last loss: 3.29
Episode 9	Last reward: -1.00	Average reward: -0.90	Last loss: -4.00
Episode 10	Last reward: -1.00	Average reward: -0.91	Last loss: -3.74
Episode 11	Last reward: -1.00	Average reward: -0.91	Last loss: -3.40
Episode 12	Last reward: -1.00	Average reward: -0.92	Last loss: -4.35
Episode 13	Last reward: -1.00	Average reward: -0.92	Last loss: -4.82
Episode 14	Last reward: -1.00	Average reward: -0.93	Last loss: -3.26
Episode 15	Last reward: -1.00	Average reward:

KeyboardInterrupt: 

In [None]:
torch.save(mcts1, f"{serialization_path}/mcts1_e{str(i_episode).zfill(6)}_p{running_reward}.pt")