In [1]:
import os, sys
# sys.path.append('/rds/general/user/tla19/home/FYP/MAax')
sys.path.append('/Users/tom/dev/imperial/FYP/MAax/')

In [2]:
from typing import Any, Callable, Tuple
from functools import partial
import time
import json

In [3]:
from worldgen import Floor, WorldBuilder, WorldParams
from mae_envs.modules.agents import Agents
from mae_envs.modules.walls import RandomWalls, WallScenarios
from mae_envs.modules.world import FloorAttributes, WorldConstants
from mae_envs.modules.objects import Boxes, Ramps
from mae_envs.modules.util import uniform_placement, centre_placement
from mae_envs.envs.hide_and_seek import quad_placement


In [4]:
import brax
import numpy as np
from brax.io import mjcf, html
from maax.envs.base import Base
from mae_envs.util.types import RNGKey, PipelineState, Action
from brax.generalized import pipeline

import jax
from jax import numpy as jp

from jax import random

from IPython.display import HTML, clear_output
clear_output()


In [5]:
from brax.envs.env import State

In [6]:
seed = 7
random_key = jax.random.PRNGKey(seed)

In [7]:
def make_env(n_boxes=2,n_frames=1, horizon=1000, deterministic_mode=False,
             floor_size=6.0, grid_size=30, door_size=2,
             n_hiders=2, n_seekers=2, max_n_agents=None,
             n_ramps=0, n_elongated_boxes=0,
             rand_num_elongated_boxes=False, n_min_boxes=None,
             box_size=0.5, boxid_obs=True, boxsize_obs=True, box_only_z_rot=True,
             pad_ramp_size=True,
             rew_type='joint_zero_sum',
             lock_box=True, grab_box=True, lock_ramp=True,
             lock_type='any_lock_specific',
             lock_grab_radius=0.25, lock_out_of_vision=True, grab_exclusive=False,
             grab_out_of_vision=False, grab_selective=False,
             box_floor_friction=0.2, other_friction=0.01, gravity=[0, 0, -50],
             action_lims=(-0.9, 0.9), polar_obs=True,
             scenario='quad', quad_game_hider_uniform_placement=False,
             p_door_dropout=0.0,
             n_rooms=4, random_room_number=True, prob_outside_walls=1.0,
             n_lidar_per_agent=0, visualize_lidar=False, compress_lidar_scale=None,
             hiders_together_radius=None, seekers_together_radius=None,
             prep_fraction=0.4, prep_rem=False,
             team_size_obs=False,
             restrict_rect=None, penalize_objects_out=False,
             ):
    '''
        This make_env function is not used anywhere; it exists to provide a simple, bare-bones
        example of how to construct a multi-agent environment using the modules framework.
    '''
    n_agents = n_seekers + n_hiders
    env = Base(n_agents=n_agents, n_frames=n_frames, horizon=horizon, grid_size=grid_size,
               deterministic_mode=deterministic_mode, seed=seed)
    env.add_module(WallScenarios(grid_size=grid_size, door_size=door_size,
                                     scenario=scenario, friction=other_friction,
                                     p_door_dropout=p_door_dropout))
    box_placement_fn = uniform_placement
    ramp_placement_fn = uniform_placement
    agent_placement_fn = uniform_placement

    env.add_module(Agents(n_agents,
                          placement_fn=agent_placement_fn,
                          color=[np.array((66., 235., 244., 255.)) / 255] * n_agents,
                          friction=other_friction,
                          polar_obs=polar_obs))

    if np.max(n_boxes) > 0:
        env.add_module(Boxes(n_boxes=n_boxes, placement_fn=box_placement_fn,
                             friction=box_floor_friction, polar_obs=polar_obs,
                             n_elongated_boxes=0,
                             boxid_obs=boxid_obs,
                             box_only_z_rot=box_only_z_rot,
                             boxsize_obs=boxsize_obs,
                             free=False))

    # if n_ramps > 0:
    #     env.add_module(Ramps(n_ramps=n_ramps, placement_fn=ramp_placement_fn,
    #                          friction=other_friction, polar_obs=polar_obs,
    #                          pad_ramp_size=pad_ramp_size, free=True))


    return env

In [8]:
def randomise_action(act, random_key):
    random_key, _ = random.split(random_key)
    return random.uniform(random_key, shape=act.shape, minval=-0.25, maxval=0.25), random_key

In [9]:
obj_max = 10
iterations = 6
episode_length = 1000


In [10]:
@jax.jit
def play_step_fn(state: State, act: Action, random_key: RNGKey, index: int):
    act, random_key = jax.lax.cond(index % 50 == 0, randomise_action, lambda x, y: (x, y), act, random_key)
    state = jit_step_fn(state, act)
    return state, act, random_key, index + 1, state.pipeline_state

def scan_play_step_fn(
    carry: Tuple[State, Action, RNGKey, int], unused_arg: Any
) ->Tuple[Tuple[State, RNGKey, int], PipelineState]:
    state, act, random_key, index, p_states = play_step_fn(*carry)
    return (state, act, random_key, index), p_states

obj_times = dict()
jit_scan_fn = jax.jit(scan_play_step_fn)

for i in range(iterations):
    for obj_cnt in range(obj_max):
        test_env = make_env(n_boxes=obj_cnt)
        test_env.gen_sys(seed)
        
        act = jp.zeros(test_env.sys.act_size())
        jit_step_fn = jax.jit(test_env.step)
        state = jax.jit(test_env.reset)(random_key)

        
        st = time.time()
        (dst_state, dst_act, key, index), rollout = jax.lax.scan(jit_scan_fn, (state, act, random_key, 0), None, length=episode_length)
        et = time.time()
        dt = et - st
        print(f"Rollout time for obj cnt {obj_cnt} : {dt}")
        if obj_cnt in obj_times:
            obj_times[obj_cnt].append(dt)
        else:
            obj_times[obj_cnt] = [dt]

Rollout time for obj cnt 0 : 19.12911891937256
Rollout time for obj cnt 1 : 29.463647842407227
Rollout time for obj cnt 2 : 38.35092115402222
Rollout time for obj cnt 3 : 51.94494271278381
Rollout time for obj cnt 4 : 68.76998591423035
Rollout time for obj cnt 5 : 108.74006986618042
Rollout time for obj cnt 6 : 154.81184196472168
Rollout time for obj cnt 7 : 205.7165207862854
Rollout time for obj cnt 8 : 256.93288397789
Rollout time for obj cnt 9 : 359.30365014076233
Rollout time for obj cnt 0 : 13.044253826141357
Rollout time for obj cnt 1 : 22.828088998794556
Rollout time for obj cnt 2 : 31.58177399635315
Rollout time for obj cnt 3 : 43.52061486244202
Rollout time for obj cnt 4 : 61.2439169883728
Rollout time for obj cnt 5 : 97.8584349155426
Rollout time for obj cnt 6 : 136.43447303771973
Rollout time for obj cnt 7 : 187.94957900047302
Rollout time for obj cnt 8 : 248.3357629776001
Rollout time for obj cnt 9 : 341.78411388397217
Rollout time for obj cnt 0 : 12.936373949050903
Rollout

KeyboardInterrupt: 

In [None]:
with open('boxes_times_slide_avg.json', 'w') as f: 
    json.dump(obj_times, f)