In [1]:
import numpy as np
import torch
import torch.optim as optim
import os
from pathlib import Path
import glob
import gc

from mcts2 import MCTS2
from tangram import Tangram
from mcts2tree import MCTS2Tree
from itertools import count
%matplotlib inline

In [2]:
torch.set_default_dtype(torch.float64)
torch.autograd.set_detect_anomaly(True)

gamma=0.9
environment_seed=543
network_seed=543
render=False
log_interval=1
gpu=False
load_agent=True

n_grid = 20
n_blocks = 4
n_possible_blocks = 6
chunk_type = 7
n_blocks_H = 2
n_distinct_samples = 20
n_samples = 20

action_dims = [3,n_blocks,n_possible_blocks]
policy_n_residual_blocks = 2
policy_channel_sizes = [32,32,32,16]
policy_kernels = [3,3,3,1]
policy_strides = [1,1,1,1]

n_simuls = 10
n_evals = 100

serialization_path = './models/mcts2/hierarchical_blocks_{}_n_samples_{}_unbiased_dataset_network_seed_{}'.format(n_blocks_H, n_samples, network_seed)
print('serialization_path: ',serialization_path)
serialize_every_n_episodes = 10000
update_every_n_episodes = 1
test_every_n_episodes = 100
# create folder 
Path(serialization_path).mkdir(parents=True, exist_ok=True)

serialization_path:  ./models/mcts2/hierarchical_blocks_2_n_samples_20_unbiased_dataset_network_seed_543


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if gpu and torch.cuda.is_available():
    torch.cuda.manual_seed_all(network_seed)
else:
    torch.manual_seed(network_seed)
    device = "cpu"
    if gpu:
        gpu = False
torch.set_num_threads(os.cpu_count() - 1)
print("Running on Device = ", device)
filesave_paths_mcts2 = sorted(glob.glob(f'{serialization_path}/mcts2_e*'))
if load_agent and len(filesave_paths_mcts2) > 0:
    mcts2 = torch.load(open(filesave_paths_mcts2[-1],'rb'))
    n_episodes = int(filesave_paths_mcts2[-1][-30:-24])
    running_reward = float(filesave_paths_mcts2[-1][-22:].replace('.pt',''))
    print('Loaded MCTS2 from '+ filesave_paths_mcts2[-1])
else:
    mcts2 = MCTS2(action_dims,
                  policy_n_residual_blocks,
                  policy_channel_sizes,
                  policy_kernels,
                  policy_strides,
                  device).to(device)
    n_episodes = 0
    running_reward = -1
    print('Initialized new MCTS2')
env = Tangram(environment_seed, n_grid, n_blocks, n_possible_blocks, chunk_type, n_blocks_H, n_distinct_samples, n_samples)
tree = MCTS2Tree(env)
#optimizer_mcts2 = optim.Adam(mcts2.parameters(), lr=5e-4)
optimizer_mcts2 = optim.SGD(mcts2.parameters(), lr=5e-4)

  return torch._C._cuda_getDeviceCount() > 0


Running on Device =  cpu
Initialized new MCTS2
Generating an unbiased Tangram environment...
Connectivity matrix:
 [[ 0. 16.  6.  6.]
 [16.  0.  3.  6.]
 [ 6.  3.  0. 16.]
 [ 6.  6. 16.  0.]]
Uniformity threshhold: 50.50%


In [4]:
def select_action(tree, n_simuls, mcts2):
    action = mcts2(tree, n_simuls, gamma)    
    block = action//(3*n_blocks)
    loc = action - block*3*n_blocks
    env_action = np.array([block,loc])
    return env_action, action

In [5]:
def finish_episode(mcts2, optimizer, episode_num):
    mcts2_loss = []
    for log_prob, R in zip(mcts2.get_saved_log_probs(), mcts2.get_returns()):
        mcts2_loss.append(torch.unsqueeze(-log_prob*R,0))
    mcts2_loss = torch.cat(mcts2_loss).sum()/update_every_n_episodes
    if episode_num % update_every_n_episodes == 0:
        optimizer.zero_grad()
    mcts2_loss.backward()
    #torch.nn.utils.clip_grad_norm_(mcts2.parameters(), 1.0)
    if episode_num % update_every_n_episodes == 0:
        optimizer.step()
        #print('Updated Weights!')
    mcts2.update_losses(mcts2_loss.clone().cpu().detach().numpy())
    mcts2.delete_returns()
    mcts2.delete_saved_log_probs()
    mcts2.delete_saved_entropies()

