In [1]:
import os, sys
sys.path.append('/Users/tom/dev/imperial/FYP/MAax/')

In [2]:
from typing import Any, Callable, Tuple
from functools import partial

In [3]:
from worldgen import Floor, WorldBuilder
from mae_envs.modules.agents import Agents
from mae_envs.modules.walls import RandomWalls, WallScenarios
from mae_envs.modules.world import FloorAttributes, WorldConstants
from mae_envs.modules.objects import Boxes, Cylinders, LidarSites, Ramps
from mae_envs.modules.util import uniform_placement, center_placement
from mae_envs.envs.hide_and_seek import quadrant_placement


ImportError: cannot import name 'quadrant_placement' from 'mae_envs.envs.hide_and_seek' (/Users/tom/dev/imperial/FYP/MAax/mae_envs/envs/hide_and_seek.py)

In [None]:
import brax
import numpy as np
from brax.io import mjcf, html
from brax.training.agents.ppo import train as ppo
from maax.envs.base import Base
from mae_envs.util.types import RNGKey, PipelineState, Action
from brax.generalized import pipeline

import jax
from jax import numpy as jp

from jax import random

import time
import matplotlib.pyplot as plt

from IPython.display import HTML, clear_output
clear_output()


In [None]:
from brax.envs.env import State

In [None]:
seed = 7
batch_size = 16
horizon = 100
random_key = jax.random.PRNGKey(seed)

In [None]:
def make_env(n_frames=1, horizon=1000, deterministic_mode=False,
             floor_size=6.0, grid_size=30, door_size=2,
             n_hiders=1, n_seekers=1, max_n_agents=None,
             n_boxes=1, n_ramps=1, n_elongated_boxes=0,
             rand_num_elongated_boxes=False, n_min_boxes=None,
             box_size=0.5, boxid_obs=True, boxsize_obs=True, box_only_z_rot=True,
             pad_ramp_size=True,
             rew_type='joint_zero_sum',
             lock_box=True, grab_box=True, lock_ramp=True,
             lock_type='any_lock_specific',
             lock_grab_radius=0.25, lock_out_of_vision=True, grab_exclusive=False,
             grab_out_of_vision=False, grab_selective=False,
             box_floor_friction=0.2, other_friction=0.01, gravity=[0, 0, -50],
             action_lims=(-0.9, 0.9), polar_obs=True,
             scenario='quadrant', quadrant_game_hider_uniform_placement=False,
             p_door_dropout=0.0,
             n_rooms=4, random_room_number=True, prob_outside_walls=1.0,
             n_lidar_per_agent=0, visualize_lidar=False, compress_lidar_scale=None,
             hiders_together_radius=None, seekers_together_radius=None,
             prep_fraction=0.4, prep_rem=False,
             team_size_obs=False,
             restrict_rect=None, penalize_objects_out=False,
             ):
    '''
        This make_env function is not used anywhere; it exists to provide a simple, bare-bones
        example of how to construct a multi-agent environment using the modules framework.
    '''
    n_agents = n_seekers + n_hiders
    env = Base(n_agents=n_agents, n_frames=n_frames, horizon=horizon, grid_size=grid_size,
               deterministic_mode=deterministic_mode, seed=seed)
    env.add_module(WallScenarios(grid_size=grid_size, door_size=door_size,
                                     scenario=scenario, friction=other_friction,
                                     p_door_dropout=p_door_dropout))
    box_placement_fn = uniform_placement
    ramp_placement_fn = uniform_placement
    agent_placement_fn = uniform_placement

    env.add_module(Agents(n_agents,
                          placement_fn=agent_placement_fn,
                          color=[np.array((66., 235., 244., 255.)) / 255] * n_agents,
                          friction=other_friction,
                          polar_obs=polar_obs))

    if np.max(n_boxes) > 0:
        env.add_module(Boxes(n_boxes=n_boxes, placement_fn=box_placement_fn,
                             friction=box_floor_friction, polar_obs=polar_obs,
                             n_elongated_boxes=0,
                             boxid_obs=boxid_obs,
                             box_only_z_rot=box_only_z_rot,
                             boxsize_obs=boxsize_obs,
                             free=True))

    if n_ramps > 0:
        env.add_module(Ramps(n_ramps=n_ramps, placement_fn=ramp_placement_fn,
                             friction=other_friction, polar_obs=polar_obs,
                             pad_ramp_size=pad_ramp_size, free=True))

    # if n_lidar_per_agent > 0 and visualize_lidar:
    #     env.add_module(LidarSites(n_agents=n_agents, n_lidar_per_agent=n_lidar_per_agent))

    # env.add_module(WorldConstants(gravity=gravity))

    return env

