In [1]:
%load_ext ipython_beartype
%beartype

%load_ext autoreload
%autoreload 2

In [2]:
# Add these lines to the cell with id 34bb5824
import warnings

import jax.numpy as jnp
import numpy as np
import rerun as rr
from jaxtyping import Array, Float
from lovely_numpy import lo
from simplecv.data.exoego.skeleton.mediapipe import MEDIAPIPE_ID2NAME, MEDIAPIPE_IDS, MEDIAPIPE_LINKS

from pi0_lerobot.mano.mano_optimization_jax import HandSide

warnings.filterwarnings('ignore', category=RuntimeWarning)




In [3]:

def set_pose_annotation_context() -> None:
    rr.log(
        "/",
        rr.AnnotationContext(
            [
                rr.ClassDescription(
                    info=rr.AnnotationInfo(id=0, label="Triangulated Hand", color=(0, 0, 255)),
                    keypoint_annotations=[
                        rr.AnnotationInfo(id=id, label=name) for id, name in MEDIAPIPE_ID2NAME.items()
                    ],
                    keypoint_connections=MEDIAPIPE_LINKS,
                ),
                rr.ClassDescription(
                    info=rr.AnnotationInfo(id=1, label="Optimized Hand", color=(255, 0, 0)),
                    keypoint_annotations=[
                        rr.AnnotationInfo(id=id, label=name) for id, name in MEDIAPIPE_ID2NAME.items()
                    ],
                    keypoint_connections=MEDIAPIPE_LINKS,
                ),
            
            ]
        ),
        static=True,
    )

In [4]:
from pathlib import Path

from jaxtyping import Float32
from numpy import ndarray
from rerun.blueprint import Blueprint
from simplecv.apis.view_exoego_data import create_blueprint
from simplecv.camera_parameters import PinholeParameters
from simplecv.data.exoego.base_exo_ego import ExoData
from simplecv.data.exoego.hocap import HOCapSequence
from simplecv.rerun_log_utils import log_pinhole

data_path = Path("../../data/hocap/sample")
assert data_path.exists(), f"Data path {data_path} does not exist, please check the path."
sequence: HOCapSequence = HOCapSequence(
    data_path=data_path,
    sequence_name="20231024_180733",
    subject_id="8",
    load_labels=True,
)

parent_log_path: Path = Path("world")

exo_cam_log_paths: list[Path] = [parent_log_path / exo_cam.name for exo_cam in sequence.exo_cam_list]
exo_video_log_paths: list[Path] = [cam_log_paths / "pinhole" / "video" for cam_log_paths in exo_cam_log_paths]

blueprint: Blueprint = create_blueprint(exo_video_log_paths=exo_video_log_paths, num_videos_to_log=8)

Found HoloLens camera: hololens_kv5h72


Loading videos: 100%|██████████| 8/8 [00:00<00:00, 26153.10it/s]
Loading 3D labels: 100%|██████████| 8/8 [00:02<00:00,  3.82it/s]
Indexing depth images: 100%|██████████| 8/8 [00:00<00:00, 448.35it/s]


## Lets get this working for the first 10 frames

In [5]:
import itertools
from typing import Literal

rr.init("sequence data")
set_pose_annotation_context()
rr.notebook_show(width=1000, height=700, blueprint=blueprint)


exo_data: ExoData
for idx, exo_data in enumerate(itertools.islice(sequence, 5)):
    rr.set_time_sequence("frame idx",sequence=idx)
    for hand_enum in HandSide:
        hand_side: Literal["left", "right"] = hand_enum.name.lower() # Get "left" or "right"
        hand_idx: int = hand_enum.value
        
        xyz_gt: Float32[ndarray, "21 3"] = exo_data.xyz[hand_idx]
        rr.log(
            hand_side,
            rr.Points3D(
                xyz_gt,
                colors=(0, 255, 0),
                class_ids=0,
                keypoint_ids=MEDIAPIPE_IDS,
                show_labels=False,
            ),
        )

    for cam_param, bgr in zip(exo_data.cam_params_list, exo_data.bgr_list, strict=True):
        cam_log_path: Path = parent_log_path / cam_param.name   
        image_log_path: Path = cam_log_path / "pinhole" / "video"
        log_pinhole(
                    camera=cam_param,
                    cam_log_path=cam_log_path,
                    image_plane_distance=0.1,
                    static=False
                )
        rr.log(
            f"{image_log_path}",
            rr.Image(
                bgr,
                color_model=rr.ColorModel.BGR
            ).compress(jpeg_quality=15),
        )
        for hand_enum in HandSide:
            hand_side: Literal["left", "right"] = hand_enum.name.lower() # Get "left" or "right"
            hand_idx: int = hand_enum.value

            xyz_gt: Float32[ndarray, "21 3"] = exo_data.xyz[hand_idx]
            # np.nan is used to indicate that a joint is not confidently detected
            uv: Float32[ndarray, "21 2"] = exo_data.uv_dict[cam_param.name][hand_idx]
            # np.nan is used to indicate that a joint is not confidently detected
            uv[uv == -1] = np.nan

            # find min/max to generate a bounding box, add a small padding
            uv_min:Float32[ndarray, "2"] = np.nanmin(uv, axis=0)
            uv_max:Float32[ndarray, "2"] = np.nanmax(uv, axis=0)

            xyxy:Float32[ndarray, "4"] = np.concatenate([uv_min, uv_max], axis=0)
            # if nan skip logging keypoints and bboxes
            if np.any(np.isnan(xyxy)):
                continue

            rr.log(
                f"{image_log_path}/{hand_side}_keypoints",
                rr.Points2D(
                    uv,
                    colors=(0, 255, 0),
                    class_ids=0,
                    keypoint_ids=MEDIAPIPE_IDS,
                    show_labels=False,
                ),
            )
            rr.log(
                    f"{image_log_path}/{hand_side}_xyxy",
                    rr.Boxes2D(
                        array=xyxy,
                        array_format=rr.Box2DFormat.XYXY,
                        labels=[hand_side],
                    )
            )

