In [1]:
import os, sys
sys.path.append('/Users/tom/dev/imperial/FYP/MAax/')

In [2]:
from typing import Any, Callable, Tuple
import cProfile

In [3]:
from worldgen import Floor, WorldBuilder
from maax.modules.agents import Agents
from maax.modules.walls import RandomWalls, WallScenarios
from maax.modules.world import FloorAttributes, WorldConstants
from maax.modules.objects import Boxes, Ramps
from maax.modules.util import uniform_placement, centre_placement, proximity_placement
from maax.envs.hide_and_seek import quad_placement, outside_quad_placement, HideAndSeekRewardWrapper, TrackStatWrapper
from maax.envs.base import Base
from maax.wrappers.multi_agent import (SplitMultiAgentActions,
                                           SplitObservations, SelectObsWrapper)
from maax.wrappers.util import (ConcatenateObsWrapper,
                                    MaskActionWrapper, SpoofEntityWrapper,
                                    AddConstantObservationsWrapper, MWrapper)
from maax.wrappers.manipulation import (LockObjWrapper, LockAllWrapper)
from maax.wrappers.line_of_sight import AgentAgentContactMask2D
from maax.wrappers.prep_phase import (PreparationPhase, NoActionsInPrepPhase,
                                          MaskPrepPhaseAction)
from maax.wrappers.limit_mvmnt import RestrictAgentsRect
from maax.wrappers.team import TeamMembership


In [4]:
import brax
import numpy as np
from brax.io import mjcf, html
from maax.envs.base import Base
from maax.util.types import RNGKey, PipelineState, Action
from brax.generalized import pipeline

import jax
from jax import numpy as jp

from jax import random

from IPython.display import HTML, clear_output
clear_output()


In [5]:
from maax.envs.base import State

In [6]:
seed = 3994
horizon = 800
n_frames = 1
random_key = jax.random.PRNGKey(seed)

In [7]:
n_hiders = 2
n_seekers = 2
n_boxes = 2
n_ramps = 1
scenario = "quad"

In [8]:
def make_env(deterministic_mode=False,
             floor_size=6.0, grid_size=30, door_size=2,
             max_n_agents=None, n_elongated_boxes=0,
             rand_num_elongated_boxes=False, n_min_boxes=None,
             box_size=0.5, boxid_obs=False, box_only_z_rot=True,
             rew_type='joint_zero_sum',
             lock_box=True, lock_ramp=True,
             lock_type='any_lock_specific',
             lock_grab_radius=0.25, lock_out_of_vision=True,
             box_floor_friction=0.2, other_friction=0.01, gravity=[0, 0, -50],
             action_lims=(-0.9, 0.9), polar_obs=False,
             scenario='quad', quadrant_game_hider_uniform_placement=False,
             p_door_dropout=0.0,
             n_rooms=4, random_room_count=True, prob_outside_walls=1.0,
             n_lidar_per_agent=0, visualize_lidar=False, compress_lidar_scale=None,
             hiders_together_radius=None, seekers_together_radius=1.0,
             prep_fraction=0.4, prep_rem=False,
             team_size_obs=False,
             restrict_rect=None, penalize_objects_out=False):

    lock_radius_multiplier = lock_grab_radius / box_size

    env = Base(n_agents=n_hiders + n_seekers, n_frames=n_frames, horizon=horizon,
               floor_size=floor_size, grid_size=grid_size,
               action_lims=action_lims,
               deterministic_mode=deterministic_mode, seed=seed)


