In [1]:
import os, sys
sys.path.append('/rds/general/user/tla19/home/FYP/MAax')

In [2]:
from typing import Any, Callable, Tuple
from functools import partial
import time
import json

In [3]:
from worldgen import Floor, WorldBuilder, WorldParams
from brax.envs.env import State
from maax.modules.agents import Agents
from maax.modules.walls import RandomWalls, WallScenarios
from maax.modules.world import FloorAttributes, WorldConstants
from maax.modules.objects import Boxes, Ramps
from maax.modules.util import uniform_placement, centre_placement
from maax.envs.hide_and_seek import quad_placement


In [4]:
import brax
import numpy as np
from brax.io import mjcf, html
from maax.envs.base import Base
from maax.util.types import RNGKey, PipelineState, Action
from brax.generalized import pipeline

import jax
from jax import numpy as jp
from jax import random

from IPython.display import HTML, clear_output
clear_output()


In [5]:
seed = 10
batch_sizes = [2, 4, 8, 16, 32, 64, 128, 256, 512]
episode_length = 1000
random_key = jax.random.PRNGKey(seed)

In [6]:
def make_env(n_frames=1, horizon=80, deterministic_mode=False,
             floor_size=6.0, grid_size=30, door_size=2,
             n_hiders=2, n_seekers=2, max_n_agents=None,
             n_boxes=1, n_ramps=1, n_elongated_boxes=0,
             rand_num_elongated_boxes=False, n_min_boxes=None,
             box_size=0.5, boxid_obs=True, boxsize_obs=True, box_only_z_rot=True,
             pad_ramp_size=True,
             rew_type='joint_zero_sum',
             lock_box=True, grab_box=True, lock_ramp=True,
             lock_type='any_lock_specific',
             lock_grab_radius=0.25, lock_out_of_vision=True, grab_exclusive=False,
             grab_out_of_vision=False, grab_selective=False,
             box_floor_friction=0.2, other_friction=0.01, gravity=[0, 0, -50],
             action_lims=(-0.9, 0.9), polar_obs=True,
             scenario='quad', quad_game_hider_uniform_placement=False,
             p_door_dropout=0.0,
             n_rooms=4, random_room_count=True, prob_outside_walls=1.0,
             n_lidar_per_agent=0, visualize_lidar=False, compress_lidar_scale=None,
             hiders_together_radius=None, seekers_together_radius=None,
             prep_fraction=0.4, prep_rem=False,
             team_size_obs=False,
             restrict_rect=None, penalize_objects_out=False,
             ):
    '''
        This make_env function is not used anywhere; it exists to provide a simple, bare-bones
        example of how to construct a multi-agent environment using the modules framework.
    '''
    n_agents = n_seekers + n_hiders
    env = Base(n_agents=n_agents, n_frames=n_frames, horizon=horizon, grid_size=grid_size,
               deterministic_mode=deterministic_mode, seed=seed)
#     env.add_module(WallScenarios(grid_size=grid_size, door_size=door_size,
#                                      scenario=scenario, friction=other_friction,
#                                      p_door_dropout=p_door_dropout))
    box_placement_fn = uniform_placement
    ramp_placement_fn = uniform_placement
    agent_placement_fn = uniform_placement

    env.add_module(Agents(n_agents,
                          placement_fn=agent_placement_fn,
                          color=[np.array((66., 235., 244., 255.)) / 255] * n_agents,
                          friction=other_friction,
                          polar_obs=polar_obs))

    if np.max(n_boxes) > 0:
        env.add_module(Boxes(n_boxes=n_boxes, placement_fn=box_placement_fn,
                             friction=box_floor_friction, polar_obs=polar_obs,
                             n_elongated_boxes=0,
                             boxid_obs=boxid_obs,
                             box_only_z_rot=box_only_z_rot,
                             boxsize_obs=boxsize_obs,
                             free=True))

    if n_ramps > 0:
        env.add_module(Ramps(n_ramps=n_ramps, placement_fn=ramp_placement_fn,
                             friction=other_friction, polar_obs=polar_obs,
                             pad_ramp_size=pad_ramp_size, free=True))

    # if n_lidar_per_agent > 0 and visualize_lidar:
    #     env.add_module(LidarSites(n_agents=n_agents, n_lidar_per_agent=n_lidar_per_agent))

    # env.add_module(WorldConstants(gravity=gravity))

    return env

In [7]:
env = make_env()
env.gen_sys(seed)

In [8]:
@jax.jit
def randomise_action(act, random_key):
    random_key, _ = random.split(random_key)
    return random.uniform(random_key, shape=act.shape, minval=-0.25, maxval=0.25), random_key

In [9]:
# Create the inital environment states
act_size = env.sys.act_size()

jit_step_fn = jax.jit(env.step)
jit_batch_reset_fn = jax.jit(jax.vmap(env.reset))

