In [90]:
import jax
from jax import numpy as jp
from matplotlib.lines import Line2D
from matplotlib.patches import Circle
import matplotlib.pyplot as plt

import brax

In [91]:
print(brax.__version__)
import IPython
from IPython.display import HTML, clear_output
print(IPython.__version__)
from brax.io import model
from brax.io import json
from brax.io import html

0.9.0
8.13.2


In [92]:
from brax.io import mjcf

agent = mjcf.loads(
"""
<mujoco>
	<option timestep="0.005"/>
	<worldbody>
	    <geom size="40 40 40" type="plane"/>
		<body name="particle" pos="0.15 0.15 0.15">
			<joint axis="1 0 0" damping="0.1" name="agentx" pos="0 0 0" type="slide"></joint>
			<joint axis="0 1 0" damping="0.1" name="agenty" pos="0 0 0" type="slide"></joint>
			<joint axis="0 0 1" damping="0.1" name="agentz" pos="0 0 0" type="slide"></joint>
			<geom name="agent" mass="1" pos="0 0 0" rgba="1 0 0 1" size="0.15" type="sphere" euler="1.57 0 0"></geom>
		</body>
		<body name="annotation:outer_bound" pos="0.15 0.15 0.15">
			<geom conaffinity="0" contype="0" mass="0" pos="0 0 0" rgba="0.417 0.7203 0.0001 0.0" size="0.0001 0.0001 0.0001" type="box"></geom>
		</body>
	</worldbody>
	<actuator>
		<motor gear="100" joint="agentx"></motor>
		<motor gear="100" joint="agenty"></motor>
	</actuator>
</mujoco>
""")

In [93]:
agent = mjcf.loads(
"""
<mujoco>
	<option timestep="0.005"/>
	<worldbody>
	    <geom size="40 40 40" type="plane"/>
		<body name="particle" pos="0.15 0.15 0.15">
			<joint axis="1 0 0" damping="0.1" name="agent0:agentx" pos="0 0 0" type="slide"></joint>
			<joint axis="0 1 0" damping="0.1" name="agent0:agenty" pos="0 0 0" type="slide"></joint>
			<joint axis="0 0 1" damping="1" name="agent0:agentz" pos="0 0 0" type="slide"></joint>
			<geom name="agent0:agent" mass="1" pos="0 0 0" rgba="0.258824 0.921569 0.956863 1" size="0.15" type="sphere" euler="1.57 0 0" friction="0.01"></geom>
		</body>
	</worldbody>
	<actuator>
		<motor gear="100" joint="agent0:agentx"></motor>
		<motor gear="100" joint="agent0:agenty"></motor>
	</actuator>
</mujoco>
""")


In [94]:
print(agent.init_q)

[0. 0. 0.]


In [95]:
#@title { run: "auto"}
from brax.generalized import pipeline

elasticity = 0.85 #@param { type:"slider", min: 0.5, max: 1.0, step:0.05 }
ball_velocity = 1 #@param { type:"slider", min:-5, max:5, step: 0.5 }

# provide an initial velocity to the ball
qd = jp.zeros(agent.qd_size())
state = jax.jit(pipeline.init)(agent, agent.init_q, qd)

rollout = []


for i in range(100):
    rollout.append(state)
    act = jp.asarray([0.1, 0.1])
    state = jax.jit(pipeline.step)(agent, state, act)

html.save('agent.html', agent, rollout)