Viewer()

If not, consider setting `RERUN_NOTEBOOK_ASSET`. Consult https://pypi.org/project/rerun-notebook/0.22.1/ for details.



# Vectorized Projection using batched matrix multiplication with jax

In [6]:
import jax.numpy as npj
from einops import rearrange
from jaxtyping import Float64, Int

from pi0_lerobot.mano.mano_optimization_jax import proj_3d_vectorized

# Calculate indices for 5 equally spaced frames
num_frames_to_select:int = 10
total_frames: int = len(sequence)
# Use np.linspace to get equally spaced indices, including start and end
frame_indices: Int[ndarray, "_"] = np.linspace(0, total_frames - 1, num_frames_to_select).astype(int)
frame_indices: list[int] = frame_indices.tolist()

# Get the ExoData for the selected frames
exo_data_list: list[ExoData] = [sequence[frame_idx] for frame_idx in frame_indices]

# create a batch of xyz points
xyz_batch: list[Float[ndarray, "2 21 3"]] = [exo_data.xyz for exo_data in exo_data_list]
xyz_batch: Float[ndarray, "n_frames 2 21 3"] = np.stack(xyz_batch, axis=0)
xyz_hom_batch: Float[ndarray, "n_frames 2 21 4"] = np.concatenate([xyz_batch, np.ones_like(xyz_batch)[..., 0:1]], axis=-1)

left_xyz_hom_batch: Float[ndarray, "n_frames 21 4"] = xyz_hom_batch[:, 0]
right_xyz_hom_batch: Float[ndarray, "n_frames 21 4"] = xyz_hom_batch[:, 1]

# Get the Projection matrices for the selected frame
exo_cam_list: list[PinholeParameters] = sequence.exo_cam_list
projection_all_list: list[Float64[np.ndarray, "3 4"]] = [exo_cam.projection_matrix for exo_cam in exo_cam_list]
Pall: Float32[np.ndarray, "n_views 3 4"] = np.array([P for P in projection_all_list]).astype(np.float32)

# create a batch of uv points in vectorized form
uv_batch_left:Float[Array, "n_frames n_views 21 2"] = proj_3d_vectorized(
    xyz_hom=npj.array(left_xyz_hom_batch),
    P=npj.array(Pall)
)
uv_batch_right:Float[Array, "n_frames n_views 21 2"] = proj_3d_vectorized(
    xyz_hom=npj.array(right_xyz_hom_batch),
    P=npj.array(Pall)
)
uv_batch: Float[ndarray, "n_frames n_views 2 21 2"] = np.stack([uv_batch_left, uv_batch_right], axis=2)

## Visualize vectorized projections to make sure things are correct

In [7]:
rr.init("Random Sequence")
set_pose_annotation_context()
rr.notebook_show(width=1000, height=700, blueprint=blueprint)

for frame_idx,  exo_data in enumerate(exo_data_list):
    rr.set_time_sequence("frame idx", sequence=frame_idx)
    for hand_enum in HandSide:
            hand_side: Literal["left", "right"] = hand_enum.name.lower()
            hand_idx: int = hand_enum.value
            class_id = 0

            xyz_gt: Float32[ndarray, "21 3"] = exo_data.xyz[hand_idx]
            uv_views: Float[ndarray, "n_views 21 2"] = uv_batch[frame_idx, :, hand_idx, ...]
            rr.log(
                hand_side,
                rr.Points3D(
                    xyz_gt,
                    colors=(0, 255, 0),
                    class_ids=class_id,
                    keypoint_ids=MEDIAPIPE_IDS,
                    show_labels=False,
                ),
            )
            for cam_param, bgr, uv in zip(exo_data.cam_params_list, exo_data.bgr_list, uv_views, strict=True): 
                cam_log_path: Path = parent_log_path / cam_param.name   
                image_log_path: Path = cam_log_path / "pinhole" / "video"

                rr.log(
                    f"{image_log_path}",
                    rr.Image(
                        bgr,
                        color_model=rr.ColorModel.BGR
                    ).compress(jpeg_quality=75),
                )
                rr.log(
                    f"{image_log_path}/{hand_side}_keypoints",
                    rr.Points2D(
                        uv,
                        colors=(0, 255, 0),
                        class_ids=class_id,
                        keypoint_ids=MEDIAPIPE_IDS,
                        show_labels=False,
                    ),
                )

                log_pinhole(
                    camera=cam_param,
                    cam_log_path=cam_log_path,
                    image_plane_distance=0.1,
                    static=False
                )


