In [1]:
import functools
import time

from IPython.display import HTML, Image
import gym

import brax.v1 as brax

from brax.v1.physics import config_pb2


from brax.v1 import envs
from brax.v1.io import html
import jax
from jax import numpy as jnp
from jax import random

from builder import *
from builder import distribute
from builder.spawn import Spawner

import time
import json

In [2]:
seed = 42
key = random.PRNGKey(42)

In [3]:
def add_agent(config, agent_mass=1.0, radius=0.5):
    # Add the agent body
    agent = config.bodies.add(name='agent')
    sph = agent.colliders.add().sphere
    sph.radius = radius
    agent.mass = agent_mass
    # agent.damping = 1e-2
    # agent.friction = 0.6


# Adds joint for rolling the agent
def add_joint(config):
    joint = config.joints.add(name='rolling')
    joint.parent = 'ground'
    joint.child = 'agent'
    joint.angle = -1.57
    joint.twist = 1.0
    joint.limit.velocity = 100.0
    joint.limit.torque = 100.0
    joint.spring.stiffness = 1e5


def add_objects(config, key, object_count, cube_mass=1.0, cube_halfsize=0.5):
    key, *subkeys = random.split(key, object_count + 1)
    for i in range(object_count):
        cube = config.bodies.add(name=f'cube_{i}')
        box = cube.colliders.add().box
        box.halfsize.x = 0.5
        box.halfsize.y = 0.5
        box.halfsize.z = 0.5
        cube.mass = cube_mass


def distribute_objects(config, key, object_count):
    default = config.defaults.add()
    key, *subkeys = random.split(key, object_count + 1)
    for i in range(object_count):
        qp = default.qps.add(name=f'cube_{i}')
        qp.pos.x = random.uniform(key=subkeys[i], minval=-10, maxval=10)
        qp.pos.y = random.uniform(key=subkeys[i], minval=-10, maxval=10)
        qp.pos.z = random.uniform(key=subkeys[i], minval=0, maxval=4)

In [4]:
def create_config():
    # Define the environment config.
    sphere_maze = brax.Config(dt=0.01, substeps=2, dynamics_mode='pbd')
    


    sphere_maze.gravity.x = 0.0
    sphere_maze.gravity.y = 0.0
    sphere_maze.gravity.z = -9.8

    # Add the ground, a frozen (immovable) infinite plane
    ground = sphere_maze.bodies.add(name='ground')
    ground.frozen.all = True
    plane = ground.colliders.add().plane
    plane.SetInParent()  # for setting an empty oneof

    # Add the agent body
    add_agent(sphere_maze)

    # # Add the cubes.
    # add_objects(sphere_maze, key, object_count=object_count)
    # distribute_objects(sphere_maze, key, object_count=object_count)

    return sphere_maze

In [5]:
# Set the actions
def set_action(env, action):
    torque = jnp.array([0., 0., action[0]])
    env.physics.forces['agent', 'rolling'].max_torque = jnp.abs(torque)
    env.physics.forces['agent', 'rolling'].torque = torque

def gen_vis(config):
    # Create a visualization
    vis_config = html.Config.from_config(config)
    vis_config.side_length = 15
    vis = html.Visualization(vis_config)
    return vis

In [6]:
# Benchmarking params
horizon = 1000
separation = 3
iterations = 6
env_dims = [(5, 5), (10, 10), (15, 15), (20, 20), (25, 25), (30, 30), (40, 40), (50, 50)]

In [7]:
spawn_times = dict()
load_times = dict()
step_times = dict()
rollout_times = dict()

In [8]:
for j in range(iterations):
    print("Iteration: ", j)
    for dim in env_dims:
        print("Env dim: ", dim)
        config = create_config()

        st = time.time()
        # Distribute the objects
        spawner = Spawner(env_dim=dim)
        spawner.spawn_objects(config, key, dim, separation=5, obj_type="box", method="poisson")
        spawn_time = time.time() - st

        # Loading the system
        sys = brax.System(config=config)
        qps = [sys.default_qp()]
        load_time = time.time() - spawn_time - st

        jit_step_fn = jax.jit(sys.step)
        step_time = time.time() - load_time - spawn_time - st
        for i in range(horizon):
            act = random.uniform(key, (1, 3), minval=-1, maxval=1)
            key, _ = random.split(key)
            qp, _ = jit_step_fn(qps[-1], act)
            qps.append(qp)
        rollout_time = time.time() - step_time - load_time - spawn_time - st


        # Add to dicts
        if spawn_time not in spawn_times:
            spawn_times[dim] = [spawn_time]
        else:
            spawn_times[dim].append(spawn_time)

        if load_time not in load_times:
            load_times[dim] = [load_time]
        else:
            load_times[dim].append(load_time)

        if step_time not in step_times:
            step_times[dim] = [step_time]
        else:
            step_times[dim].append(step_time)

        if rollout_time not in rollout_times:
            rollout_times[dim] = [rollout_time]
        else:
            rollout_times[dim].append(rollout_time)


Iteration:  0
Env dim:  (5, 5)
[(-0.05529342595955378, 0.026116907918627752)]
Env dim:  (10, 10)
[(1.7562045719204082, -2.3721295339101367)]
Env dim:  (15, 15)
[(-3.710763211404428, -1.3042187164880121), (4.995254943153302, -1.8483085605650116), (-4.8999648960888065, 4.333970892690102), (0.5385525979759596, 4.818114282402428)]
Env dim:  (20, 20)
[(-6.3520125476647165, -8.194764297589249), (6.979531101509598, -7.381727807565597), (-5.978536232008177, -2.840069988545981), (0.51426743831433, -3.4896808486319495), (5.882708733410236, -0.40262027984835846), (0.6940839164615689, 2.1452900565134194), (-8.25276022691711, 6.998400795292072), (-3.1622113947259813, 6.095566505428959), (7.560228575166006, 7.332105031736813)]
Env dim:  (25, 25)
[(-10.145625861454079, -9.703654813880895), (-2.7779355828492776, -10.029575243845745), (3.1339053356537656, -8.758944198589205), (-8.773649530588644, -4.051637283827702), (-3.405194683033674, -3.649249124597585), (1.2454473607738832, -1.4716663558901284), (

KeyboardInterrupt: 

In [None]:
# Save dicitonaries to json files

with open('spawn_times.json', 'w') as fp:
    json.dump(spawn_times, fp)

with open('load_times.json', 'w') as fp:
    json.dump(load_times, fp)

with open('step_times.json', 'w') as fp:
    json.dump(step_times, fp)

with open('rollout_times.json', 'w') as fp:
    json.dump(rollout_times, fp)

In [None]:
# HTML(html.render(sys, qps, height=800))

In [None]:
# html.save_html('v1.html', sys, qps)