In [1]:
import os, sys
sys.path.append('/Users/tom/dev/imperial/FYP/MAax/')

In [2]:
from typing import Any, Callable, Tuple
import cProfile

In [3]:
from worldgen import Floor, WorldBuilder, WorldParams
from maax.modules.agents import Agents
from maax.modules.walls import RandomWalls, WallScenarios
from maax.modules.world import FloorAttributes, WorldConstants
from maax.modules.objects import Boxes, Ramps
from maax.modules.util import uniform_placement, centre_placement, proximity_placement
from maax.envs.hide_and_seek import quad_placement, outside_quad_placement, HideAndSeekRewardWrapper, TrackStatWrapper
from maax.envs.base import Base
from maax.wrappers.multi_agent import (SplitMultiAgentActions,
                                           SplitObservations, SelectObsWrapper)
from maax.wrappers.util import (ConcatenateObsWrapper,
                                    MaskActionWrapper, SpoofEntityWrapper,
                                    AddConstantObservationsWrapper, MWrapper)
from maax.wrappers.manipulation import (LockObjWrapper, LockAllWrapper)
from maax.wrappers.lidar import Lidar
from maax.wrappers.line_of_sight import (AgentAgentObsMask2D, AgentAgentContactMask2D,
                                            AgentGeomObsMask2D, AgentSiteObsMask2D)
from maax.wrappers.prep_phase import (PreparationPhase, NoActionsInPrepPhase,
                                          MaskPrepPhaseAction)
from maax.wrappers.limit_mvmnt import RestrictAgentsRect
from maax.wrappers.team import TeamMembership
from maax.wrappers.food import FoodHealthWrapper, AlwaysEatWrapper


In [4]:
import brax
import numpy as np
from brax.io import mjcf, html
from maax.envs.base import Base
from maax.util.types import RNGKey, PipelineState, Action
from brax.generalized import pipeline

import jax
from jax import numpy as jp

from jax import random

from IPython.display import HTML, clear_output
clear_output()


In [5]:
from maax.envs.base import State

In [6]:
seed = 3
horizon = 1
n_frames = 1
random_key = jax.random.PRNGKey(seed)