Viewer()

If not, consider setting `RERUN_NOTEBOOK_ASSET`. Consult https://pypi.org/project/rerun-notebook/0.22.1/ for details.



## Example optimization using 3d keypoints

In [8]:
import jax
import jax.numpy as npj
from jax import jit
from jaxopt import LevenbergMarquardt
from jaxopt._src.levenberg_marquardt import LevenbergMarquardtState
from jaxtyping import Int

from pi0_lerobot.mano.mano_optimization_jax import LMOptimJointOnly, LossWeights, OptimizationResults

loss_weights = LossWeights(
    keypoint_2d=0.1,
    depth=0.0,
    temp=0.0,
)
megatrack_fitter = LMOptimJointOnly(
    Pall=Pall,
    loss_weights=loss_weights,
    num_iters=30,
)



Tracing JIT, can take a while...
Trace Done


In [9]:
rr.init("Optimization over sequence data")
set_pose_annotation_context()
rr.notebook_show(width=1200, height=700, blueprint=blueprint)

exo_data: ExoData
for idx, exo_data in enumerate(itertools.islice(sequence, 300)):
    rr.set_time_sequence("frame idx",sequence=idx)

    for hand_enum in HandSide:
        hand_side: Literal["left", "right"] = hand_enum.name.lower()
        hand_idx: int = hand_enum.value
        
        output: tuple[OptimizationResults, LevenbergMarquardtState] = megatrack_fitter(xyz_pred_batch=exo_data.xyz)
        optimized_result: OptimizationResults = output[0]
        state: LevenbergMarquardtState = output[1]

        xyz_pred: Float32[ndarray, "21 3"] = exo_data.xyz[hand_idx]
        rr.log("log", rr.TextLog(f"num_iters = {state.iter_num}"))

        match hand_side:
            case "left":
                xyz: Float[ndarray, "21 3"] = optimized_result.xyz_mano[0]
                xyz_template_left = xyz_pred - xyz_pred[0:1, :]
            case "right":
                xyz: Float[ndarray, "21 3"] = optimized_result.xyz_mano[1]
                xyz_template_right = xyz_pred - xyz_pred[0:1, :]
            case _: 
                raise ValueError(f"Invalid hand side: {hand_side}")
            
        rr.log(
            f"{hand_side}",
            rr.Points3D(
                xyz_pred,
                colors=(0, 255, 0),
                class_ids=0,
                keypoint_ids=sequence.hand_ids,
                show_labels=False,
            ),
        )

        rr.log(
            f"{hand_side}_optimized",
            rr.Points3D(
                xyz,
                colors=(0, 255, 0),
                class_ids=1,
                keypoint_ids=MEDIAPIPE_IDS,
                show_labels=False,
            ),
        )

# log the final template for left/right

rr.log(
    "left_template",
    rr.Points3D(
        xyz_template_left,
        colors=(0, 255, 0),
        class_ids=1,
        keypoint_ids=MEDIAPIPE_IDS,
        show_labels=False,
    ),
)
rr.log(
    "right_template",
    rr.Points3D(
        xyz_template_right,
        colors=(0, 255, 0),
        class_ids=1,
        keypoint_ids=MEDIAPIPE_IDS,
        show_labels=False,
    ),
)


Viewer()

If not, consider setting `RERUN_NOTEBOOK_ASSET`. Consult https://pypi.org/project/rerun-notebook/0.22.1/ for details.



## Example Optimization on a single timestep using multiview 2d keypoints

In [10]:
from collections.abc import Callable

from jaxtyping import Bool

from pi0_lerobot.mano.mano_jax import JointsOnly, mp_to_mano

FwdKinematics = Callable[
    [Float[Array, "b 48"], Float[Array, "b 1 3"], Float[Array, "1 1"]],
    Float[Array, "b 21 3"],
]

# The residual you’ll hand to jaxopt
ResidualFn = Callable[
    [Float[Array, "_"],               # flattened params + scale
     Float[Array, "b 3 4"],           # Pall
     Float[Array, "b n_views 21 2"],  # uv_pred
     "LossWeights",
     bool | Bool[Array, ""],],
    Float[Array, "_"],                # flat residual vector
]


