In [1]:
import math
import os
import numpy as np
from pxr import Usd, UsdGeom

import warp as wp
import warp.examples
import warp.sim
import warp.sim.render


# ----------------------------
# helper: load bunny mesh once
# ----------------------------
def load_bunny(scale):
    asset_path = os.path.join(warp.examples.get_asset_directory(), "bunny.usd")
    asset_stage = Usd.Stage.Open(asset_path)
    mesh_geom = UsdGeom.Mesh(asset_stage.GetPrimAtPath("/root/bunny"))
    points = np.array(mesh_geom.GetPointsAttr().Get())
    indices = np.array(mesh_geom.GetFaceVertexIndicesAttr().Get()).flatten()
    return wp.sim.Mesh(points * scale, indices)


# ------------------------------------
# build a *single* environment template
# ------------------------------------
def build_single_env(scale=0.8, ke=1.0e5, kd=250.0, kf=500.0, num_bodies=8):
    b = wp.sim.ModelBuilder()

    # boxes
    for i in range(num_bodies):
        body = b.add_body(origin=wp.transform((i, 1.0, 0.0), wp.quat_identity()))
        b.add_shape_box(pos=wp.vec3(0.0, 0.0, 0.0),
                        hx=0.5 * scale, hy=0.2 * scale, hz=0.2 * scale,
                        body=body, ke=ke, kd=kd, kf=kf)

    # spheres
    for i in range(num_bodies):
        body = b.add_body(origin=wp.transform((i, 1.0, 2.0), wp.quat_identity()))
        b.add_shape_sphere(pos=wp.vec3(0.0, 0.0, 0.0),
                           radius=0.25 * scale, body=body, ke=ke, kd=kd, kf=kf)

    # capsules
    for i in range(num_bodies):
        body = b.add_body(origin=wp.transform((i, 1.0, 6.0), wp.quat_identity()))
        b.add_shape_capsule(pos=wp.vec3(0.0, 0.0, 0.0),
                            radius=0.25 * scale, half_height=scale * 0.5, up_axis=0,
                            body=body, ke=ke, kd=kd, kf=kf)

    # initial spin (angular vel in body_qd)
    for i in range(len(b.body_qd)):
        b.body_qd[i] = (0.0, 2.0, 10.0, 0.0, 0.0, 0.0)

    # bunny stack
    bunny_mesh = load_bunny(scale=scale)
    for i in range(num_bodies):
        body = b.add_body(
            origin=wp.transform(
                (i * 0.5 * scale, 1.0 + i * 1.7 * scale, 4.0 + i * 0.5 * scale),
                wp.quat_from_axis_angle(wp.vec3(0.0, 1.0, 0.0), math.pi * 0.1 * i),
            )
        )
        b.add_shape_mesh(body=body, mesh=bunny_mesh,
                         pos=wp.vec3(0.0, 0.0, 0.0),
                         scale=wp.vec3(scale, scale, scale),
                         ke=ke, kd=kd, kf=kf, density=1e3)

    return b


# ----------------------------------------------------
# batched example: clone template into many environments
# ----------------------------------------------------
class BatchedExample:
    def __init__(self, num_envs=1000, stage_path=None,
                 fps=60, sim_substeps=10, env_spacing=20.0,
                 scale=0.8, ke=1.0e5, kd=250.0, kf=500.0, num_bodies=8):

        self.num_envs = num_envs
        self.frame_dt = 1.0 / fps
        self.sim_substeps = sim_substeps
        self.sim_dt = self.frame_dt / self.sim_substeps
        self.sim_time = 0.0

        # build template
        template_builder = build_single_env(scale=scale, ke=ke, kd=kd, kf=kf, num_bodies=num_bodies)

        # build top-level builder and replicate
        top_builder = wp.sim.ModelBuilder()
        for env_id in range(self.num_envs):
            tx = env_id * env_spacing
            xform = wp.transform((tx, 0.0, 0.0), wp.quat_identity())
            # add_builder clones all bodies/shapes; separate_collision_group=True isolates collisions per env
            top_builder.add_builder(template_builder,
                                    xform=xform,
                                    update_num_env_count=True,
                                    separate_collision_group=True)

        # finalize (GPU arrays allocated here)
        self.model = top_builder.finalize()
        self.model.ground = True

        self.integrator = wp.sim.SemiImplicitIntegrator()

        # allocate two ping-pong states
        self.state_0 = self.model.state()
        self.state_1 = self.model.state()

        # initialize FK from joint_q (mostly identity for free bodies)
        wp.sim.eval_fk(self.model, self.model.joint_q, self.model.joint_qd, None, self.state_0)

        # optional: randomize per-environment starts (GPU kernel below)
        self.randomize_starts()

        # renderer (USD) is expensive; disable by default when running large batches
        if stage_path:
            self.renderer = wp.sim.render.SimRenderer(self.model, stage_path, scaling=0.25)
        else:
            self.renderer = None

        # capture CUDA graph for one frame's worth of substeps
        self.use_cuda_graph = wp.get_device().is_cuda
        if self.use_cuda_graph:
            with wp.ScopedCapture() as capture:
                self._simulate_substeps()
            self.graph = capture.graph

    # ---------------------------------------
    # optional GPU-side randomized init/reset
    # ---------------------------------------
    def randomize_starts(self, seed=0):
        # simple demo: add a small random y-jitter to all bodies in each env
        rng = np.random.default_rng(seed)
        body_count = self.model.body_count
        jitter = wp.array(rng.uniform(-0.1, 0.1, size=body_count), dtype=float)

        @wp.kernel
        def add_jitter(q: wp.array(dtype=wp.transform),
                       jitter: wp.array(dtype=float)):
            tid = wp.tid()
            t = q[tid]
            pos = wp.transform_get_translation(t)
            pos = wp.vec3(pos[0], pos[1] + jitter[tid], pos[2])
            q[tid] = wp.transform(pos, wp.transform_get_rotation(t))

        wp.launch(add_jitter, dim=body_count, inputs=[self.state_0.body_q, jitter])

    # ---------------------------------------
    # internal: run substeps (no Python loop per env)
    # ---------------------------------------
    def _simulate_substeps(self):
        for _ in range(self.sim_substeps):
            self.state_0.clear_forces()
            wp.sim.collide(self.model, self.state_0)
            self.integrator.simulate(self.model, self.state_0, self.state_1, self.sim_dt)
            self.state_0, self.state_1 = self.state_1, self.state_0

    def step(self):
        if self.use_cuda_graph:
            wp.capture_launch(self.graph)
        else:
            self._simulate_substeps()
        self.sim_time += self.frame_dt

    def render(self):
        if self.renderer is None:
            return
        self.renderer.begin_frame(self.sim_time)
        self.renderer.render(self.state_0)
        self.renderer.end_frame()


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
    parser.add_argument("--stage_path", type=lambda x: None if x == "None" else str(x),
                        default='/home/hlwang/code/b3d/test.usd', help="Path to output USD (warning: huge if enabled).")
    parser.add_argument("--num_frames", type=int, default=300, help="Frames to simulate.")
    parser.add_argument("--num_envs", type=int, default=10, help="Number of parallel envs.")
    args = parser.parse_known_args()[0]

    with wp.ScopedDevice(args.device):
        ex = BatchedExample(num_envs=args.num_envs, stage_path=args.stage_path)

        for _ in range(args.num_frames):
            ex.step()
            ex.render()

        if ex.renderer:
            ex.renderer.save()


Warp 1.7.0 initialized:
   CUDA Toolkit 12.8, Driver 12.4
   Devices:
     "cpu"      : "x86_64"
     "cuda:0"   : "NVIDIA H200" (140 GiB, sm_90, mempool enabled)
   Kernel cache:
     /home/hlwang/.cache/warp/1.7.0
Module warp.sim.inertia 3bc41ce load on device 'cuda:0' took 2.08 ms  (cached)
Module warp.sim.collide e2dca21 load on device 'cuda:0' took 2.95 ms  (cached)
Module __main__ 004cda4 load on device 'cuda:0' took 473.69 ms  (compiled)
Module warp.sim.integrator_euler 99d48f5 load on device 'cuda:0' took 1.59 ms  (cached)
Module warp.sim.integrator 3b115ab load on device 'cuda:0' took 1.12 ms  (cached)
Saved the USD stage file at `/home/hlwang/code/b3d/test.usd`


# test warp combined with genjax trace and inference

In [1]:
import json
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = "platform"
from os.path import join
import copy
from jax.random import split

import b3d
import b3d.chisight.dense.dense_model
import b3d.chisight.dense.likelihoods.laplace_likelihood
import b3d.chisight.gen3d.inference.inference as inference
import b3d.chisight.gen3d.settings as settings
import jax
import jax.numpy as jnp
import trimesh
from b3d.chisight.gen3d.dataloading import (
    get_initial_state,
    load_trial,
    resize_rgbds_and_get_masks,
)
from genjax import Pytree, pretty
pretty()

def foreground_background(depth_map, area, val):
    zero_depth_map = jnp.full(depth_map.shape, val)
    zero_depth_map = zero_depth_map.at[area].set(depth_map[area])
    return zero_depth_map

START_T = 0
near_plane = 0.1
far_plane = 100.0
im_width = 350
im_height = 350
width = 1024
height = 1024

hdf5_file_path = "/orcd/data/jbt/001/hlwang/data/lf0"
mesh_file_path = "/orcd/data/jbt/001/hlwang/data/all_flex_meshes/"
pred_file_path = "/orcd/data/jbt/001/hlwang/data/gt_correct.json"
scenario = 'roll'
trial_name = 'pilot_it2_rollingSliding_simple_ramp_tdw_1_dis_1_occ_small_zone_0017'