In [10]:
@jax.jit
def play_step_fn(state: State, act: Action, random_key: RNGKey):
    act, random_key = randomise_action(act, random_key)
    state = jit_step_fn(state, act)
    return state, act, random_key, state.pipeline_state

@partial(jax.jit, static_argnames=("play_step_fn", "episode_length"))
def generate_unroll(
    init_state: State,
    act: Action,
    random_key: RNGKey,
    episode_length: int,
    play_step_fn) -> Tuple[State, Action, RNGKey]:
    """Generates an episode according to random action, returns the final state of
    the episode and the transitions of the episode.

    Args:
        init_state: first state of the rollout.
        act: The initial action
        random_key: random key for stochasiticity handling.
        episode_length: length of the rollout.
        index: index of the rollout.
        play_step_fn: function describing how a step need to be taken.

    Returns:
        A new state, the experienced transition.
    """
    @jax.jit
    def scan_play_step_fn(
        carry: Tuple[State, Action, RNGKey], unused_arg: Any) ->Tuple[Tuple[State, Action, RNGKey], PipelineState]:
        state, act, random_key, p_states = play_step_fn(*carry)
        return (state, act, random_key), p_states


    (dst_state, dst_act, key), rollout = jax.lax.scan(
        scan_play_step_fn, (init_state, act, random_key), None, length=episode_length)

    return dst_state, rollout

unroll_fn = partial(
    generate_unroll,
    episode_length=episode_length,
    play_step_fn=play_step_fn,
)

In [11]:
# # Run rollouts and time them
batch_rollout_fn = jax.jit(jax.vmap(unroll_fn))
batch_time = dict()
iterations = 6

# Perform rollout for each batch size
for i in range(iterations):
    for batch_size in batch_sizes:
        random_key, subkey = jax.random.split(random_key)
        keys = jax.random.split(subkey, num=batch_size)
        # Define initial batches states and actions
        init_states = jit_batch_reset_fn(keys)
        acts = jp.zeros(shape=(batch_size, env.sys.act_size()), dtype=jp.float32)
        start_time = time.time()
        dst_states, rollouts = batch_rollout_fn(init_states, acts, keys)
        et = time.time()
        dt = et - start_time
        print(f"Rollout time for batch size {batch_size} : {dt}")
        if batch_size in batch_time:
            batch_time[batch_size].append(dt)
        else:
            batch_time[batch_size] = [dt]



Rollout time for batch size 2 : 26.836660146713257
Rollout time for batch size 4 : 23.991141080856323
Rollout time for batch size 8 : 23.871275901794434
Rollout time for batch size 16 : 23.65753674507141
Rollout time for batch size 32 : 23.599459171295166
Rollout time for batch size 64 : 25.184409856796265
Rollout time for batch size 128 : 36.60212469100952
Rollout time for batch size 256 : 57.67041349411011


2023-06-16 17:13:13.041596: W external/xla/xla/service/hlo_rematerialization.cc:2209] Can't reduce memory use below 17.73GiB (19041386496 bytes) by rematerialization; only reduced to 21.05GiB (22606288036 bytes)
2023-06-16 17:13:33.044087: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 8.00GiB (rounded to 8594884608)requested by op 
2023-06-16 17:13:33.044609: W external/tsl/tsl/framework/bfc_allocator.cc:497] **************************************************************************************______________
2023-06-16 17:13:33.045664: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2469] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 8594884472 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   10.62MiB
              constant allocation:     8.8KiB
        maybe_live_out allocation:    9.98GiB
     preallocated temp allocati

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 8594884472 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   10.62MiB
              constant allocation:     8.8KiB
        maybe_live_out allocation:    9.98GiB
     preallocated temp allocation:    8.00GiB
                 total allocation:   18.00GiB
