In [1]:
%load_ext ipython_beartype
%beartype

%load_ext autoreload
%autoreload 2

In [None]:
# Add these lines to the cell with id 34bb5824
import warnings

import jax.numpy as jnp
import numpy as np
import rerun as rr
from jaxtyping import Array, Float
from simplecv.data.exoego.skeleton.mediapipe import MEDIAPIPE_ID2NAME, MEDIAPIPE_IDS, MEDIAPIPE_LINKS

from pi0_lerobot.mano.mano_jax import mano_j_left, mano_j_right

warnings.filterwarnings('ignore', category=RuntimeWarning)




In [3]:

def set_pose_annotation_context() -> None:
    rr.log(
        "/",
        rr.AnnotationContext(
            [
                rr.ClassDescription(
                    info=rr.AnnotationInfo(id=0, label="Triangulated Hand", color=(0, 0, 255)),
                    keypoint_annotations=[
                        rr.AnnotationInfo(id=id, label=name) for id, name in MEDIAPIPE_ID2NAME.items()
                    ],
                    keypoint_connections=MEDIAPIPE_LINKS,
                ),
                rr.ClassDescription(
                    info=rr.AnnotationInfo(id=1, label="Optimized Hand", color=(255, 0, 0)),
                    keypoint_annotations=[
                        rr.AnnotationInfo(id=id, label=name) for id, name in MEDIAPIPE_ID2NAME.items()
                    ],
                    keypoint_connections=MEDIAPIPE_LINKS,
                ),
            
            ]
        ),
        static=True,
    )

In [4]:
rr.init("mano_joints")

set_pose_annotation_context()

pose_coeffs: Float[Array, "batch 48"] = jnp.zeros((1, 48))
left_trans: Float[Array, "batch 1 3"] = jnp.array([[[0,0,0.2]]])
right_trans: Float[Array, "batch 1 3"] = jnp.array([[[0,0,-0.2]]])

xyz_left:Float[Array, "batch 21 3"] = mano_j_left(pose_coeffs, left_trans)
xyz_right:Float[Array, "batch 21 3"] = mano_j_right(pose_coeffs, right_trans)




rr.log("left joints", rr.Points3D(positions=xyz_left[0], class_ids=0, keypoint_ids=MEDIAPIPE_IDS, colors=(0, 255, 0), show_labels=False))
rr.log("right joints", rr.Points3D(positions=xyz_right[0], class_ids=0,keypoint_ids=MEDIAPIPE_IDS,colors=(0, 255, 0), show_labels=False))

rr.notebook_show()

[2025-04-21T20:11:15Z WARN  re_log_types::path::parse_path] When parsing the entity path "left joints": Unescaped whitespace. The path will be interpreted as /left\ joints
[2025-04-21T20:11:15Z WARN  re_log_types::path::parse_path] When parsing the entity path "right joints": Unescaped whitespace. The path will be interpreted as /right\ joints


Viewer()

If not, consider setting `RERUN_NOTEBOOK_ASSET`. Consult https://pypi.org/project/rerun-notebook/0.22.1/ for details.



In [5]:
from pathlib import Path

from simplecv.data.exoego.hocap import HOCapSequence

sequence: HOCapSequence = HOCapSequence(
                data_path=Path("../data/hocap/sample"),
                sequence_name="20231024_180733",
                subject_id="8",
                load_labels=True,
            )

Found HoloLens camera: hololens_kv5h72


Loading videos: 100%|██████████| 8/8 [00:00<00:00, 27391.37it/s]
Loading 3D labels: 100%|██████████| 8/8 [00:02<00:00,  3.76it/s]
Indexing depth images: 100%|██████████| 8/8 [00:00<00:00, 423.72it/s]


In [6]:
from jaxtyping import Float32
from numpy import ndarray
from simplecv.camera_parameters import PinholeParameters

exo_cam_list: list[PinholeParameters] = sequence.exo_cam_list
projection_all_list: list[Float32[np.ndarray, "3 4"]] = []
for exo_cam in sequence.exo_cam_list:
    projection_matrix: Float32[ndarray, "3 4"] = exo_cam.projection_matrix.astype(np.float32)
    projection_all_list.append(projection_matrix)

Pall: Float32[np.ndarray, "n_views 3 4"] = np.array([P for P in projection_all_list])

In [7]:
from simplecv.apis.view_exoego_data import create_blueprint
from simplecv.data.exoego.base_exo_ego import ExoData

exo_data: ExoData = next(iter(sequence))
parent_log_path:Path = Path("world")

exo_cam_log_paths: list[Path] = [parent_log_path / exo_cam.name for exo_cam in sequence.exo_cam_list]
exo_video_log_paths: list[Path] = [cam_log_paths / "pinhole" / "video" for cam_log_paths in exo_cam_log_paths]

blueprint = create_blueprint(exo_video_log_paths=exo_video_log_paths, num_videos_to_log=8)

In [8]:
rr.init("sequence data")
set_pose_annotation_context()