with open(pred_file_path) as f:
    pred_file_all = json.load(f)
pred_file = pred_file_all[trial_name]

all_meshes = {}
for path, dirs, files in os.walk(mesh_file_path):
    for name in files + dirs:
        if name.endswith(".obj"):
            mesh = trimesh.load(os.path.join(path, name))
            all_meshes[name[:-4]] = mesh

scaling_factor = im_height / height
vfov = 54.43222 / 180.0 * jnp.pi
tan_half_vfov = jnp.tan(vfov / 2.0)
tan_half_hfov = tan_half_vfov * width / float(height)
fx = width / 2.0 / tan_half_hfov
fy = height / 2.0 / tan_half_vfov

renderer = b3d.renderer.renderer_original.RendererOriginal(
    width * scaling_factor,
    height * scaling_factor,
    fx * scaling_factor,
    fy * scaling_factor,
    (width / 2) * scaling_factor,
    (height / 2) * scaling_factor,
    near_plane,
    far_plane,
)

b3d.reload(b3d.chisight.dense.likelihoods.laplace_likelihood)
likelihood_func = b3d.chisight.dense.likelihoods.laplace_likelihood.likelihood_func

b3d.reload(b3d.chisight.dense.dense_model)
dynamic_object_generative_model, viz_trace = (
    b3d.chisight.dense.dense_model.make_dense_multiobject_dynamics_model(
        renderer, likelihood_func
    )
)
importance_jit = jax.jit(dynamic_object_generative_model.importance)

likelihood_args = {
    "fx": renderer.fx,
    "fy": renderer.fy,
    "cx": renderer.cx,
    "cy": renderer.cy,
    "color_noise_variance": 1.0,
    "depth_noise_variance": 0.01,
    "outlier_probability": 0.1,
    "image_width": Pytree.const(renderer.width),
    "image_height": Pytree.const(renderer.height),
    "masked": Pytree.const(True),
    "check_interp": Pytree.const(False),
    "num_mc_sample": Pytree.const(500),
    "interp_penalty": Pytree.const(1e5),
}

physics_args = {
    "mu": Pytree.const(0.25),
    "restitution": Pytree.const(0.4),
    "fps": Pytree.const(100),
    "sim_substeps": Pytree.const(10),
    "g": Pytree.const(-9.80665),
}

inference_hyperparams = b3d.chisight.gen3d.settings.inference_hyperparams

hdf5_file_path = join(
    hdf5_file_path,
    scenario + "_all_movies",
    f"{trial_name}.hdf5",
)

(
    rgbds_original,
    seg_arr_original,
    object_ids,
    object_segmentation_colors,
    background_areas,
    camera_pose,
    
    gt_pos_array,
    gt_rot_array,
    gt_linvel_array,
    gt_angvel_array,
) = load_trial(hdf5_file_path, 250)

hyperparams = settings.hyperparams
hyperparams["camera_pose"] = camera_pose
hyperparams["likelihood_args"] = likelihood_args
hyperparams["physics_args"] = physics_args

initial_state, hyperparams, renderer, state, initial_warp_info = get_initial_state(
    pred_file,
    object_ids,
    object_segmentation_colors,
    all_meshes,
    seg_arr_original[START_T],
    rgbds_original[START_T],
    hyperparams,
)

rgbds, all_areas, background_areas = resize_rgbds_and_get_masks(
    rgbds_original, seg_arr_original, background_areas, im_height, im_width
)
hyperparams["background"] = jnp.asarray(
    [
        foreground_background(rgbds[t], background_areas[t], jnp.inf)
        for t in range(rgbds.shape[0])
    ]
)
initial_state

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


Warp 1.7.0 initialized:
   CUDA Toolkit 12.8, Driver 12.4
   Devices:
     "cpu"      : "x86_64"
     "cuda:0"   : "NVIDIA H200" (140 GiB, sm_90, mempool enabled)
   Kernel cache:
     /home/hlwang/.cache/warp/1.7.0
Module warp.sim.inertia 3bc41ce load on device 'cuda:0' took 1.73 ms  (cached)
Module warp.sim.collide e2dca21 load on device 'cuda:0' took 3.49 ms  (cached)


In [2]:
initial_warp_info

In [3]:
key = jax.random.PRNGKey(156)
trace = inference.get_initial_trace(
    key,
    importance_jit,
    hyperparams,
    initial_state,
    initial_warp_info,
    foreground_background(rgbds[START_T], all_areas[START_T], 0.0),
)
trace

