In [1]:
import functools
import time

from IPython.display import HTML, Image
import gym

import brax.v1 as brax

from brax.v1.physics import config_pb2


from brax.v1 import envs
from brax.v1.io import html
import jax
from jax import numpy as jnp
from jax import random

from builder import *
from builder import distribute
from builder.spawn import Spawner

import time
import json

In [2]:
seed = 42
key = random.PRNGKey(42)

In [3]:
def add_agent(config, agent_mass=1.0, radius=0.5):
    # Add the agent body
    agent = config.bodies.add(name='agent')
    sph = agent.colliders.add().sphere
    sph.radius = radius
    agent.mass = agent_mass
    # agent.damping = 1e-2
    # agent.friction = 0.6


# Adds joint for rolling the agent
def add_joint(config):
    joint = config.joints.add(name='rolling')
    joint.parent = 'ground'
    joint.child = 'agent'
    joint.angle = -1.57
    joint.twist = 1.0
    joint.limit.velocity = 100.0
    joint.limit.torque = 100.0
    joint.spring.stiffness = 1e5


def add_objects(config, key, object_count, cube_mass=1.0, cube_halfsize=0.5):
    key, *subkeys = random.split(key, object_count + 1)
    for i in range(object_count):
        cube = config.bodies.add(name=f'cube_{i}')
        box = cube.colliders.add().box
        box.halfsize.x = 0.5
        box.halfsize.y = 0.5
        box.halfsize.z = 0.5
        cube.mass = cube_mass


def distribute_objects(config, key, object_count):
    default = config.defaults.add()
    key, *subkeys = random.split(key, object_count + 1)
    for i in range(object_count):
        qp = default.qps.add(name=f'cube_{i}')
        qp.pos.x = random.uniform(key=subkeys[i], minval=-10, maxval=10)
        qp.pos.y = random.uniform(key=subkeys[i], minval=-10, maxval=10)
        qp.pos.z = random.uniform(key=subkeys[i], minval=0, maxval=4)

In [4]:
def create_config():
    # Define the environment config.
    sphere_maze = brax.Config(dt=0.01, substeps=2, dynamics_mode='pbd')
    


    sphere_maze.gravity.x = 0.0
    sphere_maze.gravity.y = 0.0
    sphere_maze.gravity.z = -9.8

    # Add the ground, a frozen (immovable) infinite plane
    ground = sphere_maze.bodies.add(name='ground')
    ground.frozen.all = True
    plane = ground.colliders.add().plane
    plane.SetInParent()  # for setting an empty oneof

    # Add the agent body
    add_agent(sphere_maze)

    # # Add the cubes.
    # add_objects(sphere_maze, key, object_count=object_count)
    # distribute_objects(sphere_maze, key, object_count=object_count)

    return sphere_maze

In [5]:
# Set the actions
def set_action(env, action):
    torque = jnp.array([0., 0., action[0]])
    env.physics.forces['agent', 'rolling'].max_torque = jnp.abs(torque)
    env.physics.forces['agent', 'rolling'].torque = torque

def gen_vis(config):
    # Create a visualization
    vis_config = html.Config.from_config(config)
    vis_config.side_length = 15
    vis = html.Visualization(vis_config)
    return vis

In [6]:
# Benchmarking params
separation = 3
iterations = 25
env_dims = [(5, 5), (10, 10), (15, 15), (20, 20), (25, 25), (30, 30), (40, 40), (60, 60), (100, 100)]

In [7]:
obj_cnts = dict()

In [8]:
for j in range(iterations):
    print("Iteration: ", j)
    for dim in env_dims:
        print("Env dim: ", dim)
        config = create_config()

        st = time.time()
        # Distribute the objects
        spawner = Spawner(env_dim=dim)
        obj_cnt = spawner.spawn_objects(config, key, dim, separation=5, obj_type="box", method="poisson")

        if dim[0] not in obj_cnts:
            obj_cnts[dim[0]] = [obj_cnt]
        else:
            obj_cnts[dim[0]].append(obj_cnt)


Iteration:  0
Env dim:  (5, 5)
Env dim:  (10, 10)
Env dim:  (15, 15)
Env dim:  (20, 20)
Env dim:  (25, 25)
Env dim:  (30, 30)
Env dim:  (40, 40)
Env dim:  (60, 60)
Env dim:  (100, 100)
Iteration:  1
Env dim:  (5, 5)
Env dim:  (10, 10)
Env dim:  (15, 15)
Env dim:  (20, 20)
Env dim:  (25, 25)
Env dim:  (30, 30)
Env dim:  (40, 40)
Env dim:  (60, 60)
Env dim:  (100, 100)
Iteration:  2
Env dim:  (5, 5)
Env dim:  (10, 10)
Env dim:  (15, 15)
Env dim:  (20, 20)
Env dim:  (25, 25)
Env dim:  (30, 30)
Env dim:  (40, 40)
Env dim:  (60, 60)
Env dim:  (100, 100)
Iteration:  3
Env dim:  (5, 5)
Env dim:  (10, 10)
Env dim:  (15, 15)
Env dim:  (20, 20)
Env dim:  (25, 25)
Env dim:  (30, 30)
Env dim:  (40, 40)
Env dim:  (60, 60)
Env dim:  (100, 100)
Iteration:  4
Env dim:  (5, 5)
Env dim:  (10, 10)
Env dim:  (15, 15)
Env dim:  (20, 20)
Env dim:  (25, 25)
Env dim:  (30, 30)
Env dim:  (40, 40)
Env dim:  (60, 60)
Env dim:  (100, 100)
Iteration:  5
Env dim:  (5, 5)
Env dim:  (10, 10)
Env dim:  (15, 15)
Env di

In [9]:
# Save dicitonaries to json files

with open('obj_cnts.json', 'w') as fp:
    json.dump(obj_cnts, fp)

In [10]:
# HTML(html.render(sys, qps, height=800))

In [11]:
# html.save_html('v1.html', sys, qps)