In [6]:
reward_threshold = 0.99
success_threshold = 99
success_ratio = 0
for i_episode in count(n_episodes+1):
    env.reset('train')
    if render:
        env.render()
    tree = MCTS2Tree(env)
    ep_reward = 0
    for t in range(1, n_blocks + 1):
        env_action, action = select_action(tree, n_simuls, mcts2)
        env_state, env_reward, done = env.step(env_action)
        for (child_id, child) in tree.get_root().get_children():
            if child_id == action:
                child.set_reward(env_reward)
                child.set_done(done)
                child.set_action_mask(tree.get_env().get_mask())
                tree.set_root(child)
            else:
                child.delete()
        if render:
            env.render()
        ep_reward += env_reward
        if done:
            break

    running_reward = 0.05 * ep_reward + (1 - 0.05) * running_reward
    mcts2.update_running_rewards(running_reward)
    finish_episode(mcts2, optimizer_mcts2, i_episode)

    if i_episode % log_interval == 0:
        print('Episode {}\tLast reward: {:.2f}\tAverage reward: {:.2f}\tLast loss: {:.2f}'.format(
            i_episode, ep_reward, running_reward, mcts2.get_last_episode_loss()))
    if serialize_every_n_episodes > 0 and i_episode % serialize_every_n_episodes == 0:
        torch.save(mcts2, f"{serialization_path}/mcts2_e{str(i_episode).zfill(6)}_p{running_reward}.pt")
        print("Saved the model!")
        del mcts2, optimizer_mcts2
        filesave_paths_mcts2 = sorted(glob.glob(f'{serialization_path}/mcts2_e*'))
        mcts2 = torch.load(open(filesave_paths_mcts2[-1],'rb'))
        n_episodes = int(filesave_paths_mcts2[-1][-30:-24])
        running_reward = float(filesave_paths_mcts2[-1][-22:].replace('.pt',''))
        print('Loaded MCTS2 from '+ filesave_paths_mcts2[-1])
        #optimizer_mcts2 = optim.Adam(mcts2.parameters(), lr=5e-4)
        optimizer_mcts2 = optim.SGD(mcts2.parameters(), lr=5e-4)
        gc.collect()
        torch.cuda.empty_cache()
    if i_episode % test_every_n_episodes == 0:
        print('Testing...')
        mcts2.eval()
        with torch.no_grad():
            success_ratio = 0
            for eval_num in range(1,n_evals+1):
                env.reset('test')
                tree = MCTS2Tree(env)
                done = False
                while not done:
                    if render:
                        env.render()
                    action = mcts2(tree, n_simuls, gamma)    
                    block = action//(3*n_blocks)
                    loc = action - block*3*n_blocks
                    env_action = np.array([block,loc])
                    env_state, env_reward, done = env.step(env_action)

                    for (child_id, child) in tree.get_root().get_children():
                        if child_id == action:
                            child.set_reward(env_reward)
                            child.set_done(done)
                            child.set_action_mask(tree.get_env().get_mask())
                            tree.set_root(child)
                        else:
                            child.delete()
                if env_reward == 1:
                    success_ratio += 1
                if render:
                    env.render()
                mcts2.delete_returns()
                mcts2.delete_saved_log_probs()
                mcts2.delete_saved_entropies()
            success_ratio *= 100/n_evals
            print("Success ratio: {}%".format(success_ratio))
            mcts2.update_success_ratios(success_ratio)
        mcts2.train()
    if running_reward > reward_threshold or success_ratio > success_threshold:
        print("Solved! Running reward is now {} and "
                "the last episode runs to {} time steps!".format(running_reward, t))
        torch.save(mcts2, f"{serialization_path}/mcts2_e{str(i_episode).zfill(6)}_p{running_reward}.pt")
        print("Saved the model!")
        break
    tree.delete()
    del tree
    gc.collect()
    torch.cuda.empty_cache()

Episode 1	Last reward: 1.00	Average reward: -0.90	Last loss: -14.11
Episode 2	Last reward: -1.00	Average reward: -0.90	Last loss: -31.01
Episode 3	Last reward: -1.00	Average reward: -0.91	Last loss: -41.26
Episode 4	Last reward: 1.00	Average reward: -0.81	Last loss: 22.86
Episode 5	Last reward: 1.00	Average reward: -0.72	Last loss: -24.49
Episode 6	Last reward: 1.00	Average reward: -0.64	Last loss: -10.65
Episode 7	Last reward: -1.00	Average reward: -0.66	Last loss: -35.29
Episode 8	Last reward: 1.00	Average reward: -0.57	Last loss: -7.48
Episode 9	Last reward: 1.00	Average reward: -0.49	Last loss: -1.55
Episode 10	Last reward: 1.00	Average reward: -0.42	Last loss: -18.33
Episode 11	Last reward: 1.00	Average reward: -0.35	Last loss: -8.59
Episode 12	Last reward: 1.00	Average reward: -0.28	Last loss: -10.01


KeyboardInterrupt: 

In [None]:
torch.save(mcts2, f"{serialization_path}/mcts2_e{str(i_episode).zfill(6)}_p{running_reward}.pt")