In [1]:
import jax
from jax import numpy as jp
from matplotlib.lines import Line2D
from matplotlib.patches import Circle
import matplotlib.pyplot as plt

from jax import random

try:
  import brax
except ImportError:
  from IPython.display import clear_output 
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

from typing import Any, Callable, Tuple

import time

In [2]:
seed = 1 
rng = random.PRNGKey(seed)
horizon = 1000

In [3]:

from brax.io import mjcf

sys = mjcf.loads(
    """
    <mujoco>
	<compiler angle="radian" coordinate="local" meshdir="/Users/tom/dev/imperial/FYP/MAax/assets/stls" texturedir="/Users/tom/dev/imperial/FYP/MAax/assets/textures" autolimits="true"></compiler>
	<option timestep="0.01">
		<flag warmstart="enable"></flag>
	</option>
	<worldbody>
		<geom name="floor0" pos="3 3 0" size="3 3 1" type="plane" condim="3"></geom>
		<body name="floor0" pos="3 3 0"></body>
		<body pos="0.15 0.15 0.15" name="agent0_agent0">
			<joint axis="1 0 0" damping="0.01" name="agent0_slide0" pos="0 0 0" type="slide" limited="auto" range="-100 100"></joint>
			<joint axis="0 1 0" damping="0.01" name="agent0_slide1" pos="0 0 0" type="slide" limited="auto" range="-100 100"></joint>
			<joint axis="0 0 1" damping="0.01" name="agent0_hinge0" pos="0 0 0" type="hinge" limited="auto" range="-100 100"></joint>
			<geom name="agent0_agent" mass="1" pos="0 0 0" rgba="0.258824 0.921569 0.956863 1" size="0.15" type="sphere" euler="1.57 0 0" friction="0.01"></geom>
		</body>
		<body pos="0.15 0.15 0.15" name="agent1_agent1">
			<joint axis="1 0 0" damping="0.01" name="agent1_slide0" pos="0 0 0" type="slide" limited="auto" range="-100 100"></joint>
			<joint axis="0 1 0" damping="0.01" name="agent1_slide1" pos="0 0 0" type="slide" limited="auto" range="-100 100"></joint>
			<joint axis="0 0 1" damping="0.01" name="agent1_hinge0" pos="0 0 0" type="hinge" limited="auto" range="-100 100"></joint>
			<geom name="agent1_agent" mass="1" pos="0 0 0" rgba="0.258824 0.921569 0.956863 1" size="0.15" type="sphere" euler="1.57 0 0" friction="0.01"></geom>
		</body>
		<body pos="0.15 0.15 0.15" name="agent2_agent2">
			<joint axis="1 0 0" damping="0.01" name="agent2_slide0" pos="0 0 0" type="slide" limited="auto" range="-100 100"></joint>
			<joint axis="0 1 0" damping="0.01" name="agent2_slide1" pos="0 0 0" type="slide" limited="auto" range="-100 100"></joint>
			<joint axis="0 0 1" damping="0.01" name="agent2_hinge0" pos="0 0 0" type="hinge" limited="auto" range="-100 100"></joint>
			<geom name="agent2_agent" mass="1" pos="0 0 0" rgba="0.258824 0.921569 0.956863 1" size="0.15" type="sphere" euler="1.57 0 0" friction="0.01"></geom>
		</body>
		<body pos="0.15 0.15 0.15" name="agent3_agent3">
			<joint axis="1 0 0" damping="0.01" name="agent3_slide0" pos="0 0 0" type="slide" limited="auto" range="-100 100"></joint>
			<joint axis="0 1 0" damping="0.01" name="agent3_slide1" pos="0 0 0" type="slide" limited="auto" range="-100 100"></joint>
			<joint axis="0 0 1" damping="0.01" name="agent3_hinge0" pos="0 0 0" type="hinge" limited="auto" range="-100 100"></joint>
			<geom name="agent3_agent" mass="1" pos="0 0 0" rgba="0.258824 0.921569 0.956863 1" size="0.15" type="sphere" euler="1.57 0 0" friction="0.01"></geom>
		</body>
		<body name="moveable-box0" pos="0.25 0.25 0.25">
			<joint name="moveable-box0_slide0" axis="1 0 0" type="slide" damping="0.01" pos="0 0 0" limited="auto"></joint>
			<joint name="moveable-box0_slide1" axis="0 1 0" type="slide" damping="0.01" pos="0 0 0" limited="auto"></joint>
			<joint name="moveable-box0_slide2" axis="0 0 1" type="slide" damping="0.01" pos="0 0 0" limited="auto"></joint>
			<geom size="0.25 0.25 0.25" type="box" condim="3" name="moveable-box0" rgba="1 0.5 0.8 1" mass="1" friction="0.2"></geom>
		</body>
		<body name="moveable-box1" pos="0.25 0.25 0.25">
			<joint name="moveable-box1_slide0" axis="1 0 0" type="slide" damping="0.01" pos="0 0 0" limited="auto"></joint>
			<joint name="moveable-box1_slide1" axis="0 1 0" type="slide" damping="0.01" pos="0 0 0" limited="auto"></joint>
			<joint name="moveable-box1_slide2" axis="0 0 1" type="slide" damping="0.01" pos="0 0 0" limited="auto"></joint>
			<geom size="0.25 0.25 0.25" type="box" condim="3" name="moveable-box1" rgba="1 0.5 0.8 1" mass="1" friction="0.2"></geom>
		</body>
		<body name="moveable-box2" pos="0.25 0.25 0.25">
			<joint name="moveable-box2_slide0" axis="1 0 0" type="slide" damping="0.01" pos="0 0 0" limited="auto"></joint>
			<joint name="moveable-box2_slide1" axis="0 1 0" type="slide" damping="0.01" pos="0 0 0" limited="auto"></joint>
			<joint name="moveable-box2_slide2" axis="0 0 1" type="slide" damping="0.01" pos="0 0 0" limited="auto"></joint>
			<geom size="0.25 0.25 0.25" type="box" condim="3" name="moveable-box2" rgba="1 0.5 0.8 1" mass="1" friction="0.2"></geom>
		</body>
		<body name="moveable-box3" pos="0.25 0.25 0.25">
			<joint name="moveable-box3_slide0" axis="1 0 0" type="slide" damping="0.01" pos="0 0 0" limited="auto"></joint>
			<joint name="moveable-box3_slide1" axis="0 1 0" type="slide" damping="0.01" pos="0 0 0" limited="auto"></joint>
			<joint name="moveable-box3_slide2" axis="0 0 1" type="slide" damping="0.01" pos="0 0 0" limited="auto"></joint>
			<geom size="0.25 0.25 0.25" type="box" condim="3" name="moveable-box3" rgba="1 0.5 0.8 1" mass="1" friction="0.2"></geom>
		</body>
	</worldbody>
	<actuator>
		<motor gear="100" joint="agent0_slide0"></motor>
		<motor gear="100" joint="agent0_slide1"></motor>
		<motor gear="100" joint="agent0_hinge0"></motor>
		<motor gear="100" joint="agent1_slide0"></motor>
		<motor gear="100" joint="agent1_slide1"></motor>
		<motor gear="100" joint="agent1_hinge0"></motor>
		<motor gear="100" joint="agent2_slide0"></motor>
		<motor gear="100" joint="agent2_slide1"></motor>
		<motor gear="100" joint="agent2_hinge0"></motor>
		<motor gear="100" joint="agent3_slide0"></motor>
		<motor gear="100" joint="agent3_slide1"></motor>
		<motor gear="100" joint="agent3_hinge0"></motor>
	</actuator>
</mujoco>
  """)