In [7]:
def make_env(deterministic_mode=False,
             floor_size=6.0, grid_size=30, door_size=2,
             n_hiders=2, n_seekers=2, max_n_agents=None,
             n_boxes=2, n_ramps=1, n_elongated_boxes=0,
             rand_num_elongated_boxes=False, n_min_boxes=None,
             box_size=0.5, boxid_obs=False, box_only_z_rot=True,
             rew_type='joint_zero_sum',
             lock_box=True, grab_box=True, lock_ramp=True,
             lock_type='any_lock_specific',
             lock_grab_radius=0.25, lock_out_of_vision=True,
             box_floor_friction=0.2, other_friction=0.01, gravity=[0, 0, -50],
             action_lims=(-0.9, 0.9), polar_obs=True,
             scenario='randomwalls', quad_game_hider_uniform_placement=False,
             p_door_dropout=0.0,
             n_rooms=4, random_room_count=True, prob_outside_walls=0.0,
             hiders_together_radius=None, seekers_together_radius=None,
             prep_fraction=0.4, prep_obs=False,
             team_size_obs=False,
             restrict_rect=None, penalize_objects_out=False,
             ):

    grab_radius_multiplier = lock_grab_radius / box_size
    lock_radius_multiplier = lock_grab_radius / box_size

    env = Base(n_agents=n_hiders + n_seekers, n_frames=n_frames, horizon=horizon,
               floor_size=floor_size, grid_size=grid_size,
               action_lims=action_lims,
               deterministic_mode=deterministic_mode, seed=seed)

    if scenario == 'randomwalls':
        env.add_module(RandomWalls(
            grid_size=grid_size, num_rooms=n_rooms,
            random_room_count=random_room_count, min_room_size=6,
            door_size=door_size,
            prob_outside_walls=prob_outside_walls, gen_door_obs=False))
        box_placement_fn = uniform_placement
        ramp_placement_fn = uniform_placement
        cell_size = floor_size / grid_size

        first_hider_placement = uniform_placement
        if hiders_together_radius is not None:
            htr_in_cells = np.ceil(hiders_together_radius / cell_size).astype(int)

            env.metadata['hiders_together_radius'] = htr_in_cells

            close_to_first_hider_placement = proximity_placement(
                                                "agent", 0, "hiders_together_radius")

            agent_placement_fn = [first_hider_placement] + \
                                 [close_to_first_hider_placement] * (n_hiders - 1)
        else:
            agent_placement_fn = [first_hider_placement] * n_hiders

        first_seeker_placement = uniform_placement

        if seekers_together_radius is not None:
            str_in_cells = np.ceil(seekers_together_radius / cell_size).astype(int)

            env.metadata['seekers_together_radius'] = str_in_cells

            close_to_first_seeker_placement = proximity_placement(
                                                "agent", n_hiders, "seekers_together_radius")

            agent_placement_fn += [first_seeker_placement] + \
                                  [close_to_first_seeker_placement] * (n_seekers - 1)
        else:
            agent_placement_fn += [first_seeker_placement] * (n_seekers)

    elif scenario == 'quad':
        env.add_module(WallScenarios(grid_size=grid_size, door_size=door_size,
                                     scenario=scenario, friction=other_friction,
                                     p_door_dropout=p_door_dropout))
        box_placement_fn = uniform_placement
        ramp_placement_fn = uniform_placement
        hider_placement = uniform_placement if quad_game_hider_uniform_placement else quad_placement
        agent_placement_fn = [hider_placement] * n_hiders + [outside_quad_placement] * n_seekers
    else:
        raise ValueError(f"Scenario {scenario} not supported.")




    env.add_module(Agents(n_hiders + n_seekers,
                          placement_fn=agent_placement_fn,
                          color=[np.array((66., 235., 244., 255.)) / 255] * n_hiders + [np.array((240., 20., 50., 255.)) / 255] * n_seekers,
                          friction=other_friction,
                          polar_obs=polar_obs))
    if np.max(n_boxes) > 0:
        env.add_module(Boxes(n_boxes=n_boxes, placement_fn=box_placement_fn,
                             friction=box_floor_friction, polar_obs=polar_obs,
                             n_elongated_boxes=n_elongated_boxes,
                             boxid_obs=boxid_obs, box_only_z_rot=box_only_z_rot))
    if n_ramps > 0:
        env.add_module(Ramps(n_ramps=n_ramps, placement_fn=ramp_placement_fn, friction=other_friction, polar_obs=polar_obs,
                             pad_ramp_size=(np.max(n_elongated_boxes) > 0)))

    # if box_floor_friction is not None:
    #     env.add_module(FloorAttributes(friction=box_floor_friction))

    keys_self = ['agent_q_qd', 'hider', 'prep_obs']
    keys_mask_self = ['mask_aa_con']
    keys_external = ['agent_q_qd']
    keys_copy = ['you_lock', 'team_lock', 'ramp_you_lock', 'ramp_team_lock', 'door_obs']
    keys_mask_external = []
    # env = SplitMultiAgentActions(env)
    if team_size_obs:
        keys_self += ['team_size']
    env = TeamMembership(env, np.append(np.zeros((n_hiders,)), np.ones((n_seekers,))))
    # env = AgentAgentObsMask2D(env)
    # env = AgentAgentContactMask2D(env)
    # hider_obs = np.array([[1]] * n_hiders + [[0]] * n_seekers)
    # env = AddConstantObservationsWrapper(env, new_obs={'hider': hider_obs})
    # env = HideAndSeekRewardWrapper(env, n_hiders=n_hiders, n_seekers=n_seekers,
    #                                rew_type=rew_type)

    # env = PreparationPhase(env, prep_fraction=prep_fraction)

    # if np.max(n_boxes) > 0:
    #     # env = AgentGeomObsMask2D(env, pos_obs_key='box_pos', mask_obs_key='mask_ab_obs',
    #     #                          geom_idxs_obs_key='box_geom_idxs')
    #     keys_external += ['box_obs']

    # # if lock_box and np.max(n_boxes) > 0:
    # #     env = LockObjWrapper(env, body_names=[f'moveable_box{i}' for i in range(np.max(n_boxes))],
    # #                          agent_idx_allowed_to_lock=np.arange(n_hiders+n_seekers),
    # #                          lock_type=lock_type, radius_multiplier=lock_radius_multiplier,
    # #                          obj_in_game_metadata_keys=["curr_n_boxes"],
    # #                          agent_allowed_to_lock_keys=None if lock_out_of_vision else ["mask_ab_obs"])

    # if n_ramps > 0:
    #     # if lock_ramp:
    #     #     env = LockObjWrapper(env, body_names=[f'ramp{i}:ramp' for i in range(n_ramps)],
    #     #                          agent_idx_allowed_to_lock=np.arange(n_hiders+n_seekers),
    #     #                          lock_type=lock_type, ac_obs_prefix='ramp_',
    #     #                          radius_multiplier=lock_radius_multiplier,
    #     #                          obj_in_game_metadata_keys=['curr_n_ramps'],
    #     #                          agent_allowed_to_lock_keys=None if lock_out_of_vision else ["mask_ar_obs"])
    #     keys_external += ['ramp_obs']


    # if prep_obs:
    #     env = TrackStatWrapper(env, np.max(n_boxes), n_ramps)
    # env = SplitObservations(env, keys_self + keys_mask_self, keys_copy=keys_copy, keys_self_matrices=keys_mask_self)
    # # env = SpoofEntityWrapper(env, np.max(n_boxes), ['box_obs', 'you_lock', 'team_lock', 'obj_lock'], ['mask_ab_obs'])

    # # if max_n_agents is not None:
    # #     env = SpoofEntityWrapper(env, max_n_agents, ['agent_q_qd', 'hider', 'prep_obs'], ['mask_aa_obs'])
    # # env = LockAllWrapper(env, remove_object_specific_lock=True)
    # env = NoActionsInPrepPhase(env, np.arange(n_hiders, n_hiders + n_seekers))
    # # env = ConcatenateObsWrapper(env, {'agent_q_qd': ['agent_q_qd'],
    # #                                   'box_obs': ['box_obs'],
    # #                                   'ramp_obs': ['ramp_obs']})
    # env = SelectKeysWrapper(env, keys_self=keys_self,
    #                         keys_other=keys_external + keys_mask_self + keys_mask_external)
    return env

