In [1]:
import os, sys
sys.path.append('/Users/tom/dev/imperial/FYP/MAax/')

In [2]:
from typing import Any, Callable, Tuple

In [3]:
from worldgen import Floor, WorldBuilder
from mae_envs.modules.agents import Agents
from mae_envs.modules.walls import RandomWalls, WallScenarios
from mae_envs.modules.world import FloorAttributes, WorldConstants
from mae_envs.modules.objects import Boxes, Cylinders, LidarSites, Ramps
from mae_envs.modules.util import uniform_placement, center_placement
from mae_envs.envs.hide_and_seek import quad_placement


In [4]:
import brax
import numpy as np
from brax.io import mjcf, html
from maax.envs.base import Base
from mae_envs.util.types import RNGKey, PipelineState, Action
from brax.generalized import pipeline

import jax
from jax import numpy as jp

from jax import random

from IPython.display import HTML, clear_output
clear_output()


In [5]:
from brax.envs.env import State

In [6]:
seed = 80
episode_length = 1
random_key = jax.random.PRNGKey(seed)

In [7]:
def make_env(n_frames=1, horizon=80, deterministic_mode=False,
             floor_size=6.0, grid_size=30, door_size=2,
             n_agents=2, fixed_agent_spawn=False,
             lock_box=True, grab_box=True, grab_selective=False,
             lock_type='any_lock_specific',
             lock_grab_radius=0.25, grab_exclusive=False, grab_out_of_vision=False,
             lock_out_of_vision=True,
             box_floor_friction=0.2, other_friction=0.01, gravity=[0, 0, -50],
             action_lims=(-0.9, 0.9), polar_obs=True,
             scenario='var_quad', p_door_dropout=0.0,
             n_rooms=4, random_room_number=True,
             n_lidar_per_agent=1, visualize_lidar=True, compress_lidar_scale=None,
             n_boxes=2, box_size=0.5, box_only_z_rot=False,
             boxid_obs=True, boxsize_obs=True, pad_ramp_size=True, additional_obs={},
             # lock-box task
             task_type='all', lock_reward=5.0, unlock_penalty=7.0, shaped_reward_scale=0.25,
             return_threshold=0.1,
             # ramps
             n_ramps=1):
    '''
        This make_env function is not used anywhere; it exists to provide a simple, bare-bones
        example of how to construct a multi-agent environment using the modules framework.
    '''
    env = Base(n_agents=n_agents, n_frames=n_frames, horizon=horizon, grid_size=grid_size,
               deterministic_mode=deterministic_mode, seed=seed)
    env.add_module(WallScenarios(grid_size=grid_size, door_size=door_size,
                                     scenario=scenario, friction=other_friction,
                                     p_door_dropout=p_door_dropout))
    box_placement_fn = uniform_placement
    ramp_placement_fn = uniform_placement
    agent_placement_fn = uniform_placement

    env.add_module(Agents(n_agents,
                          placement_fn=agent_placement_fn,
                          color=[np.array((66., 235., 244., 255.)) / 255] * n_agents,
                          friction=other_friction,
                          polar_obs=polar_obs))

    if np.max(n_boxes) > 0:
        env.add_module(Boxes(n_boxes=n_boxes, placement_fn=box_placement_fn,
                             friction=box_floor_friction, polar_obs=polar_obs,
                             n_elongated_boxes=0,
                             boxid_obs=boxid_obs,
                             box_only_z_rot=box_only_z_rot,
                             boxsize_obs=boxsize_obs,
                             free=True))

    if n_ramps > 0:
        env.add_module(Ramps(n_ramps=n_ramps, placement_fn=ramp_placement_fn,
                             friction=other_friction, polar_obs=polar_obs,
                             pad_ramp_size=pad_ramp_size, free=True))

    # if n_lidar_per_agent > 0 and visualize_lidar:
    #     env.add_module(LidarSites(n_agents=n_agents, n_lidar_per_agent=n_lidar_per_agent))

    # env.add_module(WorldConstants(gravity=gravity))

    return env

In [8]:
test_env = make_env()

test_env.gen_sys(seed)

state = jax.jit(test_env.reset)(random_key)

In [9]:
def randomise_action(act, random_key):
    random_key, _ = random.split(random_key)
    return random.uniform(random_key, shape=act.shape, minval=-0.25, maxval=0.25), random_key

In [10]:
# rollout = []

# for i in range(500):
#     print(i)
    # act, rng = jax.lax.cond(i % 50 == 0, randomise_action, maintain_action, act, act_size, rng)
#     rollout.append(state.pipeline_state)
#     state = jit_step_fn(state, act)

# html.save('agents.html', test_env.sys, rollout)

In [11]:

jit_step_fn = jax.jit(test_env.step)
act_size = test_env.sys.act_size()


act = jp.zeros(shape=act_size)

@jax.jit
def play_step_fn(state: State, act: Action, random_key: RNGKey, index: int):
    act, random_key = jax.lax.cond(index % 50 == 0, randomise_action, lambda x, y: (x, y), act, random_key)
    state = jit_step_fn(state, act)
    return state, act, random_key, index + 1, state.pipeline_state

def scan_play_step_fn(
    carry: Tuple[State, Action, RNGKey, int], unused_arg: Any
) ->Tuple[Tuple[State, RNGKey, int], PipelineState]:
    state, act, random_key, index, p_states = play_step_fn(*carry)
    return (state, act, random_key, index), p_states
    

(dst_state, dst_act, key, index), rollout = jax.lax.scan(scan_play_step_fn, (state, act, random_key, 0), None, length=episode_length)

In [12]:
states_list = []

for i in range(episode_length):
    s = jax.tree_util.tree_map(lambda x: x[i], rollout)
    states_list.append(s)


print(len(states_list))
html.save('uniform.html', test_env.sys, states_list)

1
