In [1]:
import os, sys
sys.path.append('/rds/general/user/tla19/home/FYP/MAax')

In [2]:
from typing import Any, Callable, Tuple
from functools import partial
import time
import json

import flax
from brax import envs
from brax.io import model
from brax.io import html
from brax.training.agents.ppo import train as ppo
from brax.training.agents.sac import train as sac

In [3]:
import brax
import numpy as np
from brax.io import mjcf, html
from maax.envs.base import Base
from maax.util.types import RNGKey, PipelineState, Action
from brax.generalized import pipeline

import jax
from jax import numpy as jp
from jax import random

from IPython.display import HTML, clear_output
clear_output()


In [4]:
seed = 10
batch_sizes = [2, 4, 8, 16, 32, 64, 128, 256, 512]
episode_length = 1000
random_key = jax.random.PRNGKey(seed)

In [5]:
env_name = 'ant'  # @param ['ant', 'halfcheetah', 'hopper', 'humanoid', 'humanoidstandup', 'inverted_pendulum', 'inverted_double_pendulum', 'pusher', 'reacher', 'walker2d']
backend = 'generalized'  # @param ['generalized', 'positional', 'spring']

env = envs.get_environment(env_name=env_name,
                           backend=backend)
state = jax.jit(env.reset)(rng=jax.random.PRNGKey(seed=0))

In [6]:
@jax.jit
def randomise_action(act, random_key):
    random_key, _ = random.split(random_key)
    return random.uniform(random_key, shape=act.shape, minval=-0.25, maxval=0.25), random_key

In [7]:
# Create the inital environment states

jit_step_fn = jax.jit(env.step)
act_size = env.sys.act_size()
act = jp.zeros(shape=act_size)
jit_batch_reset_fn = jax.jit(jax.vmap(env.reset))

In [8]:
@jax.jit
def play_step_fn(state, act, random_key):
    act, random_key = randomise_action(act, random_key)
    state = jit_step_fn(state, act)
    return state, act, random_key, state.pipeline_state

@partial(jax.jit, static_argnames=("play_step_fn", "episode_length"))
def generate_unroll(
    init_state,
    act,
    random_key,
    episode_length,
    play_step_fn):
    """Generates an episode according to random action, returns the final state of
    the episode and the transitions of the episode.

    Args:
        init_state: first state of the rollout.
        act: The initial action
        random_key: random key for stochasiticity handling.
        episode_length: length of the rollout.
        index: index of the rollout.
        play_step_fn: function describing how a step need to be taken.

    Returns:
        A new state, the experienced transition.
    """
    @jax.jit
    def scan_play_step_fn(
        carry, unused_arg):
        state, act, random_key, p_states = play_step_fn(*carry)
        return (state, act, random_key), p_states


    (dst_state, dst_act, key), rollout = jax.lax.scan(
        scan_play_step_fn, (init_state, act, random_key), None, length=episode_length)

    return dst_state, rollout

unroll_fn = partial(
    generate_unroll,
    episode_length=episode_length,
    play_step_fn=play_step_fn,
)

In [9]:
# # Run rollouts and time them
batch_rollout_fn = jax.jit(jax.vmap(unroll_fn))
batch_time = dict()
iterations = 6

# Perform rollout for each batch size
for i in range(iterations):
    for batch_size in batch_sizes:
        random_key, subkey = jax.random.split(random_key)
        keys = jax.random.split(subkey, num=batch_size)
        # Define initial batches states and actions
        init_states = jit_batch_reset_fn(keys)
        acts = jp.zeros(shape=(batch_size, env.sys.act_size()), dtype=jp.float32)
        start_time = time.time()
        dst_states, rollouts = batch_rollout_fn(init_states, acts, keys)
        et = time.time()
        dt = et - start_time
        print(f"Rollout time for batch size {batch_size} : {dt}")
        if batch_size in batch_time:
            batch_time[batch_size].append(dt)
        else:
            batch_time[batch_size] = [dt]



Rollout time for batch size 2 : 16.944783210754395
Rollout time for batch size 4 : 15.438798189163208
Rollout time for batch size 8 : 15.39379072189331
Rollout time for batch size 16 : 15.923971176147461
Rollout time for batch size 32 : 17.367851734161377
Rollout time for batch size 64 : 19.5241641998291
Rollout time for batch size 128 : 22.770678758621216
Rollout time for batch size 256 : 31.373990297317505
Rollout time for batch size 512 : 57.236847162246704
Rollout time for batch size 2 : 9.158714056015015
Rollout time for batch size 4 : 9.362637042999268
Rollout time for batch size 8 : 9.544764280319214
Rollout time for batch size 16 : 10.09183955192566
Rollout time for batch size 32 : 11.41162657737732
Rollout time for batch size 64 : 12.919147729873657
Rollout time for batch size 128 : 16.521596431732178
Rollout time for batch size 256 : 25.45203185081482
Rollout time for batch size 512 : 51.08055019378662
Rollout time for batch size 2 : 9.215003252029419
Rollout time for batch s

In [10]:
# batch_size = 128
# random_key, subkey = jax.random.split(random_key)
# keys = jax.random.split(subkey, num=batch_size)
# # Define initial batches states and actions
# init_states = jit_batch_reset_fn(keys)
# acts = jp.zeros(shape=(batch_size, env.sys.act_size()), dtype=jp.float32)
# start_time = time.time()
# dst_states, rollouts = batch_rollout_fn(init_states, acts, keys)
# et = time.time()
# dt = et - start_time
# print(f"Rollout time for batch size {batch_size} : {dt}")
# batch_time[batch_size] = dt

In [12]:
# Save batch times
with open('ant_batch_times{}.json'.format(episode_length), 'w') as f: 
    json.dump(batch_time, f)