In [8]:
test_env = make_env(random_key)

test_env.gen_sys(seed)

state = jax.jit(test_env.reset)(random_key)
# state = test_env.reset(random_key)

None


TypeError: 'NoneType' object is not subscriptable

In [None]:
def randomise_action(act, random_key):
    random_key, sub_key = random.split(random_key)
    return random.uniform(random_key, shape=act.shape, minval=-0.25, maxval=0.25), random_key

In [None]:
jit_step_fn = jax.jit(test_env.step)
act_size = test_env.sys.act_size()

act = jp.zeros(shape=(test_env.n_agents, act_size // test_env.n_agents))

In [None]:
rollout = []

for i in range(horizon):
    act, random_key = jax.lax.cond(i % 50 == 0, randomise_action, lambda x, y: (x, y), act, random_key)
    rollout.append(state.pipeline_state)
    # state = test_env.step(state, act)
    state = jit_step_fn(state, act)

html.save('no_outer.html', test_env.sys, rollout)

d_obs:  {'agent_q_qd': Traced<ShapedArray(float32[4,6])>with<DynamicJaxprTrace(level=1/0)>, 'agent_pos': Traced<ShapedArray(float32[4,3])>with<DynamicJaxprTrace(level=1/0)>, 'box_obs': Traced<ShapedArray(float32[2,15])>with<DynamicJaxprTrace(level=1/0)>, 'box_angle': Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace(level=1/0)>, 'box_pos': Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace(level=1/0)>, 'ramp_obs': Traced<ShapedArray(float32[1,15])>with<DynamicJaxprTrace(level=1/0)>, 'ramp_angle': Traced<ShapedArray(float32[1,3])>with<DynamicJaxprTrace(level=1/0)>, 'ramp_q': Traced<ShapedArray(float32[1,9])>with<DynamicJaxprTrace(level=1/0)>}


In [None]:
# @jax.jit
# def play_step_fn(state: State, act: Action, random_key: RNGKey, index: int):
#     act, random_key = jax.lax.cond(index % 50 == 0, randomise_action, lambda x, y: (x, y), act, random_key)
#     state = jit_step_fn(state, act)
#     return state, act, random_key, index + 1, state.pipeline_state

# def scan_play_step_fn(
#     carry: Tuple[State, Action, RNGKey, int], unused_arg: Any
# ) ->Tuple[Tuple[State, RNGKey, int], PipelineState]:
#     state, act, random_key, index, p_states = play_step_fn(*carry)
#     return (state, act, random_key, index), p_states
    

# (dst_state, dst_act, key, index), rollout = jax.lax.scan(scan_play_step_fn, (state, act, random_key, 0), None, length=horizon)

In [None]:
# states_list = []
# print(dst_state.info)
# print(dst_state.reward)

# for i in range(horizon):
#     s = jax.tree_util.tree_map(lambda x: x[i], rollout)
#     states_list.append(s)


# html.save('agents.html', test_env.sys, states_list)