In [5]:
import functools
import time

from IPython.display import HTML, Image
import gym

import brax



from brax import envs
from brax.io import html
import jax
from jax import numpy as jnp
from jax import random

from builder import *
from builder import distribute
from builder.spawn import Spawner

In [6]:
seed = 42
key = random.PRNGKey(42)

In [7]:
def add_agent(config, agent_mass=1.0, radius=0.5):
    # Add the agent body
    agent = config.bodies.add(name='agent')
    sph = agent.colliders.add().sphere
    sph.radius = radius
    agent.mass = agent_mass
    # agent.damping = 1e-2
    # agent.friction = 0.6


# Adds joint for rolling the agent
def add_joint(config):
    joint = config.joints.add(name='rolling')
    joint.parent = 'ground'
    joint.child = 'agent'
    joint.angle = -1.57
    joint.twist = 1.0
    joint.limit.velocity = 100.0
    joint.limit.torque = 100.0
    joint.spring.stiffness = 1e5


def add_objects(config, key, object_count, cube_mass=1.0, cube_halfsize=0.5):
    key, *subkeys = random.split(key, object_count + 1)
    for i in range(object_count):
        cube = config.bodies.add(name=f'cube_{i}')
        box = cube.colliders.add().box
        box.halfsize.x = 0.5
        box.halfsize.y = 0.5
        box.halfsize.z = 0.5
        cube.mass = cube_mass


def distribute_objects(config, key, object_count):
    default = config.defaults.add()
    key, *subkeys = random.split(key, object_count + 1)
    for i in range(object_count):
        qp = default.qps.add(name=f'cube_{i}')
        qp.pos.x = random.uniform(key=subkeys[i], minval=-10, maxval=10)
        qp.pos.y = random.uniform(key=subkeys[i], minval=-10, maxval=10)
        qp.pos.z = random.uniform(key=subkeys[i], minval=0, maxval=4)

In [8]:
def create_config(object_count):
    # Define the environment config.
    sphere_maze = brax.Config(dt=0.01, substeps=2, dynamics_mode='pbd')
    


    sphere_maze.gravity.x = 0.0
    sphere_maze.gravity.y = 0.0
    sphere_maze.gravity.z = -9.8

    # Add the ground, a frozen (immovable) infinite plane
    ground = sphere_maze.bodies.add(name='ground')
    ground.frozen.all = True
    plane = ground.colliders.add().plane
    plane.SetInParent()  # for setting an empty oneof

    # Add the agent body
    add_agent(sphere_maze)

    # # Add the cubes.
    # add_objects(sphere_maze, key, object_count=object_count)
    # distribute_objects(sphere_maze, key, object_count=object_count)

    return sphere_maze

In [9]:
# Set the actions
def set_action(env, action):
    torque = jnp.array([0., 0., action[0]])
    env.physics.forces['agent', 'rolling'].max_torque = jnp.abs(torque)
    env.physics.forces['agent', 'rolling'].torque = torque

def gen_vis(config):
    # Create a visualization
    vis_config = html.Config.from_config(config)
    vis_config.side_length = 15
    vis = html.Visualization(vis_config)
    return vis

In [11]:
object_count = 5

config = create_config(object_count)

# vis = gen_vis(config=config)
spawner = Spawner(env_dim=(50,50))
spawner.spawn_objects(config, key, space_dim=(50,50), separation=7, obj_type="box", method="poisson")


sys = brax.System(config)
qps = [sys.default_qp()]
act = jnp.array([0.0])

# agent_velocity = 0.5
# qps[-1].vel[1, 0] = agent_velocity


# Run the environment loop
for i in range(1):
    print(i)
    action = jnp.array([jnp.sin(i / 10)])
    qp, _ = sys.step(qps[-1], act)
    qps.append(qp)

HTML(html.render(sys, qps, height=800))

AttributeError: module 'brax' has no attribute 'Config'