for hand_idx, (hand_side, color, class_id) in enumerate(
        (
            ("left", (0, 255, 0), 0),
            ("right", (0, 255, 0), 0),
        )
    ):
        xyz: Float32[ndarray, "21 3"] = exo_data.xyz[hand_idx]
        rr.log(
            hand_side,
            rr.Points3D(
                xyz,
                colors=color,
                class_ids=class_id,
                keypoint_ids=MEDIAPIPE_IDS,
                show_labels=False,
            ),
        )

for cam_param, bgr in zip(exo_data.cam_params_list, exo_data.bgr_list, strict=True):
    for hand_idx, (hand_side, color, class_id) in enumerate(
        (
            ("left", (0, 255, 0), 0),
            ("right", (0, 255, 0), 0),
        )
    ):
        xyz: Float32[ndarray, "21 3"] = exo_data.xyz[hand_idx]
        z_relative: Float32[ndarray, "21"] = xyz[:, 2] - xyz[9, 2]
        # np.nan is used to indicate that a joint is not confidently detected
        uv: Float32[ndarray, "21 2"] = exo_data.uv_dict[cam_param.name][hand_idx]
        # np.nan is used to indicate that a joint is not confidently detected
        uv[uv == -1] = np.nan

        uvz:Float32[ndarray, "21 3"] = np.concatenate([uv, z_relative[:, None]], axis=1)

        # find min/max to generate a bounding box, add a small padding
        uv_min:Float32[ndarray, "2"] = np.nanmin(uv, axis=0)
        uv_max:Float32[ndarray, "2"] = np.nanmax(uv, axis=0)

        uv_min = uv_min - 0.05
        uv_max = uv_max + 0.05

        xyxy:Float32[ndarray, "4"] = np.concatenate([uv_min, uv_max], axis=0)

        image_log_path: Path = parent_log_path / cam_param.name / "pinhole" / "video"
        rr.log(
            f"{image_log_path}/{hand_side}_keypoints",
            rr.Points2D(
                uv,
                colors=color,
                class_ids=class_id,
                keypoint_ids=MEDIAPIPE_IDS,
                show_labels=False,
            ),
        )

        rr.log(
                f"{image_log_path}/{hand_side}_xyxy",
                rr.Boxes2D(
                    array=xyxy,
                    array_format=rr.Box2DFormat.XYXY
                )
        )
        rr.log(
            f"{image_log_path}",
            rr.Image(
                bgr,
                color_model=rr.ColorModel.BGR
            ).compress(jpeg_quality=75),
        )

rr.notebook_show(width=1000, height=700, blueprint=blueprint)

Viewer()

If not, consider setting `RERUN_NOTEBOOK_ASSET`. Consult https://pypi.org/project/rerun-notebook/0.22.1/ for details.



In [9]:
from pi0_lerobot.mano.mano_optimization_jax import LMOptimJointOnly, LossWeights

loss_weights = LossWeights(
    keypoint_2d=0.1,
    depth=0.0,
    temp=0.0,
)
megatrack_fitter = LMOptimJointOnly(
    Pall=Pall,
    loss_weights=loss_weights,
    num_iters=30,
)



Tracing JIT, can take a while...
Trace Done


In [10]:
from jaxopt._src.levenberg_marquardt import LevenbergMarquardtState

from pi0_lerobot.mano.mano_optimization_jax import OptimizationResults

output: tuple[OptimizationResults, LevenbergMarquardtState] = megatrack_fitter(xyz_pred_batch=exo_data.xyz)
optimized_result: OptimizationResults = output[0]
state: LevenbergMarquardtState = output[1]

In [11]:
print(state._fields)
print(state.error)
print(state.iter_num)

('iter_num', 'damping_factor', 'increase_factor', 'residual', 'value', 'delta', 'error', 'gradient', 'jac', 'jt', 'jtj', 'hess_res', 'aux')
0.00043537063
4


In [None]:
rr.init("Refined sequence data")
set_pose_annotation_context()
rr.notebook_show(width=1000, height=700, blueprint=blueprint)
for hand_idx, (hand_side, color, class_id) in enumerate(
        (
            ("left", (0, 255, 0), 0),
            ("right", (0, 255, 0), 0),
        )
    ):
        xyz_pred: Float32[ndarray, "21 3"] = exo_data.xyz[hand_idx]
        match hand_side:
            case "left":
                xyz: Float[ndarray, "21 3"] = optimized_result.xyz_mano[0]
            case "right":
                xyz: Float[ndarray, "21 3"] = optimized_result.xyz_mano[1]
            case _: 
                raise ValueError(f"Invalid hand side: {hand_side}")
        rr.log(
            f"{hand_side}_optimized",
            rr.Points3D(
                xyz,
                colors=color,
                class_ids=1,
                keypoint_ids=MEDIAPIPE_IDS,
                show_labels=False,
            ),
        )
        rr.log(
            f"{hand_side}",
            rr.Points3D(
                xyz_pred,
                colors=color,
                class_ids=class_id,
                keypoint_ids=sequence.hand_ids,
                show_labels=False,
            ),
        )

Viewer()

If not, consider setting `RERUN_NOTEBOOK_ASSET`. Consult https://pypi.org/project/rerun-notebook/0.22.1/ for details.