Peak buffers:
	Buffer 1:
		Size: 6.04GiB
		Operator: op_name="jit(<unnamed function>)/jit(main)/vmap(jit(generate_unroll))/broadcast_in_dim[shape=(1000, 512, 132, 24) broadcast_dimensions=()]" source_file="/var/tmp/pbs.7759721.pbs/ipykernel_3851952/2081353713.py" source_line=35
		XLA Label: broadcast
		Shape: f32[1000,512,132,24]
		==========================

	Buffer 2:
		Size: 6.04GiB
		Operator: op_name="jit(<unnamed function>)/jit(main)/vmap(jit(generate_unroll))/while[cond_nconsts=0 body_nconsts=0]" source_file="/var/tmp/pbs.7759721.pbs/ipykernel_3851952/2081353713.py" source_line=35
		XLA Label: transpose
		Shape: f32[512,1000,132,24]
		==========================

	Buffer 3:
		Size: 1.10GiB
		Operator: op_name="jit(<unnamed function>)/jit(main)/vmap(jit(generate_unroll))/broadcast_in_dim[shape=(1000, 512, 24, 24) broadcast_dimensions=()]" source_file="/var/tmp/pbs.7759721.pbs/ipykernel_3851952/2081353713.py" source_line=35
		XLA Label: broadcast
		Shape: f32[1000,512,24,24]
		==========================

	Buffer 4:
		Size: 1.10GiB
		Operator: op_name="jit(<unnamed function>)/jit(main)/vmap(jit(generate_unroll))/while[cond_nconsts=0 body_nconsts=0]" source_file="/var/tmp/pbs.7759721.pbs/ipykernel_3851952/2081353713.py" source_line=35
		XLA Label: transpose
		Shape: f32[512,1000,24,24]
		==========================

	Buffer 5:
		Size: 1.10GiB
		Operator: op_name="jit(<unnamed function>)/jit(main)/vmap(jit(generate_unroll))/while[cond_nconsts=0 body_nconsts=0]" source_file="/var/tmp/pbs.7759721.pbs/ipykernel_3851952/2081353713.py" source_line=35
		XLA Label: fusion
		Shape: f32[512,1000,24,24]
		==========================

	Buffer 6:
		Size: 257.81MiB
		XLA Label: copy
		Shape: f32[1000,512,132]
		==========================

	Buffer 7:
		Size: 257.81MiB
		Operator: op_name="jit(<unnamed function>)/jit(main)/vmap(jit(generate_unroll))/while[cond_nconsts=0 body_nconsts=0]" source_file="/var/tmp/pbs.7759721.pbs/ipykernel_3851952/2081353713.py" source_line=35
		XLA Label: transpose
		Shape: f32[512,1000,132]
		==========================

	Buffer 8:
		Size: 257.81MiB
		Operator: op_name="jit(<unnamed function>)/jit(main)/vmap(jit(generate_unroll))/while[cond_nconsts=0 body_nconsts=0]" source_file="/var/tmp/pbs.7759721.pbs/ipykernel_3851952/2081353713.py" source_line=35
		XLA Label: transpose
		Shape: f32[512,1000,132]
		==========================

	Buffer 9:
		Size: 140.62MiB
		XLA Label: fusion
		Shape: f32[1000,512,24,3]
		==========================

	Buffer 10:
		Size: 140.62MiB
		XLA Label: fusion
		Shape: f32[1000,512,24,3]
		==========================

	Buffer 11:
		Size: 140.62MiB
		Operator: op_name="jit(<unnamed function>)/jit(main)/vmap(jit(generate_unroll))/while[cond_nconsts=0 body_nconsts=0]" source_file="/var/tmp/pbs.7759721.pbs/ipykernel_3851952/2081353713.py" source_line=35
		XLA Label: transpose
		Shape: f32[512,1000,24,3]
		==========================

	Buffer 12:
		Size: 140.62MiB
		Operator: op_name="jit(<unnamed function>)/jit(main)/vmap(jit(generate_unroll))/while[cond_nconsts=0 body_nconsts=0]" source_file="/var/tmp/pbs.7759721.pbs/ipykernel_3851952/2081353713.py" source_line=35
		XLA Label: transpose
		Shape: f32[512,1000,24,3]
		==========================

	Buffer 13:
		Size: 140.62MiB
		Operator: op_name="jit(<unnamed function>)/jit(main)/vmap(jit(generate_unroll))/while[cond_nconsts=0 body_nconsts=0]" source_file="/var/tmp/pbs.7759721.pbs/ipykernel_3851952/2081353713.py" source_line=35
		XLA Label: transpose
		Shape: f32[512,1000,24,3]
		==========================

	Buffer 14:
		Size: 140.62MiB
		Operator: op_name="jit(<unnamed function>)/jit(main)/vmap(jit(generate_unroll))/while[cond_nconsts=0 body_nconsts=0]" source_file="/var/tmp/pbs.7759721.pbs/ipykernel_3851952/2081353713.py" source_line=35
		XLA Label: transpose
		Shape: f32[512,1000,24,3]
		==========================

	Buffer 15:
		Size: 105.47MiB
		Operator: op_name="jit(<unnamed function>)/jit(main)/vmap(jit(generate_unroll))/broadcast_in_dim[shape=(1000, 512, 6, 3, 3) broadcast_dimensions=()]" source_file="/var/tmp/pbs.7759721.pbs/ipykernel_3851952/2081353713.py" source_line=35
		XLA Label: broadcast
		Shape: f32[1000,512,6,3,3]
		==========================



In [None]:
# batch_size = 128
# random_key, subkey = jax.random.split(random_key)
# keys = jax.random.split(subkey, num=batch_size)
# # Define initial batches states and actions
# init_states = jit_batch_reset_fn(keys)
# acts = jp.zeros(shape=(batch_size, env.sys.act_size()), dtype=jp.float32)
# start_time = time.time()
# dst_states, rollouts = batch_rollout_fn(init_states, acts, keys)
# et = time.time()
# dt = et - start_time
# print(f"Rollout time for batch size {batch_size} : {dt}")
# batch_time[batch_size] = dt

In [None]:
# Save batch times
with open('batch_times_no_walls_2{}.json'.format(episode_length), 'w') as f: 
    json.dump(batch_time, f)