In [4]:
from brax.generalized import pipeline


qd = jp.zeros(shape=sys.qd_size())
state = jax.jit(pipeline.init)(sys, sys.init_q, qd)
act = jp.zeros(shape=(sys.act_size(),))

jit_step_fn = jax.jit(pipeline.step)



In [5]:
def randomise_action(act, random_key):
    random_key, sub_key = random.split(random_key)
    return random.uniform(random_key, shape=act.shape, minval=-0.25, maxval=0.25), random_key

In [6]:
state = jit_step_fn(sys, state, act)
act = jp.zeros(shape=(sys.act_size(),))

In [7]:
for i in range(100):
    # state = test_env.step(state, act)
    state = jit_step_fn(sys, state, act)


In [8]:
@jax.jit
def play_step_fn(state, sys, act, random_key, index: int):
    act, random_key = jax.lax.cond(index % 50 == 0, randomise_action, lambda x, y: (x, y), act, random_key)
    state = jit_step_fn(state, sys, act)
    return state, sys, act, random_key, index + 1, state.pipeline_state

def scan_play_step_fn(
    carry, unused_arg
):
    state, sys, act, random_key, index, p_states = play_step_fn(*carry)
    return (state, sys, act, random_key, index), p_states
    
st = time.time()
(dst_state, sys, dst_act, key, index), rollout = jax.lax.scan(scan_play_step_fn, (state, sys, act, rng, 0), None, length=horizon)
print(time.time() - st)

AttributeError: 'System' object has no attribute 'q'