def make_mv_scaled_residual(
    xyz_template_left:  Float[Array, "21 3"],
    xyz_template_right: Float[Array, "21 3"],
) -> tuple[ResidualFn, FwdKinematics, FwdKinematics]:
    """
    Returns a JIT-compiled residual function that can be dropped straight into
    `jaxopt.LevenbergMarquardt`.  No globals leak out – the MANO forward
    functions are closed over the templates you pass in *once*.

    Example
    -------
    >>> mv_scaled_residual = make_mv_scaled_residual(
    ...     xyz_template_left , xyz_template_right
    ... )
    >>> solver = LevenbergMarquardt(residual_fun=mv_scaled_residual, ...)
    """

    # ------------------------------------------------------------------
    # build per-hand forward kinematics (static because templates are constant)
    # ------------------------------------------------------------------
    mano_fwd_left  = jit(JointsOnly(template_joints=xyz_template_left[mp_to_mano, :]))
    mano_fwd_right = jit(JointsOnly(template_joints=xyz_template_right[mp_to_mano, :]))

    # ------------------------------------------------------------------
    # residual – declared once, re-used frame-to-frame
    # ------------------------------------------------------------------
    @jit
    def mv_2d_scaled_residual(
        param_to_optimize: Float[Array, "_"],
        Pall: Float[Array, "b 3 4"],
        uv_pred: Float[Array, "b n_views 21 2"],
        loss_weights: LossWeights,
        is_left:bool | Bool[Array, ""],
    ) -> Float[Array, "_"]:
        """
        Calculates the residual error between projected MANO keypoints and target 2D keypoints.

        Args:
            param_to_optimize: Flattened MANO parameters (pose coefficients and translation).
                            Must be a 1D array (batch_size * 51) because jaxopt optimizers
                            like LevenbergMarquardt expect a flat vector of parameters.
            Pall: Projection matrices for each camera view, shape (b, 3, 4).
                'b' here refers to the batch size (number of frames/samples).
            uv_pred: Target 2D keypoints for each view and joint, shape (b, n_views, 21, 2).
                    'n_views' is the number of camera views.
            loss_weights: Dictionary containing weights for different loss components (e.g., 'keypoint_2d').
            is_left: Boolean indicating whether to use the left or right MANO model.

        Returns:
            A flattened 1D array containing the weighted residual errors for all keypoints, views, and batch items.
        """
        batch_size: int = uv_pred.shape[0]
        # extract parameters that are being optimized and add batch dimension
        scale_param: Float[Array, ""] = param_to_optimize[-1]
        scale_param: Float[Array, "1 1"] = scale_param.reshape(1, 1)
        param_to_optimize: Float[Array, "_"] = param_to_optimize[:-1] # 
        param_to_optimize: Float[Array, "1 51"] = param_to_optimize.reshape(batch_size, 51)

        so3: Float[Array, "b 48"] = param_to_optimize[:, 0:48]
        trans: Float[Array, "b 1 3"] = param_to_optimize[:, npj.newaxis, 48:51]

        def left_func(x:tuple[Float[Array, "b 48"], Float[Array, "b 1 3"], Float[Array, "1 1"]])->Float[Array, "b 21 3"]:
            return mano_fwd_left(x[0], x[1], x[2])

        def right_func(x:tuple[Float[Array, "b 48"], Float[Array, "b 1 3"], Float[Array, "1 1"]])->Float[Array, "b 21 3"]:
            return mano_fwd_right(x[0], x[1], x[2])

        xyz_mano: Float[Array, "b 21 3"] = jax.lax.cond(is_left, left_func, right_func, (so3, trans, scale_param))
        xyz_mano_hom: Float[Array, "b 21 4"] = npj.concatenate([xyz_mano, npj.ones_like(xyz_mano)[..., 0:1]], axis=-1)

        uv_mano: Float[Array, "b n_views 21 2"] = proj_3d_vectorized(
            xyz_hom=xyz_mano_hom,
            P=Pall
        )

        # calculate residuals
        res_2d: Float[Array, "b n_views 21 2"] = uv_mano - uv_pred
        res_2d = jnp.nan_to_num(res_2d * loss_weights["keypoint_2d"], nan=0.0)


        # Return the flattened vector of valid, weighted residuals
        return res_2d.flatten()

    return mv_2d_scaled_residual, mano_fwd_left, mano_fwd_right

In [11]:
loss_weights = LossWeights(
    keypoint_2d=1.0,
    depth=0.0,
    temp=0.0,
)

residual_fn, mano_fwd_left, mano_fwd_right = make_mv_scaled_residual(xyz_template_left=npj.array(xyz_template_left), xyz_template_right=npj.array(xyz_template_right))
optimizer = LevenbergMarquardt(
    residual_fun=residual_fn, maxiter=1000, solver="cholesky", jit=True, xtol=1e-8, gtol=1e-8
)

