In [1]:
import os
import argparse
import time
import pickle

# Libraries
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
import pathlib
PROJECT_ROOT = pathlib.Path().absolute().parent 
print(PROJECT_ROOT)
os.chdir(PROJECT_ROOT)
os.getcwd()

C:\Users\victo\Documents\CodeBase\AnimalTown\Python\src


'C:\\Users\\victo\\Documents\\CodeBase\\AnimalTown\\Python\\src'

In [3]:
# Local imports
import tom.config
import tom.render_utils
from tom.tom_agents import ToMAgent

In [4]:
class ARGS():
    paths = tom.config.Paths()
    exp_name = tom.config.exp_name
    episodes = tom.config.episodes
    episode_len = tom.config.episode_len
    save_rate = tom.config.save_rate
    method = 'ToM'  # ['chase', 'ToM gt', 'ToM']
    debug = False

    scenario="simple_chase"
    max_episode_len=episode_len
    num_episodes=episodes

    # Agent setting
    method=method

    use_gt_belief=True if method == 'ToM gt' else False
    direct_chase=True if method == 'chase' else False
    use_distance=True

    # Checkpoint
    exp_name=exp_name
    save_dir=os.path.join(paths.tmp_root, 'checkpoints', exp_name, method)
    save_rate=save_rate

    # Evaluation
    debug=debug
    restore=False
    display=False
    display_mode="all"
    save_screen=False
    benchmark=True
    benchmark_iters=episodes
    benchmark_dir=os.path.join(paths.tmp_root, 'benchmark', exp_name, method)
    plots_dir=os.path.join(paths.tmp_root, 'plots', exp_name, method)


In [5]:
args = ARGS()

In [6]:
t_start = time.time()
np.random.seed(int(t_start))
torch.manual_seed(int(t_start))

<torch._C.Generator at 0x2bbce9081d0>

In [7]:
from multiagent.environment import MultiAgentEnv
import multiagent.scenarios as scenarios

In [8]:
args.scenario

'simple_chase'

In [9]:
# load scenario from script
scenario = scenarios.load(args.scenario + ".py").Scenario(open_world=True, setting=args.exp_name)


In [10]:
# create world
world = scenario.make_world()

In [11]:
# create multi-agent environment
try:
    done_callback = scenario.done
    info_callback = scenario.info
except AttributeError:
    done_callback = None
    info_callback = None

In [12]:
env = MultiAgentEnv(world, scenario.reset_world, scenario.reward, scenario.observation, done_callback=done_callback,
                    info_callback=info_callback)



# Env

In [13]:
new_obs_n = env.reset()

In [14]:
new_obs_n

[(array([[0., 0.],
         [0., 0.],
         [0., 0.]]),
  array([[-0.48865396,  0.15470432],
         [ 0.87789463, -0.75192884],
         [-0.63157033,  0.77846646]]),
  array([[0., 0.],
         [0., 0.],
         [0., 0.]]))]

In [15]:
polices = list()
for i, policy_agent in enumerate(env.world.policy_agents):
    polices.append(ToMAgent(policy_agent, env.world, scenario, new_obs_n[i], args.exp_name, use_gt_belief=args.use_gt_belief, use_distance=args.use_distance, direct_chase=args.direct_chase))
if args.restore:
    for police in polices:
        police.load_model(os.path.join(args.save_dir, sorted(os.listdir(args.save_dir))[-1]))
if not args.debug:
    os.makedirs(os.path.join(args.save_dir, str(int(t_start))))

In [16]:
polices

[<tom.tom_agents.ToMAgent at 0x2bb810ae048>]

In [17]:
weights = np.zeros((0, polices[0].__class__.weight.shape[0]))

In [18]:
# Record data
episode_returns = []
episode_info = []
episode_estimatio_belief = []

In [19]:
# for different rendering modes
display_modes = tom.render_utils.get_display_modes(args.display_mode)

# Initialize saving directories
args.plots_dir = os.path.join(args.plots_dir, str(int(t_start)))
os.makedirs(os.path.join(args.plots_dir, 'belief_value'))
os.makedirs(os.path.join(args.plots_dir, 'dist_value'))

# Train

In [20]:
def plot_weights(arglist, weights, filename1=None, filename2=None):
    plt.figure(0)
    plt.clf()
    sns.barplot(x=np.arange(2), y=weights[-1, :2], color='c')
    # plt.ylim(-1.0, 1.0)
    if filename1 and not arglist.debug:
        plt.savefig(filename1)
        plt.close()
    else:
        plt.pause(0.001)

    plt.figure(1)
    plt.clf()
    plt.plot(np.arange(weights.shape[1]-3), weights[-1, 3:], color='m')
    # plt.ylim(-1.0, 1.0)
    if filename2 and not arglist.debug:
        plt.savefig(filename2)
        plt.close()
    else:
        plt.pause(0.001)

    # =============================== DEBUG ===============================
    # plt.figure(0)
    # plt.clf()
    # plt.plot(np.arange(2), weights[-1, :2], color='c')
    # plt.plot(np.arange(2, weights.shape[1]), weights[-1, 2:], color='m')
    # # plt.ylim(-1.0, 1.0)
    # plt.savefig(filename1)
    # plt.pause(0.001)
    pass