## Lets get this working for the first 10 frames instead of just a single timestep

In [13]:
import itertools

rr.init("sequence data")
set_pose_annotation_context()
rr.notebook_show(width=1000, height=700, blueprint=blueprint)

exo_data: ExoData
for idx, exo_data in enumerate(itertools.islice(sequence, 50)):
    rr.set_time_sequence("frame idx",sequence=idx)
    for hand_idx, (hand_side, color, class_id) in enumerate(
            (
                ("left", (0, 255, 0), 0),
                ("right", (0, 255, 0), 0),
            )
        ):
            xyz: Float32[ndarray, "21 3"] = exo_data.xyz[hand_idx]
            rr.log(
                hand_side,
                rr.Points3D(
                    xyz,
                    colors=color,
                    class_ids=class_id,
                    keypoint_ids=MEDIAPIPE_IDS,
                    show_labels=False,
                ),
            )

    for cam_param, bgr in zip(exo_data.cam_params_list, exo_data.bgr_list, strict=True):
        for hand_idx, (hand_side, color, class_id) in enumerate(
            (
                ("left", (0, 255, 0), 0),
                ("right", (0, 255, 0), 0),
            )
        ):
            xyz: Float32[ndarray, "21 3"] = exo_data.xyz[hand_idx]
            z_relative: Float32[ndarray, "21"] = xyz[:, 2] - xyz[9, 2]
            # np.nan is used to indicate that a joint is not confidently detected
            uv: Float32[ndarray, "21 2"] = exo_data.uv_dict[cam_param.name][hand_idx]
            # np.nan is used to indicate that a joint is not confidently detected
            uv[uv == -1] = np.nan

            uvz:Float32[ndarray, "21 3"] = np.concatenate([uv, z_relative[:, None]], axis=1)

            # find min/max to generate a bounding box, add a small padding
            uv_min:Float32[ndarray, "2"] = np.nanmin(uv, axis=0)
            uv_max:Float32[ndarray, "2"] = np.nanmax(uv, axis=0)

            xyxy:Float32[ndarray, "4"] = np.concatenate([uv_min, uv_max], axis=0)
            if np.any(np.isnan(xyxy)):
                print(f"xyxy is nan: {xyxy}")
                continue

            image_log_path: Path = parent_log_path / cam_param.name / "pinhole" / "video"
            rr.log(
                f"{image_log_path}/{hand_side}_keypoints",
                rr.Points2D(
                    uv,
                    colors=color,
                    class_ids=class_id,
                    keypoint_ids=MEDIAPIPE_IDS,
                    show_labels=False,
                ),
            )

            rr.log(
                    f"{image_log_path}/{hand_side}_xyxy",
                    rr.Boxes2D(
                        array=xyxy,
                        array_format=rr.Box2DFormat.XYXY
                    )
            )
            rr.log(
                f"{image_log_path}",
                rr.Image(
                    bgr,
                    color_model=rr.ColorModel.BGR
                ).compress(jpeg_quality=75),
            )

Viewer()

If not, consider setting `RERUN_NOTEBOOK_ASSET`. Consult https://pypi.org/project/rerun-notebook/0.22.1/ for details.



xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is 

In [19]:
megatrack_fitter = LMOptimJointOnly(
    Pall=Pall,
    loss_weights=loss_weights,
    num_iters=50,
)



Tracing JIT, can take a while...
Trace Done


In [20]:
rr.init("Optimization over sequence data")
set_pose_annotation_context()
rr.notebook_show(width=1000, height=700, blueprint=blueprint)

exo_data: ExoData
for idx, exo_data in enumerate(itertools.islice(sequence, 100)):
    rr.set_time_sequence("frame idx",sequence=idx)
    for hand_idx, (hand_side, color, class_id) in enumerate(
            (
                ("left", (0, 255, 0), 0),
                ("right", (0, 255, 0), 0),
            )
        ):
            
            output: tuple[OptimizationResults, LevenbergMarquardtState] = megatrack_fitter(xyz_pred_batch=exo_data.xyz)
            optimized_result: OptimizationResults = output[0]
            state: LevenbergMarquardtState = output[1]
            xyz_pred: Float32[ndarray, "21 3"] = exo_data.xyz[hand_idx]
            rr.log("log", rr.TextLog(f"num_iters = {state.iter_num}"))
            match hand_side:
                case "left":
                    xyz: Float[ndarray, "21 3"] = optimized_result.xyz_mano[0]
                case "right":
                    xyz: Float[ndarray, "21 3"] = optimized_result.xyz_mano[1]
                case _: 
                    raise ValueError(f"Invalid hand side: {hand_side}")
            rr.log(
                f"{hand_side}_optimized",
                rr.Points3D(
                    xyz,
                    colors=color,
                    class_ids=1,
                    keypoint_ids=MEDIAPIPE_IDS,
                    show_labels=False,
                ),
            )
            rr.log(
                f"{hand_side}",
                rr.Points3D(
                    xyz_pred,
                    colors=color,
                    class_ids=class_id,
                    keypoint_ids=sequence.hand_ids,
                    show_labels=False,
                ),
            )

Viewer()