exo_data: ExoData = sequence[0]
P_list: list[ndarray] = [exo_cam.projection_matrix for exo_cam in exo_data.cam_params_list]
uv_list: list[ndarray] = [exo_data.uv_dict[cam_param.name] for cam_param in exo_data.cam_params_list]
P_all: Float[Array, "n_views 3 4"] = npj.stack(P_list, axis=0)
uv_gt: Float[Array, "n_views 2 21 2"] = npj.stack(uv_list, axis=0)
uv_left_gt_batch: Float[Array, "n_views 21 2"] = uv_gt[:, 0, ...]
uv_left_gt_batch: Float[Array, "n_frames n_views 21 2"] = rearrange(uv_left_gt_batch, "n_views n_joints uv -> 1 n_views n_joints uv")

init_params: Float[Array, "b 51"] = npj.zeros((1, 51))
scale_param: Float[Array, "1"] = npj.array([1.0])
init_params = npj.concatenate([init_params.flatten(), scale_param], axis=0)

# add jit
print("Tracing JIT, can take a while...")
_, _ = optimizer.run(
    init_params.flatten(),
    Pall=P_all,
    uv_pred=uv_left_gt_batch,
    loss_weights=loss_weights,
    is_left=True,
)
optimizer = jit(optimizer.run)



Tracing JIT, can take a while...


In [12]:
# optimize 5 frames of the left hand
from simplecv.camera_parameters import PinholeParameters

num_samples = 1
exo_data_list: list[ExoData] = [sequence[frame_idx] for frame_idx in range(num_samples)]

# create batch of frames and views of uv
uv_batch_list: list[Float[ndarray, "n_views 2 21 2"]] = []
for exo_data in exo_data_list:
    uv_list: list[ndarray] = [exo_data.uv_dict[cam_param.name] for cam_param in exo_data.cam_params_list]
    uv_gt: Float[Array, "n_views 2 21 2"] = npj.stack(uv_list, axis=0)
    uv_batch_list.append(uv_gt)

uv_batch: Float[Array, "n_frames n_views 2 21 2"] = npj.stack(uv_batch_list, axis=0)
P_list: list[Float[ndarray, "3 4"]] = [exo_cam.projection_matrix for exo_cam in exo_data_list[0].cam_params_list]
Pall: Float[Array, "n_views 3 4"] = npj.stack(P_list, axis=0)

uv_left_gt_batch: Float[Array, "n_frames n_views 21 2"] = uv_batch[:, :, 0, ...]

rr.init("Random Sequence")
set_pose_annotation_context()
rr.notebook_show(width=1000, height=700, blueprint=blueprint)

for frame_idx in range(num_samples):
    rr.set_time_sequence("frame idx", sequence=frame_idx)
    
    uv_gt_left: Float[Array, "n_views 21 2"] = uv_left_gt_batch[frame_idx, ...]
    cam_param_list: list[PinholeParameters] = exo_data_list[0].cam_params_list
    bgr_list: list[ndarray] = exo_data_list[frame_idx].bgr_list

    for uv_left, cam_param, bgr in zip(uv_gt_left, cam_param_list, bgr_list, strict=True):
        cam_log_path: Path = parent_log_path / cam_param.name   
        image_log_path: Path = cam_log_path / "pinhole" / "video"
        rr.log(
            f"{image_log_path}",
            rr.Image(
                bgr,
                color_model=rr.ColorModel.BGR
            ).compress(jpeg_quality=75),
        )
        rr.log(
            f"{image_log_path}/{hand_side}_keypoints",
            rr.Points2D(
                uv_left,
                colors=(0, 255, 0),
                class_ids=0,
                keypoint_ids=MEDIAPIPE_IDS,
                show_labels=False,
            ),
        )
        log_pinhole(
            camera=cam_param,
            cam_log_path=cam_log_path,
            image_plane_distance=0.1,
            static=False
        )

batch_size: int = uv_left_gt_batch.shape[0]
so3_init: Float[Array, "b 48"] = npj.zeros((batch_size, 48))
trans_init: Float[Array, "b 3"] = npj.zeros((batch_size, 3))
init_params: Float[Array, "b 51"] = npj.concatenate([so3_init, trans_init], axis=-1)
scale_init: Float[Array, "1"] = npj.ones((1))
init_params = npj.concatenate([init_params.flatten(), scale_init], axis=0)


optimized_params, state = optimizer(
    init_params,
    Pall=P_all,
    uv_pred=uv_left_gt_batch,
    loss_weights=loss_weights,
    is_left=True
)
optimized_scale: Float[Array, ""] = optimized_params[-1]
optimized_params: Float[Array, "_"] = optimized_params[:-1]
optimized_params: Float[Array, "b 51"] = optimized_params.reshape(batch_size, 51)
optimized_so3: Float[Array, "b 48"] = optimized_params[:, 0:48]
optimized_trans: Float[Array, "b 3"] = optimized_params[:, 48:51]

print(optimized_scale)

Viewer()

If not, consider setting `RERUN_NOTEBOOK_ASSET`. Consult https://pypi.org/project/rerun-notebook/0.22.1/ for details.



1.0027553


