In [None]:
import functools
import time

from IPython.display import HTML, Image
import gym

import brax.v1 as brax

from brax.v1.physics import config_pb2


from brax.v1 import envs
from brax.v1.io import html
import jax
from jax import numpy as jnp
from jax import random

from builder import *
from builder import distribute
from builder.spawn import Spawner

import time
import json

In [None]:
seed = 42
key = random.PRNGKey(42)

In [None]:
def add_agent(config, agent_mass=1.0, radius=0.5):
    # Add the agent body
    agent = config.bodies.add(name='agent')
    sph = agent.colliders.add().sphere
    sph.radius = radius
    agent.mass = agent_mass
    # agent.damping = 1e-2
    # agent.friction = 0.6


# Adds joint for rolling the agent
def add_joint(config):
    joint = config.joints.add(name='rolling')
    joint.parent = 'ground'
    joint.child = 'agent'
    joint.angle = -1.57
    joint.twist = 1.0
    joint.limit.velocity = 100.0
    joint.limit.torque = 100.0
    joint.spring.stiffness = 1e5


def add_objects(config, key, object_count, cube_mass=1.0, cube_halfsize=0.5):
    key, *subkeys = random.split(key, object_count + 1)
    for i in range(object_count):
        cube = config.bodies.add(name=f'cube_{i}')
        box = cube.colliders.add().box
        box.halfsize.x = 0.5
        box.halfsize.y = 0.5
        box.halfsize.z = 0.5
        cube.mass = cube_mass


def distribute_objects(config, key, object_count):
    default = config.defaults.add()
    key, *subkeys = random.split(key, object_count + 1)
    for i in range(object_count):
        qp = default.qps.add(name=f'cube_{i}')
        qp.pos.x = random.uniform(key=subkeys[i], minval=-10, maxval=10)
        qp.pos.y = random.uniform(key=subkeys[i], minval=-10, maxval=10)
        qp.pos.z = random.uniform(key=subkeys[i], minval=0, maxval=4)

In [None]:
def create_config():
    # Define the environment config.
    sphere_maze = brax.Config(dt=0.01, substeps=2, dynamics_mode='pbd')
    


    sphere_maze.gravity.x = 0.0
    sphere_maze.gravity.y = 0.0
    sphere_maze.gravity.z = -9.8

    # Add the ground, a frozen (immovable) infinite plane
    ground = sphere_maze.bodies.add(name='ground')
    ground.frozen.all = True
    plane = ground.colliders.add().plane
    plane.SetInParent()  # for setting an empty oneof

    # Add the agent body
    add_agent(sphere_maze)

    # # Add the cubes.
    # add_objects(sphere_maze, key, object_count=object_count)
    # distribute_objects(sphere_maze, key, object_count=object_count)

    return sphere_maze

In [None]:
# Set the actions
def set_action(env, action):
    torque = jnp.array([0., 0., action[0]])
    env.physics.forces['agent', 'rolling'].max_torque = jnp.abs(torque)
    env.physics.forces['agent', 'rolling'].torque = torque

def gen_vis(config):
    # Create a visualization
    vis_config = html.Config.from_config(config)
    vis_config.side_length = 15
    vis = html.Visualization(vis_config)
    return vis

In [None]:
# Benchmarking params
horizon = 1000
separation = 3
iterations = 6
env_dims = [(5, 5), (10, 10), (15, 15), (20, 20), (25, 25), (30, 30), (40, 40)]

In [None]:
spawn_times = dict()
load_times = dict()
step_times = dict()
rollout_times = dict()

In [None]:
for j in range(iterations):
    print("Iteration: ", j)
    for dim in env_dims:
        print("Env dim: ", dim)
        config = create_config()

        st = time.time()
        # Distribute the objects
        spawner = Spawner(env_dim=dim)
        spawner.spawn_objects(config, key, dim, separation=5, obj_type="box", method="poisson")
        spawn_time = time.time() - st

        # Loading the system
        sys = brax.System(config=config)
        qps = [sys.default_qp()]
        load_time = time.time() - spawn_time - st

        jit_step_fn = jax.jit(sys.step)

        for i in range(horizon):
            act = random.uniform(key, (1, 3), minval=-1, maxval=1)
            key, _ = random.split(key)
            qp, _ = jit_step_fn(qps[-1], act)
            step_time = time.time() - load_time - spawn_time - st
            qps.append(qp)
        rollout_time = time.time() - step_time - load_time - spawn_time - st


        # Add to dicts
        if dim not in spawn_times:
            spawn_times[dim] = [spawn_time]
        else:
            spawn_times[dim].append(spawn_time)

        if dim not in load_times:
            load_times[dim] = [load_time]
        else:
            load_times[dim].append(load_time)

        if dim not in step_times:
            step_times[dim] = [step_time]
        else:
            step_times[dim].append(step_time)

        if dim not in rollout_times:
            rollout_times[dim] = [rollout_time]
        else:
            rollout_times[dim].append(rollout_time)


Iteration:  0
Env dim:  (5, 5)
[(0.2771204664937307, 0.08057567866888027)]
Env dim:  (10, 10)
[(-2.4661975855962006, -2.873019700382543), (2.654405174266083, -2.283855518054232), (-2.9321784481999016, 2.628381320065197)]
Env dim:  (15, 15)
[(-5.361465780858335, -4.371315111538589), (-0.24760443290704526, -5.427694309059383), (4.271857269654575, -2.727855635404819), (-2.0131994057175846, -0.5348248384072161), (3.7280063951772924, 2.38434109784743), (-2.4272493247946274, 4.656539617091061)]
Env dim:  (20, 20)
[(-5.487391210578444, -8.101996579451178), (-0.5156713377441284, -4.138620353246744), (4.4502612309517815, -7.0615309729936095), (7.773554483338026, -2.4327785163320375), (-6.891210999613146, 0.9259956127606506), (0.9665515911638289, 1.476987985650596), (-7.745279605729811, 7.742404367207744), (-1.8015401612122934, 7.5706247695504185), (7.052068021000583, 7.409814418484453)]
Env dim:  (25, 25)
[(-10.473197925781172, -7.555159618990992), (-1.563408516617658, -7.808099789616352), (3.5

In [None]:
print(spawn_times.keys())

In [None]:
spawn_times1 = {k[0]: spawn_times[k] for k in spawn_times.keys()} 
load_times1 = {k[0]: load_times[k] for k in load_times.keys()} 
step_times1 = {k[0]: step_times[k] for k in step_times.keys()} 
rollout_times1 = {k[0]: rollout_times[k] for k in rollout_times.keys()} 

In [None]:
print(spawn_times1)

In [None]:
# Save dicitonaries to json files

with open('spawn_times.json', 'w') as fp:
    json.dump(spawn_times1, fp)

with open('load_times.json', 'w') as fp:
    json.dump(load_times1, fp)

with open('step_times.json', 'w') as fp:
    json.dump(step_times1, fp)

with open('rollout_times.json', 'w') as fp:
    json.dump(rollout_times1, fp)

In [None]:
# HTML(html.render(sys, qps, height=800))

In [None]:
# html.save_html('v1.html', sys, qps)