In [1]:
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

import jax
import mujoco
from mujoco import mjx
import mujoco.viewer
import mediapy as media

Setting environment variable to use GPU rendering:
env: MUJOCO_GL=egl


In [2]:
def create_mjx_model(mesh_paths, scale_factors, positions, rotations, colors):
    """Creates an MJX XML model string."""
    # Generate XML for bodies and assets
    bodies_xml = ""
    assets_xml = ""
    for i, (mesh_path, scale, pos, rot, color) in enumerate(zip(mesh_paths, scale_factors, positions, rotations, colors)):
        body_name = f"mesh{i+1}"
        scale_str = " ".join(map(str, scale))
        pos_str = " ".join(map(str, pos))
        rot_str = " ".join(map(str, rot))
        color_str = " ".join(map(str, color))
        bodies_xml += f'''
            <body name="{body_name}" pos="{pos_str}" quat="{rot_str}">
                <geom type="mesh" mesh="{body_name}" condim="3" rgba="{color_str}" contype="1" conaffinity="1"/>
                <joint type="free"/>
            </body>
        '''
        assets_xml += f'''
            <mesh name="{body_name}" file="{mesh_path}" scale="{scale_str}"/>
        '''

    # Define the XML model template
    MODEL_XML = f"""
    <mujoco>
        <option gravity="0 0 -9.81"/>
        <asset>
            <texture name="grid" type="2d" builtin="checker" rgb1=".1 .2 .3"
            rgb2=".2 .3 .4" width="300" height="300" mark="edge" markrgb=".2 .3 .4"/>
            <material name="grid" texture="grid" texrepeat="2 2" texuniform="true"
            reflectance=".2"/>
        </asset>
        <worldbody>
            <light name="top" pos="0 0 1"/>
            <geom name="ground" type="plane" pos="0 0 -.5" size="20 20 .1" material="grid" solimp=".99 .99 .01" solref=".001 1"/>
            {bodies_xml}
        </worldbody>
        <asset>
            {assets_xml}
        </asset>
    </mujoco>
    """
    return MODEL_XML


def simulate_collision(mesh_paths, scale_factors, positions, rotations, colors):
    """Runs the MJX simulation with the given convex mesh obj files."""
    # Create a temporary MJX model
    xml = create_mjx_model(mesh_paths, scale_factors, positions, rotations, colors)

    # Load MuJoCo model
    mj_model = mujoco.MjModel.from_xml_string(xml)
    mj_data = mujoco.MjData(mj_model)
    renderer = mujoco.Renderer(mj_model, height=480, width=480)
        
    # Simulate and visualize
    scene_option = mujoco.MjvOption()
    scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = True

    duration = 10  # (seconds)
    framerate = 100  # (Hz)

    frames = []
    mujoco.mj_resetData(mj_model, mj_data)
    while mj_data.time < duration:
        mujoco.mj_step(mj_model, mj_data)
        if len(frames) < mj_data.time * framerate:
            renderer.update_scene(mj_data, scene_option=scene_option)
            pixels = renderer.render()
            frames.append(pixels)

        # Simulate and display video.
    media.show_video(frames, fps=framerate)

In [3]:
mesh_paths = ["/ccn2/u/rmvenkat/data/all_flex_meshes/cone.obj", "/ccn2/u/rmvenkat/data/all_flex_meshes/torus.obj"]
scale_factors = [(1.0, 1.0, 1.0), (1, 1, 1)]
positions = [(0, 0, 1), (0, 0, 3)]
rotations = [(1, 1, 0, 0), (1, 1, 0, 0)]
colors = [(1, 0, 0, 1), (0, 1, 0, 1)]
simulate_collision(mesh_paths, scale_factors, positions, rotations, colors)

0
This browser does not support the video tag.