In [13]:
xyz_mano_left_init: Float[Array, "1 21 3"] = mano_fwd_left(so3_init, trans_init[:, npj.newaxis, :], scale_init[npj.newaxis, ...])
xyz_mano_left_batch: Float[Array, "n_frames 21 3"] = mano_fwd_left(optimized_so3, optimized_trans[:, npj.newaxis, :], optimized_scale[npj.newaxis, npj.newaxis, ...])
xyz_gt_left_batch: Float[Array, "n_frames 21 3"] = npj.stack([exo_data.xyz[0] for exo_data in exo_data_list], axis=0)
rr.init("Optimization over sequence data")
set_pose_annotation_context()
rr.notebook_show(width=1000, height=700, blueprint=blueprint)

for idx, (xyz_gt_left, xyz_mano_left, xyz_mano_init) in enumerate(zip(xyz_gt_left_batch, xyz_mano_left_batch,xyz_mano_left_init, strict=True)):
    rr.set_time_sequence("frame idx",sequence=idx)
    rr.log(
        f"{hand_side}_init",
        rr.Points3D(
            xyz_mano_init,
            colors=(0, 255, 0),
            class_ids=1,
            keypoint_ids=MEDIAPIPE_IDS,
            show_labels=False,
        ),
    )

    rr.log(
        f"{hand_side}_optimized",
        rr.Points3D(
            xyz_mano_left,
            colors=(0, 255, 0),
            class_ids=1,
            keypoint_ids=MEDIAPIPE_IDS,
            show_labels=False,
        ),
    )
    rr.log(
        f"{hand_side}",
        rr.Points3D(
            xyz_gt_left,
            colors=(0, 255, 0),
            class_ids=class_id,
            keypoint_ids=sequence.hand_ids,
            show_labels=False,
        ),
    )
    

Viewer()

If not, consider setting `RERUN_NOTEBOOK_ASSET`. Consult https://pypi.org/project/rerun-notebook/0.22.1/ for details.



In [None]:
from dataclasses import dataclass


@dataclass
class NewOptimizationResults:
    """
    Stores results from bounding box detection and hand pose estimation
    """

    xyz_mano: Float[ndarray, "2 21 3"]
    so3: Float[ndarray, "2 48"]
    trans: Float[ndarray, "2 3"]