def plot_avg_return(arglist, episode_returns, filename=None):
    plt.figure(3)
    # plt.plot(np.arange(len(episode_returns)), episode_returns, 'k')
    plt.plot(np.arange(len(episode_returns)), [np.mean(episode_returns[max(0, i-99):i+1]) for i in range(len(episode_returns))], 'r')
    plt.xlabel('Number of episodes')
    plt.ylabel('Avg return')

    if filename:
        plt.savefig(filename)
        plt.close()
    else:
        plt.pause(0.001)


In [21]:
# training loop
print("Starting training")
sns.set(style="white", context="paper", palette="muted", color_codes=True)
while weights.shape[0] <= 1000:
    print("episode:",weights.shape[0])
    weights = np.vstack((weights, polices[0].__class__.weight))
    print("weights", weights, "stack", polices[0].__class__.weight)
    if True: #weights.shape[0] % args.save_rate == 0:
        plot_weights(args, weights, os.path.join(args.plots_dir, 'belief_value', "belief_value_{}.png".format(weights.shape[0])),
                     os.path.join(args.plots_dir, 'dist_value', "dist_value_{}.png".format(weights.shape[0])))
        plot_avg_return(args, episode_returns, os.path.join(args.plots_dir, "avg_return.png"))

    # Initialization
    env.reset()
    episode_returns.append(0)
    episode_info.append([])
    episode_estimatio_belief.append([])
    action_n = np.zeros((len(polices), 5))
    episode_step = 0
    prev_values = np.zeros(len(polices))
    for police in polices:
        police.initialize()

    if args.save_screen:
        save_screen_dirs = tom.render_utils.get_save_dirs(display_modes, args.plots_dir, str(weights.shape[0]))

    # training each episode
    while True:
        # get current value
        curr_values = [police.compute_value() for police in polices]

        # choose action based on current approximation
        for i, police in enumerate(polices):
            action_n[i] = police.select_action(new_obs_n[i])

        # update the environment
        new_obs_n, reward_n, done_n, info_n = env.step(action_n)
        done = any(done_n)

        # FIXME: only record the first agent's history
        episode_returns[-1] += reward_n[0]
        episode_info[-1].append(info_n['n'][0][0])
        episode_estimatio_belief[-1].append(polices[0].belief_mean.detach().squeeze().cpu().numpy()[0])

        # get next value
        next_values = [police.compute_value() for police in polices]
        if not args.restore:
            for i, police in enumerate(polices):
                police.train(reward_n[i], done, prev_values[i], curr_values[i], next_values[i], new_obs_n[i])
        prev_values = next_values

        thief = env.world.thieves[0]
        # for displaying learned policies
        if args.display:
            # time.sleep(0.1)
            tom.render_utils.render_image(env, args.display_mode, thief.belief, True)

        # Save screen
        if args.save_screen:
            for i, d in enumerate(display_modes):
                filename = os.path.join(save_screen_dirs[i], "step{}.png".format(episode_step))
                tom.render_utils.save_screen(env, d, filename, belief=thief.belief)

        episode_step += 1
        if done or episode_step >= args.max_episode_len:
            episode_info[-1].append(done)
            break

        # End episode loop

    if not args.debug:
        # Save trained model every few training steps
        if weights.shape[0] % args.save_rate == 0:
            polices[0].save_model(os.path.join(args.save_dir, str(int(t_start))))

    # End training loop

Starting training
episode: 0
weights [[0.00331528 0.00135623 0.00310449 0.00798787 0.00615955 0.0072921
  0.00441572 0.00263972 0.00680271]] stack [0.00331528 0.00135623 0.00310449 0.00798787 0.00615955 0.0072921
 0.00441572 0.00263972 0.00680271]
episode: 1
weights [[ 0.00331528  0.00135623  0.00310449  0.00798787  0.00615955  0.0072921
   0.00441572  0.00263972  0.00680271]
 [-1.04881908 -0.4432815   0.00310449  0.00798787  0.00615955 -0.33534641
  -0.98670366 -0.13854761 -0.01502417]] stack [-1.04881908 -0.4432815   0.00310449  0.00798787  0.00615955 -0.33534641
 -0.98670366 -0.13854761 -0.01502417]
episode: 2
weights [[ 0.00331528  0.00135623  0.00310449  0.00798787  0.00615955  0.0072921
   0.00441572  0.00263972  0.00680271]
 [-1.04881908 -0.4432815   0.00310449  0.00798787  0.00615955 -0.33534641
  -0.98670366 -0.13854761 -0.01502417]
 [-1.2944218  -1.06146848 -0.02133034 -0.13886558 -0.26459236 -0.6991995
  -1.01613731 -0.16701037 -0.01502417]] stack [-1.2944218  -1.06146848 -0

In [22]:
args.paths.tmp_root

'C:\\Users\\victo\\Documents\\CodeBase\\AnimalTown\\Python\\tmp'

In [23]:
args.plots_dir

'C:\\Users\\victo\\Documents\\CodeBase\\AnimalTown\\Python\\tmp\\plots\\fast_thief\\ToM\\1583208244'