In [None]:
env = make_env()

env.gen_sys(seed)

In [None]:
@jax.jit
def randomise_action(act, random_key):
    random_key, _ = random.split(random_key)
    return random.uniform(random_key, shape=act.shape, minval=-0.25, maxval=0.25), random_key

In [None]:
# Create the inital environment states
act_size = env.sys.act_size()

jit_step_fn = jax.jit(env.step)
jit_batch_reset_fn = jax.jit(jax.vmap(env.reset))

random_key, subkey = jax.random.split(random_key)
keys = jax.random.split(subkey, num=batch_size)

# Define initial batches states and actions
init_states = jit_batch_reset_fn(keys)
acts = jp.zeros(shape=(batch_size, act_size), dtype=jp.float32)

Traced<ShapedArray(float32[79])>with<DynamicJaxprTrace(level=1/0)>


In [None]:
@jax.jit
def play_step_fn(state: State, act: Action, random_key: RNGKey):
    act, random_key = randomise_action(act, random_key)
    state = jit_step_fn(state, act)
    return state, act, random_key, state.pipeline_state

@partial(jax.jit, static_argnames=("play_step_fn", "episode_length"))
def generate_unroll(
    init_state: State,
    act: Action,
    random_key: RNGKey,
    episode_length: int,
    play_step_fn) -> Tuple[State, Action, RNGKey]:
    """Generates an episode according to random action, returns the final state of
    the episode and the transitions of the episode.

    Args:
        init_state: first state of the rollout.
        act: The initial action
        random_key: random key for stochasiticity handling.
        episode_length: length of the rollout.
        index: index of the rollout.
        play_step_fn: function describing how a step need to be taken.

    Returns:
        A new state, the experienced transition.
    """
    def scan_play_step_fn(
        carry: Tuple[State, Action, RNGKey], unused_arg: Any) ->Tuple[Tuple[State, Action, RNGKey], PipelineState]:
        state, act, random_key, p_states = play_step_fn(*carry)
        return (state, act, random_key), p_states


    (dst_state, dst_act, key), rollout = jax.lax.scan(
        scan_play_step_fn, (init_state, act, random_key), None, length=episode_length)

    return dst_state, rollout, key

In [None]:
train_fn = partial(ppo.train, num_timesteps=2_000_000, num_evals=20, episode_length=1000, 
normalize_observations=True, action_repeat=1, batch_size=batch_size, seed=seed)

In [None]:
max_y = 1000
min_y = -1000

xdata, ydata = [], []
times = [time.time()]

def progress(num_steps, metrics):
  times.append(time.time())
  xdata.append(num_steps)
  ydata.append(metrics['eval/episode_reward'])
  clear_output(wait=True)
  plt.xlim([0, train_fn.keywords['num_timesteps']])
  plt.ylim([min_y, max_y])
  plt.xlabel('# environment steps')
  plt.ylabel('reward per episode')
  plt.plot(xdata, ydata)
  plt.show()

In [None]:
make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress)

Traced<ShapedArray(float32[79])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(float32[79])>with<DynamicJaxprTrace(level=1/0)>


ValueError: Incompatible shapes for broadcasting: shapes=[(128, 2), (128,), (128, 2)]

In [None]:
print(f'time to jit: {times[1] - times[0]}')
print(f'time to train: {times[-1] - times[1]}')

In [None]:
unroll_fn = partial(
    generate_unroll,
    episode_length=horizon,
    play_step_fn=play_step_fn,
)

dst_states, rollouts, random_key = jax.vmap(unroll_fn)(init_states, acts, keys)


# (dst_state, dst_act, key, index), rollout = jax.lax.scan(scan_play_step_fn, (state, act, random_key, 0), None, length=episode_length)

In [None]:
states_list = []
# for j in range(batch_size):
#     for i in range(episode_length):
#         s = jax.tree_util.tree_map(lambda x: x[j][i], rollouts)
#         states_list.append(s)


# print(len(states_list))
# html.save('parallel.html', test_env.sys, states_list[:episode_length])