# Add modules

    if scenario == 'quad':
        env.add_module(WallScenarios(grid_size=grid_size, door_size=door_size,
                                     scenario=scenario, friction=other_friction,
                                     p_door_dropout=p_door_dropout))
        box_placement_fn = uniform_placement
        ramp_placement_fn = uniform_placement
        hider_placement = uniform_placement if quadrant_game_hider_uniform_placement else quad_placement
        agent_placement_fn = [hider_placement] * n_hiders + [outside_quad_placement] * n_seekers



    env.add_module(Agents(n_hiders + n_seekers,
                          placement_fn=agent_placement_fn,
                          color=[np.array((66., 235., 244., 255.)) / 255] * n_hiders + [np.array((240., 20., 50., 255.)) / 255] * n_seekers,
                          friction=other_friction,
                          polar_obs=polar_obs))
    if np.max(n_boxes) > 0:
        env.add_module(Boxes(n_boxes=n_boxes, placement_fn=box_placement_fn,
                             friction=box_floor_friction, polar_obs=polar_obs,
                             n_elongated_boxes=n_elongated_boxes,
                             boxid_obs=boxid_obs, box_only_z_rot=box_only_z_rot))
    if n_ramps > 0:
        env.add_module(Ramps(n_ramps=n_ramps, placement_fn=ramp_placement_fn, friction=other_friction, polar_obs=polar_obs,
                             pad_ramp_size=(np.max(n_elongated_boxes) > 0)))


# Add wrappers

    env = TeamMembership(env, np.append(np.zeros((n_hiders,)), np.ones((n_seekers,))))
    env = AgentAgentContactMask2D(env)
    hider_obs = jp.array([[1]] * n_hiders + [[0]] * n_seekers)
    env = AddConstantObservationsWrapper(env, new_obs={'hider': hider_obs})
    env = HideAndSeekRewardWrapper(env, n_hiders=n_hiders, n_seekers=n_seekers,
                                   rew_type=rew_type)

    env = PreparationPhase(env, prep_fraction=prep_fraction)

    if prep_rem:
        env = TrackStatWrapper(env, np.max(n_boxes), n_ramps)
    env = NoActionsInPrepPhase(env, np.arange(n_hiders, n_hiders + n_seekers))
    return env

In [9]:
# Create the physical system and initial state
test_env = make_env(random_key)

test_env.gen_sys(seed)

state = jax.jit(test_env.reset)(random_key)

In [10]:
def randomise_action(act, random_key):
    random_key, sub_key = random.split(random_key)
    return random.uniform(random_key, shape=act.shape, minval=-0.25, maxval=0.25), random_key

In [11]:
# Define initial action and compile step-function
jit_step_fn = jax.jit(test_env.step)
act_size = test_env.sys.act_size()

act = jp.zeros(shape=(test_env.n_agents, act_size // test_env.n_agents))

In [12]:
# # Scan the rollout function
# @jax.jit
# def play_step_fn(state: State, act: Action, random_key: RNGKey, index: int):
#     act, random_key = jax.lax.cond(index % 50 == 0, randomise_action, lambda x, y: (x, y), act, random_key)
#     state = jit_step_fn(state, act)
#     return state, act, random_key, index + 1, state.pipeline_state

# def scan_play_step_fn(
#     carry: Tuple[State, Action, RNGKey, int], unused_arg: Any
# ) ->Tuple[Tuple[State, RNGKey, int], PipelineState]:
#     state, act, random_key, index, p_states = play_step_fn(*carry)
#     return (state, act, random_key, index), p_states
    

# (dst_state, dst_act, key, index), rollout = jax.lax.scan(scan_play_step_fn, (state, act, random_key, 0), None, length=horizon)

In [13]:
rollout = []

for i in range(horizon):
    act, random_key = jax.lax.cond(i % 50 == 0, randomise_action, lambda x, y: (x, y), act, random_key)
    rollout.append(state.pipeline_state)
    # state = test_env.step(state, act)
    state = jit_step_fn(state, act)

In [14]:

html.save('example.html', test_env.sys, rollout)

In [15]:
# # Visualise the rollout

# states_list = []

# for i in range(horizon):
#     s = jax.tree_util.tree_map(lambda x: x[i], rollout)
#     states_list.append(s)


# html.render(test_env.sys, states_list)
# html.save('demo.html', test_env.sys, states_list)