class JointAndScaleOptimization:
    def __init__(
        self,
        xyz_template: Float[ndarray, "2 21 3"],
        Pall: Float[ndarray, "n_views 3 4"],
        loss_weights: LossWeights,
        num_iters: int = 30,
    ) -> None:
        """
        Pall - n, 3, 4 projection matrix
        loss_weights - dictionary containing how much value to give each portion
            of the cost function (2d, 3d, temporal)
        num_iters - how many iterations to optimize
        """

        batch_size = 1
        assert batch_size == 1, "Batch size must be 1 for this optimization"

        n_views: int = Pall.shape[0]

        self.num_iters: int = num_iters
        # Projection Matrix (n, 3, 4) where n is the number of cameras
        self.Pall: Float[Array, "batch 3 4"] = npj.array(Pall)

        self.loss_weights: LossWeights = loss_weights
        # use previous values to initialize, there should only ever be 1
        # hand model per frame
        self.so3_left_prev: Float[Array, "1 48"] = npj.zeros((1, 48))
        self.trans_left_prev: Float[Array, "1 3"] = npj.zeros((1, 3))

        self.so3_right_prev: Float[Array, "1 48"] = npj.zeros((1, 48))
        self.trans_right_prev: Float[Array, "1 3"] = npj.zeros((1, 3))

        # scale parameter is shared between left and right hand
        self.scale_init: Float[Array, "1"] = npj.ones((1))  # noqa UP037

        output_fns: tuple[ResidualFn, FwdKinematics, FwdKinematics] = make_mv_scaled_residual(
            xyz_template_left=npj.array(xyz_template[0]), xyz_template_right=npj.array(xyz_template[1])
        )

        residual_fn: ResidualFn = output_fns[0]
        self.mano_fwd_left: FwdKinematics = output_fns[1]
        self.mano_fwd_right: FwdKinematics = output_fns[2]

        # remove the need for two different optimizers, solvers ‘cholesky’, ‘inv’
        self.optimizer = LevenbergMarquardt(
            residual_fun=residual_fn, maxiter=self.num_iters, solver="cholesky", jit=True, xtol=1e-6, gtol=1e-6
        )
        # add jit
        print("Tracing JIT, can take a while...")
        init_params: Float[Array, "1 51"] = npj.concatenate([so3_init, trans_init], axis=-1)
        init_params = npj.concatenate([init_params.flatten(), self.scale_init], axis=0)

        uv_batch_init: Float[Array, "n_frames n_views 21 2"] = npj.zeros((1, n_views, 21, 2))
        _, _ = self.optimizer.run(
            init_params.flatten(),
            Pall=npj.array(P_all),
            uv_pred=uv_batch_init,
            loss_weights=loss_weights,
            is_left=True,
        )
        self.optimizer = jit(self.optimizer.run)

        print("Trace Done")

    def __call__(
        self,
        uv_left_pred_batch: Float[ndarray, "n_views 21 2"],
        uv_right_pred_batch: Float[ndarray, "n_views 21 2"],
        calibrate: bool = False,
    ) -> tuple[NewOptimizationResults, LevenbergMarquardtState]:
        """
        pose_predictions_dict
            pose_predictions
            camera_dict
        """
        so3_optimized: Float[ndarray, "2 48"] = np.zeros((2, 48))
        trans_optimized: Float[ndarray, "2 3"] = np.zeros((2, 3))
        xyz_mano: Float[ndarray, "2 21 3"] = np.zeros((2, 21, 3))

        for hand_enum in HandSide:
            hand_side: Literal["left", "right"] = hand_enum.name.lower()
            hand_idx: int = hand_enum.value
            # get previous values, and extract pose_predictions
            match hand_side:
                case "left":
                    so3_prev: Float[Array, "1 48"] = self.so3_left_prev.copy()
                    trans_prev: Float[Array, "1 3"] = self.trans_left_prev.copy()
                    uv_pred_batch:Float[Array, "1 n_views 21 2"] = npj.array(uv_left_pred_batch)[npj.newaxis, ...]

                case "right":
                    so3_prev: Float[Array, "1 48"] = self.so3_right_prev.copy()
                    trans_prev: Float[Array, "1 3"] = self.trans_right_prev.copy()
                    uv_pred_batch:Float[Array, "1 n_views 21 2"] = npj.array(uv_right_pred_batch)[npj.newaxis, ...]

            # TODO initialize only rotation from wrist form either 3d procustus or mano preds
            so3_init: Float[Array, "1 48"] = so3_prev
            trans_init: Float[Array, "1 3"] = trans_prev

            init_params: Float[Array, "1 51"] = npj.concatenate([so3_init, trans_init], axis=-1)
            init_params: Float[Array, "_"] = npj.concatenate([init_params.flatten(), self.scale_init], axis=0)

            optimized_params, state = self.optimizer(
                init_params,
                Pall=self.Pall,
                uv_pred=uv_pred_batch,
                loss_weights=self.loss_weights,
                is_left=hand_side == "left",
            )

            # if np.isnan(optimized_params).any():
            #     continue
            
            optimized_scale: Float[Array, ""] = optimized_params[-1]
            optimized_params: Float[Array, "_"] = optimized_params[:-1]
            optimized_params: Float[Array, "1 51"] = optimized_params.reshape(batch_size, 51)

            so3: Float[Array, "1 48"] = optimized_params[:, 0:48]
            trans: Float[Array, "1 3"] = optimized_params[:, 48:51]

            so3_optimized[0 if hand_side == "left" else 1] = np.array(so3[0])
            trans_optimized[0 if hand_side == "left" else 1] = np.array(trans[0])

            # pass optimized values to mano to extract 3d joints
            match hand_side:
                case "left":
                    self.so3_left_prev = so3
                    self.trans_left_prev = trans

                    xyz_mano_left: Float[Array, "1 21 3"] = self.mano_fwd_left(
                        so3, trans[:, npj.newaxis, :]
                    )
                    xyz_mano[0] = np.array(xyz_mano_left[0])

                case "right":
                    self.so3_right_prev = so3
                    self.trans_right_prev = trans

                    xyz_mano_right: Float[Array, "1 21 3"] = self.mano_fwd_right(
                        so3, trans[:, npj.newaxis, :]
                    )
                    xyz_mano[1] = np.array(xyz_mano_right[0])

        optimization_results = NewOptimizationResults(
            xyz_mano=xyz_mano,
            so3=so3_optimized,
            trans=trans_optimized,
        )

        return optimization_results, state

P_list: list[Float[ndarray, "3 4"]] = [exo_cam.projection_matrix for exo_cam in exo_data_list[0].cam_params_list]
Pall: Float[ndarray, "n_views 3 4"] = np.stack(P_list, axis=0)

joint_and_scale_optimizer = JointAndScaleOptimization(
    xyz_template=np.stack([xyz_template_left, xyz_template_right], axis=0),
    Pall=Pall,
    loss_weights=loss_weights,
    num_iters=30)



Tracing JIT, can take a while...
Trace Done


In [16]:
rr.init("Optimization over sequence data")
set_pose_annotation_context()
rr.notebook_show(width=1200, height=700, blueprint=blueprint)