pure print body_q before collide: Traced<ShapedArray(float32[3,7])>with<DynamicJaxprTrace>
Module b3d.physics.physics_utils 115c70c load on device 'cuda:0' took 1.85 ms  (cached)
pure print body_q before simulate: Traced<ShapedArray(float32[3,7])>with<DynamicJaxprTrace>
NOT STEPPING


  second_moment = np.nanmean(


In [None]:
T = 0
if T == 0:
    xyz = False
    infer_vel = False
else:
    xyz = True
    infer_vel = True

relevant_objects = [2]
addresses = [Pytree.const(f"object_pose_{o_id}") for o_id in relevant_objects]

In [5]:
# assert len(inference_hyperparams.pose_proposal_args) == len(inference_hyperparams.vel_proposal_args)
# key, subkey = split(key)
# trace = inference.advance_time(subkey, trace, foreground_background(rgbds[T], all_areas[T], 0.0))

In [6]:
import itertools
from b3d.chisight.gen3d.inference.utils import logmeanexp, update_field

@jax.jit
def inference_step(
    key,
    trace,
    observed_rgbd,
    inference_hyperparams,
    addresses,
    xyz=True,
    infer_vel=True,
    include_previous_pose=True,
    k=50,
):
    assert len(inference_hyperparams.pose_proposal_args) == len(inference_hyperparams.vel_proposal_args)
    key, subkey = split(key)
    trace = inference.advance_time(subkey, trace, observed_rgbd)

    @jax.jit
    def c2f_pose_step(
        key,
        trace,
        pose_proposal_args,
        vel_proposal_args,
        addr,
    ):
        addr = addr.unwrap()
        k1, k2, k3 = split(key, 3)

        # Propose the poses
        generation_keys = split(k1, inference_hyperparams.n_poses_vels)
        proposed_poses, log_q_poses = jax.vmap(
            inference.propose_pose, in_axes=(0, None, None, None, None)
        )(generation_keys, trace, addr, pose_proposal_args, xyz)
        # jax.debug.print("proposed_poses: {v}", v=proposed_poses)
        # jax.debug.print("log_q_poses before: {v}", v=log_q_poses)

        proposed_poses, log_q_poses = inference.maybe_swap_in_previous_pose(
            proposed_poses,
            log_q_poses,
            trace,
            addr,
            include_previous_pose,
            pose_proposal_args,
            xyz,
        )
        # jax.debug.print("rank after: {v}", v=ss.rankdata(log_q_poses))
        # jax.debug.print("score after: {v}", v=log_q_poses)
        def update_and_get_scores_pose(key, proposed_pose, trace, addr):
            key, subkey = split(key)
            updated_trace = update_field(subkey, trace, addr, proposed_pose)
            return updated_trace, updated_trace.get_score()

        param_generation_keys = split(k3, inference_hyperparams.n_poses_vels)
        _, p_scores = jax.vmap(update_and_get_scores_pose, in_axes=(0, 0, None, None))(
            param_generation_keys, proposed_poses, trace, addr
        )
        # jax.debug.print("p_scores: {x}", x=p_scores)
        # jax.debug.print("log_q_poses: {x}", x=log_q_poses)
        # jax.debug.print("log_q_vels: {x}", x=log_q_vels)
        # jax.debug.print("log_q_poses+log_q_vels: {x}", x=log_q_poses+log_q_vels)
        # Scoring + resampling
        weights = jnp.where(
            inference_hyperparams.include_q_scores_at_top_level,
            p_scores - log_q_poses,
            p_scores,
        )
        # jax.debug.print("weights: {x}", x=weights)

        # chosen_index = jax.random.categorical(k3, weights)
        chosen_index = weights.argmax()
        resampled_trace, _ = update_and_get_scores_pose(
            param_generation_keys[chosen_index],
            proposed_poses[chosen_index],
            trace,
            addr,
        )
        return (
            resampled_trace,
            logmeanexp(weights),
            proposed_poses[chosen_index],
            inference.get_zreo_vel(None),
            proposed_poses,
            jax.vmap(inference.get_zreo_vel, in_axes=(0,))(generation_keys),
            weights,
        )

    @jax.jit
    def c2f_pose_vel_step(
        key,
        trace,
        pose_proposal_args,
        vel_proposal_args,
        addr,
    ):
        addr = addr.unwrap()
        k1, k2, k3 = split(key, 3)

        # Propose the poses
        generation_keys = split(k1, inference_hyperparams.n_poses_vels)
        proposed_poses, log_q_poses = jax.vmap(
            inference.propose_pose, in_axes=(0, None, None, None, None)
        )(generation_keys, trace, addr, pose_proposal_args, xyz)
        # jax.debug.print("proposed_poses: {v}", v=proposed_poses)
        # jax.debug.print("log_q_poses before: {v}", v=log_q_poses)
        generation_keys = split(k2, inference_hyperparams.n_poses_vels)
        proposed_vels, log_q_vels = jax.vmap(
            inference.propose_vel, in_axes=(0, None, None, None)
        )(generation_keys, trace, addr.replace('pose', 'vel'), vel_proposal_args)
        # jax.debug.print("proposed_vels: {v}", v=proposed_vels)
        # jax.debug.print("log_q_vels before: {v}", v=log_q_vels)

        proposed_poses, log_q_poses = inference.maybe_swap_in_previous_pose(
            proposed_poses,
            log_q_poses,
            trace,
            addr,
            include_previous_pose,
            pose_proposal_args,
            xyz,
        )
        proposed_vels, log_q_vels = inference.maybe_swap_in_previous_vel(
            proposed_vels,
            log_q_vels,
            trace,
            addr.replace('pose', 'vel'),
            include_previous_pose,
            vel_proposal_args,
        )
        
        # jax.debug.print("rank after: {v}", v=ss.rankdata(log_q_poses))
        # jax.debug.print("score after: {v}", v=log_q_poses)

        def update_and_get_scores_pose_vel(key, proposed_pose, proposed_vel, trace, addr_pose, addr_vel):
            key, subkey = split(key)
            updated_trace = inference.update_fields(subkey, trace, [addr_pose, addr_vel],
                [proposed_pose, proposed_vel])
            return updated_trace, updated_trace.get_score()

        param_generation_keys = split(k3, inference_hyperparams.n_poses_vels)
        return param_generation_keys
        _, p_scores = jax.vmap(update_and_get_scores_pose_vel, in_axes=(0, 0, 0, None, None, None))(
            param_generation_keys, proposed_poses, proposed_vels, trace, addr, addr.replace('pose', 'vel')
        )
        
        # jax.debug.print("p_scores: {x}", x=p_scores)
        # jax.debug.print("log_q_poses: {x}", x=log_q_poses)
        # jax.debug.print("log_q_vels: {x}", x=log_q_vels)
        # jax.debug.print("log_q_poses+log_q_vels: {x}", x=log_q_poses+log_q_vels)
        # Scoring + resampling
        weights = jnp.where(
            inference_hyperparams.include_q_scores_at_top_level,
            p_scores - (log_q_poses+log_q_vels),
            p_scores,
        )

        # chosen_index = jax.random.categorical(k3, weights)
        chosen_index = weights.argmax()
        
        # jax.debug.print("chosen_index: {x}", x=chosen_index)
        # jax.debug.print("len(proposed_poses): {x}", x=len(proposed_poses))
        # jax.debug.print("len(proposed_vels): {x}", x=len(proposed_vels))
        # resampled_trace, _ = update_and_get_scores_pose_vel(
        #     param_generation_keys[chosen_index],
        #     proposed_poses[chosen_index],
        #     proposed_vels[chosen_index],
        #     trace,
        #     addr,
        #     addr.replace('pose', 'vel'),
        # )
        chosen_index = 0
        return (
            k1,
            1.0,
            proposed_poses[chosen_index],
            proposed_vels[chosen_index],
            proposed_poses,
            proposed_vels,
            jnp.full((2000,), 1.0),
        )
        return (
            param_generation_keys[chosen_index],
            logmeanexp(weights),
            proposed_poses[chosen_index],
            proposed_vels[chosen_index],
            proposed_poses,
            proposed_vels,
            weights,
        )
        return (
            trace,
            # resampled_trace,
            logmeanexp(weights),
            proposed_poses[chosen_index],
            proposed_vels[chosen_index],
            proposed_poses,
            proposed_vels,
            weights,
        )
        
    # for i, (addr, (pose_proposal_args, vel_proposal_args)) in enumerate(
    #         itertools.product(addresses, zip([inference_hyperparams.pose_proposal_args[0]], [inference_hyperparams.vel_proposal_args[0]]))
    #     ):
    #     trace, _, best_pose, best_vel, proposed_poses, proposed_vels, weights = jax.lax.cond(infer_vel, c2f_pose_vel_step, c2f_pose_step, subkey, trace, inference_hyperparams.pose_proposal_args[0], inference_hyperparams.vel_proposal_args[0], addresses[0])

    key, subkey = split(key)
    # trace, _, best_pose, best_vel, proposed_poses, proposed_vels, weights = c2f_pose_vel_step(subkey, trace, inference_hyperparams.pose_proposal_args[0], inference_hyperparams.vel_proposal_args[0], addresses[0])
    # generation_key, _, best_pose, best_vel, proposed_poses, proposed_vels, weights = c2f_pose_vel_step(subkey, trace, inference_hyperparams.pose_proposal_args[0], inference_hyperparams.vel_proposal_args[0], addresses[0])
    index = c2f_pose_vel_step(subkey, trace, inference_hyperparams.pose_proposal_args[0], inference_hyperparams.vel_proposal_args[0], addresses[0])
    # trace, _, best_pose, best_vel, proposed_poses, proposed_vels, weights = c2f_pose_step(subkey, trace, inference_hyperparams.pose_proposal_args[0], inference_hyperparams.vel_proposal_args[0], addresses[0])
    return index

In [7]:
key, subkey = split(key)
# new_trace, _ = inference_step(
#                 subkey,
#                 trace,
#                 foreground_background(rgbds[T], all_areas[T], 0.0),
#                 inference_hyperparams,
#                 addresses,
#                 xyz,
#                 infer_vel,
index = inference_step(
                subkey,
                trace,
                foreground_background(rgbds[T], all_areas[T], 0.0),
                inference_hyperparams,
                addresses,
                xyz,
                infer_vel,
            )

NOT STEPPING


In [None]:
def update_and_get_scores_pose_vel(key, proposed_pose, proposed_vel, trace, addr_pose, addr_vel):
    key, subkey = split(key)
    updated_trace = inference.update_fields(subkey, trace, [addr_pose, addr_vel],
        [proposed_pose, proposed_vel])
    return updated_trace, updated_trace.get_score()

In [None]:
resampled_trace, _ = update_and_get_scores_pose_vel(
        generation_key,
        best_pose,
        best_vel,
        trace,
        addresses[0],
        addresses[0].unwrap().replace('pose', 'vel'),
    )

In [8]:
# for i, (addr, (pose_proposal_args, vel_proposal_args)) in enumerate(
#         itertools.product(addresses, zip(inference_hyperparams.pose_proposal_args, inference_hyperparams.vel_proposal_args))
#     ):
#     key, subkey = split(key)
#     trace, _, best_pose, best_vel, proposed_poses, proposed_vels, weights = c2f_pose_step(subkey, trace, pose_proposal_args, vel_proposal_args, addr)


In [None]:
T = 0
print(f"step {T}:")
if T == 0:
    xyz = False
    infer_vel = False
else:
    xyz = True
    infer_vel = True

relevant_objects = [2]

key, subkey = split(key)
particle, _ = inference.inference_step(
                subkey,
                trace,
                foreground_background(rgbds[T], all_areas[T], 0.0),
                inference_hyperparams,
                [Pytree.const(f"object_pose_{o_id}") for o_id in relevant_objects],
                xyz,
                infer_vel,
            )

step 0:
Traceback (most recent call last):
  File "/home/hlwang/.conda/envs/b3dipe/lib/python3.12/site-packages/warp/jax_experimental/ffi.py", line 563, in ffi_callback
    self.func(*arg_list)
  File "/orcd/home/002/hlwang/code/b3d/src/b3d/physics/physics_utils.py", line 735, in simulate
    wp.launch(
  File "/home/hlwang/.conda/envs/b3dipe/lib/python3.12/site-packages/warp/context.py", line 5732, in launch
    pack_args(fwd_args, params)
  File "/home/hlwang/.conda/envs/b3dipe/lib/python3.12/site-packages/warp/context.py", line 5699, in pack_args
    params.append(pack_arg(kernel, arg_type, arg_name, a, device, adjoint))
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hlwang/.conda/envs/b3dipe/lib/python3.12/site-packages/warp/context.py", line 5332, in pack_arg
    raise RuntimeError(
RuntimeError: Error launching kernel 'eval_rigid_contacts', argument 'body_q' expects an array with 1 dimension(s) but the passed array has 2 dimension(s).



E0716 14:21:36.002199 3163655 pjrt_stream_executor_client.cc:3077] Execution of replica 0 failed: INTERNAL: Failed to capture gpu graph: FFI callback error: RuntimeError: Error launching kernel 'eval_rigid_contacts', argument 'body_q' expects an array with 1 dimension(s) but the passed array has 2 dimension(s).


XlaRuntimeError: INTERNAL: Failed to capture gpu graph: FFI callback error: RuntimeError: Error launching kernel 'eval_rigid_contacts', argument 'body_q' expects an array with 1 dimension(s) but the passed array has 2 dimension(s).

In [4]:
from b3d.chisight.gen3d.inference.inference import *

In [5]:
def update_and_get_scores_pose(key, proposed_pose, trace, addr):
        key, subkey = split(key)
        updated_trace = update_field(subkey, trace, addr, proposed_pose)
        return updated_trace, updated_trace.get_score()

def update_and_get_scores_pose_vel(key, proposed_pose, proposed_vel, trace, addr_pose, addr_vel):
        key, subkey = split(key)
        updated_trace = update_fields(subkey, trace, [addr_pose, addr_vel],
            [proposed_pose, proposed_vel])
        return updated_trace, updated_trace.get_score()

@jax.jit
def c2f_pose_step(
    key,
    trace,
    pose_proposal_args,
    vel_proposal_args,
    addr,
    include_previous_pose=True,
    xyz=True,
):
    addr = addr.unwrap()
    k1, k2, k3 = split(key, 3)

    # Propose the poses
    generation_keys = split(k1, inference_hyperparams.n_poses_vels)
    proposed_poses, log_q_poses = jax.vmap(
        propose_pose, in_axes=(0, None, None, None, None)
    )(generation_keys, trace, addr, pose_proposal_args, xyz)
    # jax.debug.print("proposed_poses: {v}", v=proposed_poses)
    # jax.debug.print("log_q_poses before: {v}", v=log_q_poses)

    proposed_poses, log_q_poses = maybe_swap_in_previous_pose(
        proposed_poses,
        log_q_poses,
        trace,
        addr,
        include_previous_pose,
        pose_proposal_args,
        xyz,
    )
    # jax.debug.print("rank after: {v}", v=ss.rankdata(log_q_poses))
    # jax.debug.print("score after: {v}", v=log_q_poses)

    param_generation_keys = split(k3, inference_hyperparams.n_poses_vels)
    _, p_scores = jax.vmap(update_and_get_scores_pose, in_axes=(0, 0, None, None))(
        param_generation_keys, proposed_poses, trace, addr
    )
    # jax.debug.print("p_scores: {x}", x=p_scores)
    # jax.debug.print("log_q_poses: {x}", x=log_q_poses)
    # jax.debug.print("log_q_vels: {x}", x=log_q_vels)
    # jax.debug.print("log_q_poses+log_q_vels: {x}", x=log_q_poses+log_q_vels)
    # Scoring + resampling
    weights = jnp.where(
        inference_hyperparams.include_q_scores_at_top_level,
        p_scores - log_q_poses,
        p_scores,
    )
    # jax.debug.print("weights: {x}", x=weights)

    # chosen_index = jax.random.categorical(k3, weights)
    chosen_index = weights.argmax()
    resampled_trace, _ = update_and_get_scores_pose(
            param_generation_keys[chosen_index],
            proposed_poses[chosen_index],
            trace,
            addr,
        )
    return (
        resampled_trace,
        logmeanexp(weights),
        proposed_poses[chosen_index],
        get_zreo_vel(None),
        proposed_poses,
        jax.vmap(get_zreo_vel, in_axes=(0,))(generation_keys),
        weights,
    )
    # return (
    #     param_generation_keys[chosen_index],
    #     logmeanexp(weights),
    #     proposed_poses[chosen_index],
    #     get_zreo_vel(None),
    #     proposed_poses,
    #     jax.vmap(get_zreo_vel, in_axes=(0,))(generation_keys),
    #     weights,
    # )

# @jax.jit
def c2f_pose_vel_step(
    key,
    trace,
    pose_proposal_args,
    vel_proposal_args,
    addr,
    include_previous_pose=True,
    xyz=True,
):
    addr = addr.unwrap()
    k1, k2, k3 = split(key, 3)

    # Propose the poses
    generation_keys = split(k1, inference_hyperparams.n_poses_vels)
    proposed_poses, log_q_poses = jax.vmap(
        propose_pose, in_axes=(0, None, None, None, None)
    )(generation_keys, trace, addr, pose_proposal_args, xyz)
    # jax.debug.print("proposed_poses: {v}", v=proposed_poses)
    # jax.debug.print("log_q_poses before: {v}", v=log_q_poses)
    generation_keys = split(k2, inference_hyperparams.n_poses_vels)
    proposed_vels, log_q_vels = jax.vmap(
        propose_vel, in_axes=(0, None, None, None)
    )(generation_keys, trace, addr.replace('pose', 'vel'), vel_proposal_args)
    # jax.debug.print("proposed_vels: {v}", v=proposed_vels)
    # jax.debug.print("log_q_vels before: {v}", v=log_q_vels)

    proposed_poses, log_q_poses = maybe_swap_in_previous_pose(
        proposed_poses,
        log_q_poses,
        trace,
        addr,
        include_previous_pose,
        pose_proposal_args,
        xyz,
    )
    proposed_vels, log_q_vels = maybe_swap_in_previous_vel(
        proposed_vels,
        log_q_vels,
        trace,
        addr.replace('pose', 'vel'),
        include_previous_pose,
        vel_proposal_args,
    )
    
    # jax.debug.print("rank after: {v}", v=ss.rankdata(log_q_poses))
    # jax.debug.print("score after: {v}", v=log_q_poses)
    
    param_generation_keys = split(k3, inference_hyperparams.n_poses_vels)
    return param_generation_keys, proposed_poses, proposed_vels, trace, addr, addr.replace('pose', 'vel')
    _, p_scores = jax.vmap(update_and_get_scores_pose_vel, in_axes=(0, 0, 0, None, None, None))(
        param_generation_keys, proposed_poses, proposed_vels, trace, addr, addr.replace('pose', 'vel')
    )
    
    # jax.debug.print("p_scores: {x}", x=p_scores)
    # jax.debug.print("log_q_poses: {x}", x=log_q_poses)
    # jax.debug.print("log_q_vels: {x}", x=log_q_vels)
    # jax.debug.print("log_q_poses+log_q_vels: {x}", x=log_q_poses+log_q_vels)
    # Scoring + resampling
    weights = jnp.where(
        inference_hyperparams.include_q_scores_at_top_level,
        p_scores - (log_q_poses+log_q_vels),
        p_scores,
    )
    # jax.debug.print("weights: {x}", x=weights)

    # chosen_index = jax.random.categorical(k3, weights)
    chosen_index = weights.argmax()
    
    return (
        param_generation_keys[chosen_index],
        logmeanexp(weights),
        proposed_poses[chosen_index],
        proposed_vels[chosen_index],
        proposed_poses,
        proposed_vels,
        weights,
    )


## T=0

In [6]:
T = 0
print(f"step {T}:")
if T == 0:
    xyz = False
else:
    xyz = True

relevant_object = 2
addr = Pytree.const(f"object_pose_{relevant_object}")

key, subkey = split(key)
trace = inference.advance_time(subkey, trace, foreground_background(rgbds[T], all_areas[T], 0.0))
trace

step 0:
NOT STEPPING


  second_moment = np.nanmean(


In [7]:
key, subkey = split(key)
# if T != 0:
param_generation_keys, proposed_poses, proposed_vels, trace, addr_pose, addr_vel = c2f_pose_vel_step(subkey, trace, inference_hyperparams.pose_proposal_args[0], inference_hyperparams.vel_proposal_args[0], addr, xyz=xyz)
# else:
    # generation_key, _, best_pose, _, proposed_poses, _, weights = c2f_pose_step(subkey, trace, inference_hyperparams.pose_proposal_args[0], inference_hyperparams.vel_proposal_args[0], addr, xyz=xyz)
    

In [8]:
param_generation_keys

In [9]:
proposed_poses

In [10]:
proposed_vels

In [11]:
addr_pose

In [12]:
addr_vel

In [13]:
trace

In [14]:
update_and_get_scores_pose_vel_vmap = jax.vmap(update_and_get_scores_pose_vel, in_axes=(0, 0, 0, None, None, None))
update_and_get_scores_pose_vel_vmap

In [15]:
_, p_scores = update_and_get_scores_pose_vel_vmap(param_generation_keys, proposed_poses, proposed_vels, trace, addr_pose, addr_vel)

NOT STEPPING


2025-07-17 21:55:57.328771: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3021] Can't reduce memory use below 19.05GiB (20451117758 bytes) by rematerialization; only reduced to 91.27GiB (98000000000 bytes), down from 91.27GiB (98000000000 bytes) originally


In [7]:
key, subkey = split(key)
trace, _, best_pose, best_vel, proposed_poses, proposed_vels, weights = c2f_pose_step(subkey, trace, inference_hyperparams.pose_proposal_args[0], inference_hyperparams.vel_proposal_args[0], addr, xyz=xyz)

NOT STEPPING
NOT STEPPING


In [8]:
key, subkey = split(key)
trace, _, best_pose, best_vel, proposed_poses, proposed_vels, weights = c2f_pose_step(subkey, trace, inference_hyperparams.pose_proposal_args[1], inference_hyperparams.vel_proposal_args[1], addr, xyz=xyz)

NOT STEPPING
NOT STEPPING


In [9]:
key, subkey = split(key)
trace, _, best_pose, best_vel, proposed_poses, proposed_vels, weights = c2f_pose_step(subkey, trace, inference_hyperparams.pose_proposal_args[2], inference_hyperparams.vel_proposal_args[2], addr, xyz=xyz)

NOT STEPPING
NOT STEPPING


In [10]:
if T == 0:
    resampled_trace, _ = update_and_get_scores_pose(
            generation_key,
            best_pose,
            trace,
            addr.unwrap(),
        )
else:
    resampled_trace, _ = update_and_get_scores_pose_vel(
        generation_key,
        best_pose,
        best_vel,
        trace,
        addr,
        addr.unwrap().replace('pose', 'vel'),
    )
resampled_trace

  second_moment = np.nanmean(


## T=1

In [10]:
T = 1
print(f"step {T}:")
if T == 0:
    xyz = False
else:
    xyz = True

relevant_object = 2
addr = Pytree.const(f"object_pose_{relevant_object}")

key, subkey = split(key)
trace = inference.advance_time(subkey, resampled_trace, foreground_background(rgbds[T], all_areas[T], 0.0))
trace

step 1:
prev pose: Pose(position=Array([[ 1.25      ,  0.        , -0.74991846],
       [-0.625     ,  0.01      ,  0.        ],
       [-0.8050178 ,  1.1496917 ,  0.0054369 ]], dtype=float32), quaternion=Array([[-0.0000000e+00, -0.0000000e+00, -0.0000000e+00,  1.0000000e+00],
       [-0.0000000e+00,  1.0000000e+00, -0.0000000e+00, -4.3711388e-08],
       [ 5.6135392e-01,  6.2854314e-01,  5.2036524e-01, -1.3796860e-01]],      dtype=float32))
prev velocities: Velocity(linvel=Array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32), angvel=Array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32))
stepped pose: Pose(position=Array([[ 1.25      ,  0.        , -0.74991846],
       [-0.625     ,  0.01      ,  0.        ],
       [-0.8050178 ,  1.1491523 ,  0.0054369 ]], dtype=float32), quaternion=Array([[ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  1.0000000e+00],
       [ 0.0000000e+00,  1.0000000e+00,  0.0000000e+00, -4.3711388e-08],
       [

In [11]:
key, subkey = split(key)
if T != 0:
    generation_key, _, best_pose, best_vel, proposed_poses, proposed_vels, weights = c2f_pose_vel_step(subkey, trace, inference_hyperparams.pose_proposal_args[0], inference_hyperparams.vel_proposal_args[0], addr, xyz=xyz)
else:
    generation_key, _, best_pose, _, proposed_poses, _, weights = c2f_pose_step(subkey, trace, inference_hyperparams.pose_proposal_args[0], inference_hyperparams.vel_proposal_args[0], addr, xyz=xyz)

prev pose: Pose(position=Array([[ 1.25      ,  0.        , -0.74991846],
       [-0.625     ,  0.01      ,  0.        ],
       [-0.8050178 ,  1.1496917 ,  0.0054369 ]], dtype=float32), quaternion=Array([[-0.0000000e+00, -0.0000000e+00, -0.0000000e+00,  1.0000000e+00],
       [-0.0000000e+00,  1.0000000e+00, -0.0000000e+00, -4.3711388e-08],
       [ 5.6135392e-01,  6.2854314e-01,  5.2036524e-01, -1.3796860e-01]],      dtype=float32))
prev velocities: Velocity(linvel=Array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32), angvel=Array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32))
stepped pose: Pose(position=Array([[ 1.25      ,  0.        , -0.74991846],
       [-0.625     ,  0.01      ,  0.        ],
       [-0.8050178 ,  1.1491523 ,  0.0054369 ]], dtype=float32), quaternion=Array([[ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  1.0000000e+00],
       [ 0.0000000e+00,  1.0000000e+00,  0.0000000e+00, -4.3711388e-08],
       [ 5.61353

In [12]:
best_pose

In [13]:
gt_pos_array[1][2,:]


In [14]:
gt_rot_array[0][2,:]

In [15]:
best_vel

In [16]:
gt_linvel_array[1][2,:]

In [17]:
gt_angvel_array[0][2,:]

In [18]:
if T == 0:
    resampled_trace, _ = update_and_get_scores_pose(
            generation_key,
            best_pose,
            trace,
            addr.unwrap(),
        )
else:
    resampled_trace, _ = update_and_get_scores_pose_vel(
        generation_key,
        best_pose,
        best_vel,
        trace,
        addr.unwrap(),
        addr.unwrap().replace('pose', 'vel'),
    )
resampled_trace

prev pose: Pose(position=Array([[ 1.25      ,  0.        , -0.74991846],
       [-0.625     ,  0.01      ,  0.        ],
       [-0.8050178 ,  1.1496917 ,  0.0054369 ]], dtype=float32), quaternion=Array([[-0.0000000e+00, -0.0000000e+00, -0.0000000e+00,  1.0000000e+00],
       [-0.0000000e+00,  1.0000000e+00, -0.0000000e+00, -4.3711388e-08],
       [ 5.6135392e-01,  6.2854314e-01,  5.2036524e-01, -1.3796860e-01]],      dtype=float32))
prev velocities: Velocity(linvel=Array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32), angvel=Array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32))
stepped pose: Pose(position=Array([[ 1.25      ,  0.        , -0.74991846],
       [-0.625     ,  0.01      ,  0.        ],
       [-0.8050178 ,  1.1491523 ,  0.0054369 ]], dtype=float32), quaternion=Array([[ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  1.0000000e+00],
       [ 0.0000000e+00,  1.0000000e+00,  0.0000000e+00, -4.3711388e-08],
       [ 5.61353

  second_moment = np.nanmean(


## T=2

In [21]:
T = 2
print(f"step {T}:")
if T == 0:
    xyz = False
else:
    xyz = True

relevant_object = 2
addr = Pytree.const(f"object_pose_{relevant_object}")

key, subkey = split(key)
trace = inference.advance_time(subkey, resampled_trace, foreground_background(rgbds[T], all_areas[T], 0.0))
trace

step 2:
prev pose: Pose(position=Array([[ 1.25      ,  0.        , -0.74991846],
       [-0.625     ,  0.01      ,  0.        ],
       [-0.7903974 ,  1.1353531 , -0.01451672]], dtype=float32), quaternion=Array([[-0.0000000e+00, -0.0000000e+00, -0.0000000e+00,  1.0000000e+00],
       [-0.0000000e+00,  1.0000000e+00, -0.0000000e+00, -4.3711388e-08],
       [ 5.3443199e-01,  6.6312522e-01,  5.2060491e-01, -6.0149744e-02]],      dtype=float32))
prev velocities: Velocity(linvel=Array([[ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ],
       [-0.45709187,  0.23146074, -0.20536013]], dtype=float32), angvel=Array([[ 0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ],
       [ 0.04684507, -0.37553844,  0.03289122]], dtype=float32))
stepped pose: Pose(position=Array([[ 1.25      ,  0.        , -0.74991846],
       [-0.625     ,  0.01      ,  0.        ],
       [-0.79008096,  1.1308991 , -0.01402882]], dtype=float32), qua

# HMM test case

In [6]:
import jax
import jax.numpy as jnp
from jax import jit
import genjax

In [11]:
N = 800
n_repeats = 100
variance = jnp.eye(N)
key, subkey = jax.random.split(key)
initial_state = jax.random.normal(subkey, (N,))


@genjax.gen
def hmm_step(x, _):
    new_x = genjax.mv_normal(x, variance) @ "new_x"
    return new_x, None


hmm = hmm_step.scan(n=100)

key, subkey = jax.random.split(key)
jitted = jit(hmm.repeat(n=n_repeats).simulate)
trace = jitted(subkey, (initial_state, None))
key, subkey = jax.random.split(key)
%timeit jitted(subkey, (initial_state, None))

343 ms ± 2.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [12]:
jitted = jit(hmm.simulate)


def hmm_debatched(key, initial_state):
    keys = jax.random.split(key, n_repeats)
    traces = {}
    for i in range(n_repeats):
        trace = jitted(keys[i], (initial_state, None))
        traces[i] = trace
    return traces


key, subkey = jax.random.split(key)
# About 4x slower on arm64 CPU and 40x on a Google Colab GPU
%timeit hmm_debatched(subkey, initial_state)

4.4 s ± 12.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
# sim_time = 0.0
# for i in range(250):
#     print(f"step: {i}")
#     stepped_model, stepped_state = step(initial_state["prev_model"], initial_state["prev_state"], hyperparams["physics_args"])
#     # print(stepped_state._body_q)
#     initial_state["prev_model"] = stepped_model
#     initial_state["prev_state"] = stepped_state
#     state.body_q = stepped_state._body
# 
# 
# 
# 
# 
# 
# 
# 
# 
# _q
#     state.body_qd = stepped_state._body_qd
#     renderer.begin_frame(sim_time)
#     renderer.render(state)
#     renderer.end_frame()
#     sim_time += 1/100
# renderer.save()

In [8]:
# for i in range(250):
#     print(i)
#     # initial_state["prev_state"].clear_forces()
#     # initial_state["prev_model"].clear_old_count()
#     stepped_model, stepped_state = step(initial_state["prev_model"], initial_state["prev_state"], hyperparams)
#     initial_state["prev_model"] = stepped_model
#     initial_state["prev_state"] = stepped_state

In [9]:
# for i in range(250):
#     print(i)
#     stepped_model, stepped_state = step(initial_state["prev_model"], initial_state["prev_state"], hyperparams)
#     initial_state["prev_model"] = stepped_model
#     initial_state["prev_state"] = stepped_state

In [4]:
stepped_model

<b3d.physics.core.Model at 0x7fcc8411ae50>

In [13]:
import warp as wp

a = wp.array([1])
a

<warp.types.array at 0x7f96ffd11d10>

In [19]:
a.list()[0]

1

In [12]:
stepped_model["rigid_contact_broad_shape0"].size

1

In [1]:
import warp as wp

In [2]:
a = wp.array([1.0, 2.0, 3.0], dtype=float)
a

Warp 1.7.0.dev20250223 initialized:
   CUDA Toolkit 12.8, Driver 12.2
   Devices:
     "cpu"      : "x86_64"
     "cuda:0"   : "NVIDIA A40" (44 GiB, sm_86, mempool enabled)
     "cuda:1"   : "NVIDIA A40" (47 GiB, sm_86, mempool enabled)
     "cuda:2"   : "NVIDIA A40" (47 GiB, sm_86, mempool enabled)
     "cuda:3"   : "NVIDIA A40" (47 GiB, sm_86, mempool enabled)
     "cuda:4"   : "NVIDIA A40" (47 GiB, sm_86, mempool enabled)
     "cuda:5"   : "NVIDIA A40" (47 GiB, sm_86, mempool enabled)
     "cuda:6"   : "NVIDIA A40" (47 GiB, sm_86, mempool enabled)
     "cuda:7"   : "NVIDIA A40" (47 GiB, sm_86, mempool enabled)
   CUDA peer access:
     Supported fully (all-directional)
   Kernel cache:
     /home/haw027/.cache/warp/1.7.0.dev20250223


<warp.types.array at 0x7f18ce475bd0>

In [2]:
wp.vec3(1.0, 2.0, 3.0)

<warp.types.vec3f at 0x7f6a7ad60440>

In [4]:
wp.vec3(a[0], a[1], a[2])

RuntimeError: Item indexing is not supported on wp.array objects

In [7]:
wp.array(a, dtype=wp.vec3)

<warp.types.array at 0x7f2b5c854950>

In [1]:
import jax
import jax.numpy as jnp
import warp as wp
from warp.jax_experimental.ffi import jax_callable

@wp.kernel
def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
    tid = wp.tid()
    output[tid] = a[tid] * s

@wp.kernel
def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
    tid = wp.tid()
    output[tid] = a[tid] * s

def example_func(
    # inputs
    a: wp.array(dtype=float),
    b: wp.array(dtype=wp.vec2),
    s: float,
    # outputs
    c: wp.array(dtype=float),
    d: wp.array(dtype=wp.vec2),
    ):
    wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
    wp.launch(scale_vec_kernel, dim=b.shape, inputs=[b, s], outputs=[d])

jax_func = jax_callable(example_func, num_outputs=2)

o = wp.array([1, 2, 3])
p = wp.to_jax(o)


Warp 1.7.0.dev20250223 initialized:
   CUDA Toolkit 12.8, Driver 12.2
   Devices:
     "cpu"      : "x86_64"
     "cuda:0"   : "NVIDIA A40" (44 GiB, sm_86, mempool enabled)
     "cuda:1"   : "NVIDIA A40" (47 GiB, sm_86, mempool enabled)
     "cuda:2"   : "NVIDIA A40" (47 GiB, sm_86, mempool enabled)
     "cuda:3"   : "NVIDIA A40" (47 GiB, sm_86, mempool enabled)
     "cuda:4"   : "NVIDIA A40" (47 GiB, sm_86, mempool enabled)
     "cuda:5"   : "NVIDIA A40" (47 GiB, sm_86, mempool enabled)
     "cuda:6"   : "NVIDIA A40" (47 GiB, sm_86, mempool enabled)
     "cuda:7"   : "NVIDIA A40" (47 GiB, sm_86, mempool enabled)
   CUDA peer access:
     Supported fully (all-directional)
   Kernel cache:
     /home/haw027/.cache/warp/1.7.0.dev20250223


ArgumentError: argument 3: TypeError: expected CFunctionType instance instead of CFunctionType

In [2]:
from warp.jax_experimental.ffi import jax_callable
import warp as wp
import jax
import jax.numpy as jnp
import genjax

@wp.kernel
def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
    tid = wp.tid()
    output[tid] = a[tid] * s

@wp.kernel
def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
    tid = wp.tid()
    output[tid] = a[tid] * s


# The Python function to call.
# Note the argument type annotations, just like Warp kernels.
def example_func(
    # inputs
    a: wp.array(dtype=float),
    b: wp.array(dtype=wp.vec2),
    s: float,
    # outputs
    c: wp.array(dtype=float),
    d: wp.array(dtype=wp.vec2),
):
    # launch multiple kernels
    wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
    wp.launch(scale_vec_kernel, dim=b.shape, inputs=[b, s], outputs=[d])


jax_func = jax_callable(example_func, num_outputs=2)

# @jax.jit
# def f():
#     # inputs
#     a = jnp.arange(10, dtype=jnp.float32)
#     b = jnp.arange(10, dtype=jnp.float32).reshape((5, 2))  # wp.vec2
#     s = 2.0

#     # output shapes
#     output_dims = {"c": a.shape, "d": b.shape}

#     c, d = jax_func(a, b, s, output_dims=output_dims)

#     return c, d

def make_genjax_func():
    @genjax.gen
    def f():
        # inputs
        a = jnp.arange(10, dtype=jnp.float32)
        b = jnp.arange(10, dtype=jnp.float32).reshape((5, 2))  # wp.vec2
        s = 2.0

        # output shapes
        output_dims = {"c": a.shape, "d": b.shape}

        c, d = jax_func(a, b, s, output_dims=output_dims)

        return c, d
    return f

f = make_genjax_func()
importance_jit = jax.jit(f.importance)
# r1, r2 = f()
# print(r1)
# print(r2)

In [1]:
from warp.jax_experimental.ffi import jax_callable
import warp as wp
import jax.numpy as jnp
import jax
from warp.jax_experimental.ffi import register_ffi_callback
from warp.jax import get_jax_device

In [5]:
@wp.kernel
def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
    tid = wp.tid()
    output[tid] = a[tid] * s

@wp.kernel
def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
    tid = wp.tid()
    output[tid] = a[tid] * s


# The Python function to call.
# Note the argument type annotations, just like Warp kernels.
def example_func(
    # inputs
    a: wp.array(dtype=float),
    b: wp.array(dtype=wp.vec2),
    s: float,
    # outputs
    c: wp.array(dtype=float),
    d: wp.array(dtype=wp.vec2),
):
    # launch multiple kernels
    wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
    wp.launch(scale_vec_kernel, dim=b.shape, inputs=[b, s], outputs=[d])


jax_func = jax_callable(example_func, num_outputs=2)

@jax.jit
def f():
    # inputs
    a = jnp.arange(10, dtype=jnp.float32)
    b = jnp.arange(10, dtype=jnp.float32).reshape((5, 2))  # wp.vec2
    s = 2.0

    # output shapes
    # output_dims = {"c": a.shape, "d": b.shape}

    # c, d = jax_func(a, b, s, output_dims=output_dims)
    c, d = jax_func(a, b, s)

    return c, d

r1, r2 = f()
print(r1)
print(r2)

[ 0.  2.  4.  6.  8. 10. 12. 14. 16. 18.]
[[ 0.  2.]
 [ 4.  6.]
 [ 8. 10.]
 [12. 14.]
 [16. 18.]
 [ 0.  0.]
 [ 0.  0.]
 [ 0.  0.]
 [ 0.  0.]
 [ 0.  0.]]


In [7]:
@wp.kernel
def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
    tid = wp.tid()
    output[tid] = a[tid] * s

@wp.kernel
def scale_twice_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
    tid = wp.tid()
    output[tid] = a[tid] * s * s

# The Python function to call.
# Note the argument type annotations, just like Warp kernels.
def example_func(
    # inputs
    a: wp.array(dtype=float),
    b: wp.array(dtype=float),
    s: float,
    # outputs
    c: wp.array(dtype=float),
):
    # launch multiple kernels
    wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[b])
    wp.launch(scale_twice_kernel, dim=b.shape, inputs=[b, s], outputs=[c])


jax_func = jax_callable(example_func, num_outputs=1)

@jax.jit
def f():
    # inputs
    a = jnp.arange(10, dtype=jnp.float32)
    b = jnp.arange(10, dtype=jnp.float32)
    s = 2.0

    # output shapes
    # output_dims = {"c": a.shape}

    # c = jax_func(a, b, s, output_dims=output_dims)
    c = jax_func(a, b, s)

    return c

r1 = f()
print(r1)

[Array([ 0.,  8., 16., 24., 32., 40., 48., 56., 64., 72.], dtype=float32)]


In [None]:
from warp.sim.integrator import integrate_bodies

In [None]:
@wp.kernel
def eval_rigid_contacts(
    body_q: wp.array(dtype=wp.transform),
    body_qd: wp.array(dtype=wp.spatial_vector),
    body_com: wp.array(dtype=wp.vec3),
    ke: wp.array(dtype=float),
    kd: wp.array(dtype=float),
    kf: wp.array(dtype=float),
    ka: wp.array(dtype=float),
    mu: wp.array(dtype=float),
    shape_body: wp.array(dtype=int),
    contact_count: wp.array(dtype=int),
    contact_point0: wp.array(dtype=wp.vec3),
    contact_point1: wp.array(dtype=wp.vec3),
    contact_normal: wp.array(dtype=wp.vec3),
    contact_shape0: wp.array(dtype=int),
    contact_shape1: wp.array(dtype=int),
    force_in_world_frame: bool,
    friction_smoothing: float,
    # outputs
    body_f: wp.array(dtype=wp.spatial_vector),
):
    tid = wp.tid()

    count = contact_count[0]
    if tid >= count:
        return

    # retrieve contact thickness, compute average contact material properties
    ke = 0.0  # contact normal force stiffness
    kd = 0.0  # damping coefficient
    kf = 0.0  # friction force stiffness
    ka = 0.0  # adhesion distance
    mu = 0.0  # friction coefficient
    mat_nonzero = 0
    thickness_a = 0.0
    thickness_b = 0.0
    shape_a = contact_shape0[tid]
    shape_b = contact_shape1[tid]
    if shape_a == shape_b:
        return
    body_a = -1
    body_b = -1
    if shape_a >= 0:
        mat_nonzero += 1
        ke += ke[shape_a]
        kd += kd[shape_a]
        kf += kf[shape_a]
        ka += ka[shape_a]
        mu += mu[shape_a]
        thickness_a = 1.e-05
        body_a = shape_body[shape_a]
    if shape_b >= 0:
        mat_nonzero += 1
        ke += ke[shape_b]
        kd += kd[shape_b]
        kf += kf[shape_b]
        ka += ka[shape_b]
        mu += mu[shape_b]
        thickness_b = 1.e-05
        body_b = shape_body[shape_b]
    if mat_nonzero > 0:
        ke /= float(mat_nonzero)
        kd /= float(mat_nonzero)
        kf /= float(mat_nonzero)
        ka /= float(mat_nonzero)
        mu /= float(mat_nonzero)

    # contact normal in world space
    n = contact_normal[tid]
    bx_a = contact_point0[tid]
    bx_b = contact_point1[tid]
    r_a = wp.vec3(0.0)
    r_b = wp.vec3(0.0)
    if body_a >= 0:
        X_wb_a = body_q[body_a]
        X_com_a = body_com[body_a]
        bx_a = wp.transform_point(X_wb_a, bx_a) - thickness_a * n
        r_a = bx_a - wp.transform_point(X_wb_a, X_com_a)

    if body_b >= 0:
        X_wb_b = body_q[body_b]
        X_com_b = body_com[body_b]
        bx_b = wp.transform_point(X_wb_b, bx_b) + thickness_b * n
        r_b = bx_b - wp.transform_point(X_wb_b, X_com_b)

    d = wp.dot(n, bx_a - bx_b)

    if d >= ka:
        return

    # compute contact point velocity
    bv_a = wp.vec3(0.0)
    bv_b = wp.vec3(0.0)
    if body_a >= 0:
        body_v_s_a = body_qd[body_a]
        body_w_a = wp.spatial_top(body_v_s_a)
        body_v_a = wp.spatial_bottom(body_v_s_a)
        if force_in_world_frame:
            bv_a = body_v_a + wp.cross(body_w_a, bx_a)
        else:
            bv_a = body_v_a + wp.cross(body_w_a, r_a)

    if body_b >= 0:
        body_v_s_b = body_qd[body_b]
        body_w_b = wp.spatial_top(body_v_s_b)
        body_v_b = wp.spatial_bottom(body_v_s_b)
        if force_in_world_frame:
            bv_b = body_v_b + wp.cross(body_w_b, bx_b)
        else:
            bv_b = body_v_b + wp.cross(body_w_b, r_b)

    # relative velocity
    v = bv_a - bv_b

    # print(v)

    # decompose relative velocity
    vn = wp.dot(n, v)
    vt = v - n * vn

    # contact elastic
    fn = d * ke

    # contact damping
    fd = wp.min(vn, 0.0) * kd * wp.step(d)

    # viscous friction
    # ft = vt*kf

    # Coulomb friction (box)
    # lower = mu * d * ke
    # upper = -lower

    # vx = wp.clamp(wp.dot(wp.vec3(kf, 0.0, 0.0), vt), lower, upper)
    # vz = wp.clamp(wp.dot(wp.vec3(0.0, 0.0, kf), vt), lower, upper)

    # ft = wp.vec3(vx, 0.0, vz)

    # Coulomb friction (smooth, but gradients are numerically unstable around |vt| = 0)
    ft = wp.vec3(0.0)
    if d < 0.0:
        # use a smooth vector norm to avoid gradient instability at/around zero velocity
        vs = wp.norm_huber(vt, delta=friction_smoothing)
        if vs > 0.0:
            fr = vt / vs
            ft = fr * wp.min(kf * vs, -mu * (fn + fd))

    f_total = n * (fn + fd) + ft
    # f_total = n * (fn + fd)
    # f_total = n * fn

    if body_a >= 0:
        if force_in_world_frame:
            wp.atomic_add(body_f, body_a, wp.spatial_vector(wp.cross(bx_a, f_total), f_total))
        else:
            wp.atomic_sub(body_f, body_a, wp.spatial_vector(wp.cross(r_a, f_total), f_total))

    if body_b >= 0:
        if force_in_world_frame:
            wp.atomic_sub(body_f, body_b, wp.spatial_vector(wp.cross(bx_b, f_total), f_total))
        else:
            wp.atomic_add(body_f, body_b, wp.spatial_vector(wp.cross(r_b, f_total), f_total))


In [None]:
def simulate(
    # inputs
    rigid_contact_max: int,
    body_q: wp.array(dtype=wp.transform),
    body_qd: wp.array(dtype=wp.spatial_vector),
    body_com: wp.array(dtype=wp.vec3),
    ke: wp.array(dtype=float),
    kd: wp.array(dtype=float),
    kf: wp.array(dtype=float),
    ka: wp.array(dtype=float),
    mu: wp.array(dtype=float),
    shape_body: wp.array(dtype=int),
    rigid_contact_count: wp.array(dtype=int),
    rigid_contact_point0: wp.array(dtype=wp.vec3),
    rigid_contact_point1: wp.array(dtype=wp.vec3),
    rigid_contact_normal: wp.array(dtype=wp.vec3),
    rigid_contact_shape0: wp.array(dtype=int),
    rigid_contact_shape1: wp.array(dtype=int),
    force_in_world_frame: bool = False,
    friction_smoothing: float = 1.0,
    # outputs
    body_f: wp.array(dtype=wp.spatial_vector),

    # inputs
    body_mass: wp.array(dtype=float),
    body_inertia: wp.array(dtype=wp.mat33),
    body_inv_mass: wp.array(dtype=float),
    body_inv_inertia: wp.array(dtype=wp.mat33),
    gravity: wp.vec3,
    angular_damping: float = 0.05,
    dt: float,
    body_count: int,
    # outputs
    body_q_new: wp.array(dtype=wp.transform),
    body_qd_new: wp.array(dtype=wp.spatial_vector),
):
    # collision
    if model.shape_contact_pair_count or model.ground and model.shape_ground_contact_pair_count:
        # clear old count
        model.rigid_contact_count.zero_()

        model.rigid_contact_broad_shape0.fill_(-1)
        model.rigid_contact_broad_shape1.fill_(-1)

    if model.shape_contact_pair_count:
        wp.launch(
            kernel=broadphase_collision_pairs,
            dim=model.shape_contact_pair_count,
            inputs=[
                model.shape_contact_pairs,
                state.body_q,
                model.shape_transform,
                model.shape_body,
                model.body_mass,
                model.shape_count,
                model.shape_geo,
                model.shape_collision_radius,
                model.rigid_contact_max,
                model.rigid_contact_margin,
                model.rigid_mesh_contact_max,
                iterate_mesh_vertices,
            ],
            outputs=[
                model.rigid_contact_count,
                model.rigid_contact_broad_shape0,
                model.rigid_contact_broad_shape1,
                model.rigid_contact_point_id,
                model.rigid_contact_point_limit,
            ],
            record_tape=False,
        )

    if model.ground and model.shape_ground_contact_pair_count:
        wp.launch(
            kernel=broadphase_collision_pairs,
            dim=model.shape_ground_contact_pair_count,
            inputs=[
                model.shape_ground_contact_pairs,
                state.body_q,
                model.shape_transform,
                model.shape_body,
                model.body_mass,
                model.shape_count,
                model.shape_geo,
                model.shape_collision_radius,
                model.rigid_contact_max,
                model.rigid_contact_margin,
                model.rigid_mesh_contact_max,
                iterate_mesh_vertices,
            ],
            outputs=[
                model.rigid_contact_count,
                model.rigid_contact_broad_shape0,
                model.rigid_contact_broad_shape1,
                model.rigid_contact_point_id,
                model.rigid_contact_point_limit,
            ],
            record_tape=False,
        )

    if model.shape_contact_pair_count or model.ground and model.shape_ground_contact_pair_count:
        model.rigid_contact_count.zero_()
        model.rigid_contact_tids.zero_()
        if model.rigid_contact_pairwise_counter is not None:
            model.rigid_contact_pairwise_counter.zero_()
        model.rigid_contact_shape0.fill_(-1)
        model.rigid_contact_shape1.fill_(-1)

        wp.launch(
            kernel=handle_contact_pairs,
            dim=model.rigid_contact_max,
            inputs=[
                state.body_q,
                model.shape_transform,
                model.shape_body,
                model.shape_geo,
                model.rigid_contact_margin,
                model.rigid_contact_broad_shape0,
                model.rigid_contact_broad_shape1,
                model.shape_count,
                model.rigid_contact_point_id,
                model.rigid_contact_point_limit,
                edge_sdf_iter,
            ],
            outputs=[
                model.rigid_contact_count,
                model.rigid_contact_shape0,
                model.rigid_contact_shape1,
                model.rigid_contact_point0,
                model.rigid_contact_point1,
                model.rigid_contact_offset0,
                model.rigid_contact_offset1,
                model.rigid_contact_normal,
                model.rigid_contact_thickness,
                model.rigid_contact_pairwise_counter,
                model.rigid_contact_tids,
            ],
        )

    # compute forces
    wp.launch(
        kernel=eval_rigid_contacts,
        dim=rigid_contact_max,
        inputs=[
            body_q,
            body_qd,
            body_com,
            ke,
            kd,
            kf,
            ka,
            mu,
            shape_body,
            rigid_contact_count,
            rigid_contact_point0,
            rigid_contact_point1,
            rigid_contact_normal,
            rigid_contact_shape0,
            rigid_contact_shape1,
            force_in_world_frame,
            friction_smoothing,
        ],
        outputs=[body_f],
    )
    
    # integrate
    wp.launch(
        kernel=integrate_bodies,
        dim=body_count,
        inputs=[
            body_q,
            body_qd,
            body_f,
            body_com,
            body_mass,
            body_inertia,
            body_inv_mass,
            body_inv_inertia,
            gravity,
            angular_damping,
            dt,
        ],
        outputs=[body_q_new, body_qd_new],
    )

# test physion

In [1]:
import numpy as np
import trimesh
import warp as wp
import warp.sim.render
from os.path import isfile, join
from os import listdir
import h5py
import jax

In [2]:
def euler_angles_to_quaternion(euler: np.ndarray) -> np.ndarray:
    """
    Convert Euler angles to a quaternion.

    Source: https://pastebin.com/riRLRvch

    :param euler: The Euler angles vector.

    :return: The quaternion representation of the Euler angles.
    """
    pitch = np.radians(euler[0] * 0.5)
    cp = np.cos(pitch)
    sp = np.sin(pitch)

    yaw = np.radians(euler[1] * 0.5)
    cy = np.cos(yaw)
    sy = np.sin(yaw)

    roll = np.radians(euler[2] * 0.5)
    cr = np.cos(roll)
    sr = np.sin(roll)

    x = sy * cp * sr + cy * sp * cr
    y = sy * cp * cr - cy * sp * sr
    z = cy * cp * sr - sy * sp * cr
    w = cy * cp * cr + sy * sp * sr
    return np.array([x, y, z, w])

class Example:
    def __init__(self, stage_path="example_rigid_contact.usd", hdf5_path=''):
        builder = wp.sim.ModelBuilder()

        self.sim_time = 0.0
        fps = 100
        self.frame_dt = 1.0 / fps

        self.sim_substeps = 10
        self.sim_dt = self.frame_dt / self.sim_substeps

        self.mu = 0.25
        self.restitution = 0.4

        self.mesh_path = '/ccn2/u/rmvenkat/data/all_flex_meshes'
        self.mesh_lib = dict([(f, join(self.mesh_path, f)) for f in listdir(self.mesh_path) if isfile(join(self.mesh_path, f)) and join(self.mesh_path, f).endswith('.obj')])
        initial_positions, initial_rotations, model_names, scales = self.load_hdf5(hdf5_path)

        # meshes
        for i, (pos, rot, name, scale) in enumerate(zip(initial_positions, initial_rotations, model_names, scales)):
            b = builder.add_body(
                origin=wp.transform(
                    pos, euler_angles_to_quaternion(rot),
                )
            )

            mesh = self.load_mesh(self.mesh_lib[name.decode('utf-8')+'.obj'])
            builder.add_shape_mesh(
                body=b,
                mesh=mesh,
                pos=wp.vec3(0.0, 0.0, 0.0),
                scale=scale,
                # density=5,
                restitution=self.restitution,
                mu=self.mu,
                # ke=self.ke,
                # kd=self.kd,
                # kf=self.kf,
                density=1e3,
                has_ground_collision=True,
                has_shape_collision=True
            )
        builder.set_ground_plane(mu=self.mu)
        
        # finalize model
        self.model = builder.finalize()
        self.model.ground = True
        print(f"self.model.rigid_contact_count: {self.model.rigid_contact_count} {type(self.model.rigid_contact_count)}")
        print(f"to jax: {wp.to_jax(self.model.rigid_contact_count)} {type(wp.to_jax(self.model.rigid_contact_count))}")

        print(f"self.model.rigid_mesh_contact_max: {self.model.rigid_mesh_contact_max},\n self.model.ground: {self.model.ground},\n self.model.shape_materials: {self.model.shape_materials},\n self.model.shape_geo: {self.model.shape_geo},\n self.model.device: {self.model.device} {type(self.model.device)},\n self.model.spring_count: {self.model.spring_count},\n self.model.tri_count: {self.model.tri_count},\n self.model.enable_tri_collisions: {self.model.enable_tri_collisions},\n self.model.edge_count: {self.model.edge_count},\n self.model.particle_count: {self.model.particle_count},\n self.model.tet_count: {self.model.tet_count},\n self.model.rigid_contact_max: {self.model.rigid_contact_max},\n self.model.ground: {self.model.ground},\n self.model.shape_ground_contact_pair_count: {self.model.shape_ground_contact_pair_count},\n self.model.shape_contact_pair_count: {self.model.shape_contact_pair_count},\n self.model.joint_count: {self.model.joint_count},\n self.model.particle_count: {self.model.particle_count},\n self.model.shape_count: {self.model.shape_count},\n self.model.muscle_count: {self.model.muscle_count}")

        

        self.integrator = wp.sim.SemiImplicitIntegrator()

        if stage_path:
            self.renderer = wp.sim.render.SimRenderer(self.model, stage_path, scaling=0.5)
        else:
            self.renderer = None

        self.state_0 = self.model.state()
        self.state_1 = self.model.state()

        wp.sim.eval_fk(self.model, self.model.joint_q, self.model.joint_qd, None, self.state_0)

        self.use_cuda_graph = wp.get_device().is_cuda
        if self.use_cuda_graph:
            with wp.ScopedCapture() as capture:
                self.simulate()
            self.graph = capture.graph

    def load_hdf5(self, path):
        with h5py.File(path, "r") as f:
            distractors = (
                np.array(f["static"]["distractors"])
                if np.array(f["static"]["distractors"]).size != 0
                else []
            )
            occluders = (
                np.array(f["static"]["occluders"])
                if np.array(f["static"]["occluders"]).size != 0
                else []
            )
            distractors_occluders = np.concatenate([distractors, occluders])

            model_names = np.array(f["static"]["model_names"])[1:]

            scales = np.array(
                f["static"]["scale"]
            )[1:]
            initial_positions = np.array(f["static"]["initial_position"])[1:]
            initial_rotations = np.array(f["static"]["initial_rotation"])[1:]
            
            if len(distractors_occluders):
                model_names = model_names[: -len(distractors_occluders)]
                scales = scales[
                    : -len(distractors_occluders)
                ]
                initial_positions = initial_positions[
                    : -len(distractors_occluders)
                ]
                initial_rotations = initial_rotations[
                    : -len(distractors_occluders)
                ]
        return initial_positions, initial_rotations, model_names, scales

    def load_mesh(self, path):
        m = trimesh.load(path)
        mesh_points = np.array(m.vertices)
        mesh_indices = np.array(m.faces, dtype=np.int32).flatten()
        mesh = wp.sim.Mesh(mesh_points, mesh_indices)
        return mesh

    def simulate(self):
        for _ in range(self.sim_substeps):
            self.state_0.clear_forces()
            wp.sim.collide(self.model, self.state_0)
            self.integrator.simulate(self.model, self.state_0, self.state_1, self.sim_dt)
            self.state_0, self.state_1 = self.state_1, self.state_0
            
    def step(self):
        with wp.ScopedTimer("step", active=True):
            if self.use_cuda_graph:
                wp.capture_launch(self.graph)
                print(self.state_0.body_q, self.state_1.body_q)
            # else:
            #     self.simulate()
        self.sim_time += self.frame_dt

    def render(self):
        if self.renderer is None:
            return

        with wp.ScopedTimer("render", active=True):
            self.renderer.begin_frame(self.sim_time)
            self.renderer.render(self.state_0)
            self.renderer.end_frame()

In [None]:
if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
    parser.add_argument(
        "--stage_path",
        type=lambda x: None if x == "None" else str(x),
        default="example_rigid_contact.usd",
        help="Path to the output USD file.",
    )
    parser.add_argument(
        "--hdf5_path",
        type=lambda x: None if x == "None" else str(x),
        default="/ccn2/u/rmvenkat/data/testing_physion/regenerate_from_old_commit/test_humans_consolidated/lf_0/support_all_movies/pilot_towers_nb3_fr015_SJ025_mono1_dis0_occ0_tdwroom_unstable_0004.hdf5",
        help="Path to the input hdf5 file.",
    )
    parser.add_argument("--num_frames", type=int, default=250, help="Total number of frames.")
    args = parser.parse_known_args()[0]

    with wp.ScopedDevice(args.device):
        example = Example(stage_path=args.stage_path, hdf5_path=args.hdf5_path)

        for _ in range(args.num_frames):
            example.step()
            example.render()

        if example.renderer:
            example.renderer.save()

Warp 1.7.0.dev20250223 initialized:
   CUDA Toolkit 12.8, Driver 12.2
   Devices:
     "cpu"      : "x86_64"
     "cuda:0"   : "NVIDIA A40" (44 GiB, sm_86, mempool enabled)
     "cuda:1"   : "NVIDIA A40" (47 GiB, sm_86, mempool enabled)
     "cuda:2"   : "NVIDIA A40" (47 GiB, sm_86, mempool enabled)
     "cuda:3"   : "NVIDIA A40" (47 GiB, sm_86, mempool enabled)
     "cuda:4"   : "NVIDIA A40" (47 GiB, sm_86, mempool enabled)
     "cuda:5"   : "NVIDIA A40" (47 GiB, sm_86, mempool enabled)
     "cuda:6"   : "NVIDIA A40" (47 GiB, sm_86, mempool enabled)
     "cuda:7"   : "NVIDIA A40" (47 GiB, sm_86, mempool enabled)
   CUDA peer access:
     Supported fully (all-directional)
   Kernel cache:
     /home/haw027/.cache/warp/1.7.0.dev20250223
Module warp.sim.inertia 185ee85 load on device 'cuda:0' took 1.05 ms  (cached)
Module warp.sim.collide 26674e6 load on device 'cuda:0' took 1.29 ms  (cached)
self.model.rigid_contact_count: [0] <class 'warp.types.array'>
to jax: [0] <class 'jaxlib.xla_ex

ImportError: Failed to import pxr. Please install USD (e.g. via `pip install usd-core`).

: 