exo_data: ExoData
for idx, exo_data in enumerate(sequence):
    rr.set_time_sequence("frame idx",sequence=idx)

    uv_batch_left: None
    uv_batch_right: None
    for hand_enum in HandSide:
        hand_side: Literal["left", "right"] = hand_enum.name.lower()
        hand_idx: int = hand_enum.value

        xyz_gt: Float32[ndarray, "21 3"] = exo_data.xyz[hand_idx]

        # create a batch of xyz points
        xyz_gt_hom: Float[ndarray, "21 4"] = np.concatenate([xyz_gt, np.ones_like(xyz_gt)[..., 0:1]], axis=-1)
        xyz_gt_hom: Float[Array, "1 21 4"] = npj.array(xyz_gt_hom)[npj.newaxis, ...]

        # create a batch of uv points in vectorized form
        uv_batch:Float[Array, "1 n_views 21 2"] = proj_3d_vectorized(
            xyz_hom=xyz_gt_hom,
            P=npj.array(Pall)
        )
        match hand_side:
            case "left":
                uv_batch_left = uv_batch
            case "right":
                uv_batch_right = uv_batch
            case _:
                raise ValueError(f"Invalid hand side: {hand_side}")
            
        rr.log(
            f"{hand_side}",
            rr.Points3D(
                xyz_gt,
                colors=(0, 255, 0),
                class_ids=0,
                keypoint_ids=sequence.hand_ids,
                show_labels=False,
            ),
        )

    optim_out: tuple[NewOptimizationResults, LevenbergMarquardtState] = joint_and_scale_optimizer(
        uv_left_pred_batch=np.array(uv_batch_left)[0],
        uv_right_pred_batch=np.array(uv_batch_right)[0]
    )
    optimized_result: NewOptimizationResults = optim_out[0]

    for hand_enum in HandSide:
        hand_side: Literal["left", "right"] = hand_enum.name.lower()
        hand_idx: int = hand_enum.value

        match hand_side:
            case "left":
                xyz_optim = optimized_result.xyz_mano[0]
            case "right":
                xyz_optim = optimized_result.xyz_mano[1]
            case _:
                raise ValueError(f"Invalid hand side: {hand_side}")
            
        rr.log(
            f"{hand_side}_optim",
            rr.Points3D(
                xyz_optim,
                colors=(0, 255, 0),
                class_ids=1,
                keypoint_ids=sequence.hand_ids,
                show_labels=False,
            ),
        )


Viewer()

In [22]:
from pi0_lerobot.mano.mano_optimization_jax import JointAndScaleOptimization, OptimizationResults

P_list: list[Float[ndarray, "3 4"]] = [exo_cam.projection_matrix for exo_cam in exo_data_list[0].cam_params_list]
Pall: Float[ndarray, "n_views 3 4"] = np.stack(P_list, axis=0)

new_joint_and_scale_optimizer = JointAndScaleOptimization(
    xyz_template=np.stack([xyz_template_left, xyz_template_right], axis=0),
    Pall=Pall,
    loss_weights=loss_weights,
    num_iters=30)

rr.init("Optimization over sequence data")
set_pose_annotation_context()
rr.notebook_show(width=1200, height=700, blueprint=blueprint)

exo_data: ExoData
for idx, exo_data in enumerate(sequence):
    rr.set_time_sequence("frame idx",sequence=idx)

    uv_batch_left: None
    uv_batch_right: None
    for hand_enum in HandSide:
        hand_side: Literal["left", "right"] = hand_enum.name.lower()
        hand_idx: int = hand_enum.value

        xyz_gt: Float32[ndarray, "21 3"] = exo_data.xyz[hand_idx]

        # create a batch of xyz points
        xyz_gt_hom: Float[ndarray, "21 4"] = np.concatenate([xyz_gt, np.ones_like(xyz_gt)[..., 0:1]], axis=-1)
        xyz_gt_hom: Float[Array, "1 21 4"] = npj.array(xyz_gt_hom)[npj.newaxis, ...]

        # create a batch of uv points in vectorized form
        uv_batch:Float[Array, "1 n_views 21 2"] = proj_3d_vectorized(
            xyz_hom=xyz_gt_hom,
            P=npj.array(Pall)
        )
        match hand_side:
            case "left":
                uv_batch_left = uv_batch
            case "right":
                uv_batch_right = uv_batch
            case _:
                raise ValueError(f"Invalid hand side: {hand_side}")
            
        rr.log(
            f"{hand_side}",
            rr.Points3D(
                xyz_gt,
                colors=(0, 255, 0),
                class_ids=0,
                keypoint_ids=sequence.hand_ids,
                show_labels=False,
            ),
        )

    optim_out: tuple[OptimizationResults, LevenbergMarquardtState] = new_joint_and_scale_optimizer(
        uv_left_pred_batch=np.array(uv_batch_left)[0],
        uv_right_pred_batch=np.array(uv_batch_right)[0]
    )
    optimized_result: OptimizationResults = optim_out[0]

    for hand_enum in HandSide:
        hand_side: Literal["left", "right"] = hand_enum.name.lower()
        hand_idx: int = hand_enum.value

        match hand_side:
            case "left":
                xyz_optim = optimized_result.xyz_mano[0]
            case "right":
                xyz_optim = optimized_result.xyz_mano[1]
            case _:
                raise ValueError(f"Invalid hand side: {hand_side}")
            
        rr.log(
            f"{hand_side}_optim",
            rr.Points3D(
                xyz_optim,
                colors=(0, 255, 0),
                class_ids=1,
                keypoint_ids=sequence.hand_ids,
                show_labels=False,
            ),
        )




Tracing JIT, can take a while...
Trace Done


Viewer()