In [1]:
%load_ext ipython_beartype
%beartype

%load_ext autoreload
%autoreload 2

In [2]:
# Add these lines to the cell with id 34bb5824
import warnings

import jax.numpy as jnp
import numpy as np
import rerun as rr
from jaxtyping import Array, Float
from lovely_numpy import lo
from simplecv.data.exoego.skeleton.mediapipe import MEDIAPIPE_ID2NAME, MEDIAPIPE_IDS, MEDIAPIPE_LINKS

from pi0_lerobot.mano.mano_jax import mano_j_left, mano_j_right

warnings.filterwarnings('ignore', category=RuntimeWarning)




In [3]:

def set_pose_annotation_context() -> None:
    rr.log(
        "/",
        rr.AnnotationContext(
            [
                rr.ClassDescription(
                    info=rr.AnnotationInfo(id=0, label="Triangulated Hand", color=(0, 0, 255)),
                    keypoint_annotations=[
                        rr.AnnotationInfo(id=id, label=name) for id, name in MEDIAPIPE_ID2NAME.items()
                    ],
                    keypoint_connections=MEDIAPIPE_LINKS,
                ),
                rr.ClassDescription(
                    info=rr.AnnotationInfo(id=1, label="Optimized Hand", color=(255, 0, 0)),
                    keypoint_annotations=[
                        rr.AnnotationInfo(id=id, label=name) for id, name in MEDIAPIPE_ID2NAME.items()
                    ],
                    keypoint_connections=MEDIAPIPE_LINKS,
                ),
            
            ]
        ),
        static=True,
    )

In [4]:
rr.init("mano_joints")

set_pose_annotation_context()

pose_coeffs: Float[Array, "batch 48"] = jnp.zeros((1, 48))
left_trans: Float[Array, "batch 1 3"] = jnp.array([[[0,0,0.2]]])

xyz_left:Float[Array, "batch 21 3"] = mano_j_left(pose_coeffs, left_trans)


rr.log("left joints", rr.Points3D(positions=xyz_left[0], class_ids=0, keypoint_ids=MEDIAPIPE_IDS, colors=(0, 255, 0), show_labels=False))

rr.notebook_show()

[2025-04-23T22:54:31Z WARN  re_log_types::path::parse_path] When parsing the entity path "left joints": Unescaped whitespace. The path will be interpreted as /left\ joints


Viewer()

If not, consider setting `RERUN_NOTEBOOK_ASSET`. Consult https://pypi.org/project/rerun-notebook/0.22.1/ for details.



In [5]:
from pathlib import Path

from jaxtyping import Float32
from numpy import ndarray
from simplecv.camera_parameters import PinholeParameters
from simplecv.data.exoego.hocap import HOCapSequence

data_path = Path("../../data/hocap/sample")
assert data_path.exists(), f"Data path {data_path} does not exist, please check the path."
sequence: HOCapSequence = HOCapSequence(
                data_path=data_path,
                sequence_name="20231024_180733",
                subject_id="8",
                load_labels=True,
            )


exo_cam_list: list[PinholeParameters] = sequence.exo_cam_list
projection_all_list: list[Float32[np.ndarray, "3 4"]] = []
for exo_cam in sequence.exo_cam_list:
    projection_matrix: Float32[ndarray, "3 4"] = exo_cam.projection_matrix.astype(np.float32)
    projection_all_list.append(projection_matrix)

Pall: Float32[np.ndarray, "n_views 3 4"] = np.array([P for P in projection_all_list])

Found HoloLens camera: hololens_kv5h72


Loading videos: 100%|██████████| 8/8 [00:00<00:00, 27938.74it/s]
Loading 3D labels: 100%|██████████| 8/8 [00:02<00:00,  2.88it/s]
Indexing depth images: 100%|██████████| 8/8 [00:00<00:00, 444.81it/s]


In [6]:
from simplecv.apis.view_exoego_data import create_blueprint
from simplecv.data.exoego.base_exo_ego import ExoData

exo_data: ExoData = sequence[0]
parent_log_path:Path = Path("world")

exo_cam_log_paths: list[Path] = [parent_log_path / exo_cam.name for exo_cam in sequence.exo_cam_list]
exo_video_log_paths: list[Path] = [cam_log_paths / "pinhole" / "video" for cam_log_paths in exo_cam_log_paths]

blueprint = create_blueprint(exo_video_log_paths=exo_video_log_paths, num_videos_to_log=8)

## Lets get this working for the first 10 frames

In [7]:
import itertools

rr.init("sequence data")
set_pose_annotation_context()
rr.notebook_show(width=1000, height=700, blueprint=blueprint)

exo_data: ExoData
for idx, exo_data in enumerate(itertools.islice(sequence, 10)):
    rr.set_time_sequence("frame idx",sequence=idx)
    for hand_idx, (hand_side, color, class_id) in enumerate(
            (
                ("left", (0, 255, 0), 0),
                ("right", (0, 255, 0), 0),
            )
        ):
            xyz_gt: Float32[ndarray, "21 3"] = exo_data.xyz[hand_idx]
            rr.log(
                hand_side,
                rr.Points3D(
                    xyz_gt,
                    colors=color,
                    class_ids=class_id,
                    keypoint_ids=MEDIAPIPE_IDS,
                    show_labels=False,
                ),
            )

    for cam_param, bgr in zip(exo_data.cam_params_list, exo_data.bgr_list, strict=True):
        for hand_idx, (hand_side, color, class_id) in enumerate(
            (
                ("left", (0, 255, 0), 0),
                ("right", (0, 255, 0), 0),
            )
        ):
            xyz_gt: Float32[ndarray, "21 3"] = exo_data.xyz[hand_idx]
            z_relative: Float32[ndarray, "21"] = xyz_gt[:, 2] - xyz_gt[9, 2]
            # np.nan is used to indicate that a joint is not confidently detected
            uv: Float32[ndarray, "21 2"] = exo_data.uv_dict[cam_param.name][hand_idx]
            # np.nan is used to indicate that a joint is not confidently detected
            uv[uv == -1] = np.nan

            uvz:Float32[ndarray, "21 3"] = np.concatenate([uv, z_relative[:, None]], axis=1)

            # find min/max to generate a bounding box, add a small padding
            uv_min:Float32[ndarray, "2"] = np.nanmin(uv, axis=0)
            uv_max:Float32[ndarray, "2"] = np.nanmax(uv, axis=0)

            xyxy:Float32[ndarray, "4"] = np.concatenate([uv_min, uv_max], axis=0)
            if np.any(np.isnan(xyxy)):
                print(f"xyxy is nan: {xyxy}")
                continue

            image_log_path: Path = parent_log_path / cam_param.name / "pinhole" / "video"
            rr.log(
                f"{image_log_path}/{hand_side}_keypoints",
                rr.Points2D(
                    uv,
                    colors=color,
                    class_ids=class_id,
                    keypoint_ids=MEDIAPIPE_IDS,
                    show_labels=False,
                ),
            )

            rr.log(
                    f"{image_log_path}/{hand_side}_xyxy",
                    rr.Boxes2D(
                        array=xyxy,
                        array_format=rr.Box2DFormat.XYXY
                    )
            )
            rr.log(
                f"{image_log_path}",
                rr.Image(
                    bgr,
                    color_model=rr.ColorModel.BGR
                ).compress(jpeg_quality=75),
            )

Viewer()

If not, consider setting `RERUN_NOTEBOOK_ASSET`. Consult https://pypi.org/project/rerun-notebook/0.22.1/ for details.



xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]
xyxy is nan: [nan nan nan nan]


In [8]:
import jax
import jax.numpy as npj
from jax import jit
from jaxopt import LevenbergMarquardt
from jaxopt._src.levenberg_marquardt import LevenbergMarquardtState

from pi0_lerobot.mano.mano_optimization_jax import LMOptimJointOnly, LossWeights, OptimizationResults

loss_weights = LossWeights(
    keypoint_2d=0.1,
    depth=0.0,
    temp=0.0,
)
megatrack_fitter = LMOptimJointOnly(
    Pall=Pall,
    loss_weights=loss_weights,
    num_iters=30,
)



Tracing JIT, can take a while...
Trace Done


# Vectorized Projection using batched matrix multiplication with jax

In [9]:
from einops import rearrange
from simplecv.rerun_log_utils import log_pinhole



@jit
def proj_3d_other(
    xyz_hom: Float[Array, "n_frames n_joints 4"],
    P: Float[Array, "n_views 3 4"]
    ) -> Float[Array, "n_frames n_views n_joints 2"]:
    """
    Projects 3D points to 2D using the projection matrix for a batch of frames and views.

    xyz_hom: [n_frames, 21, 4] [x, y, z, 1]
    P: [n_views, 3, 4] (projection matrix - includes extrensic (R, t) and intrinsic (K))

    return kp2d: [n_frames, n_views, n_joints, 2] (squeeze out if 1)
    """
    # rearrange for batch matrix multiplication
    xyz_hom: Float[Array, "n_frames 1 4 21"] = rearrange(xyz_hom, "n_frames n_joints xyz_hom -> n_frames 1 xyz_hom n_joints")
    P: Float[Array, "1 n_views 3 4"] = rearrange(P, "n_views n m -> 1 n_views n m")

    # [1 n_views, 3, 4] @ [n_frames, 1, 4, 21] -> [n_frames, n_views, 3, 21]
    uv_hom: Float[Array, "n_frames n_views 3 21"] = P @ xyz_hom
    uv_hom = rearrange(uv_hom, "n_frames n_views xyz_hom n_joints -> n_frames n_views n_joints xyz_hom")
    # convert back from homogeneous coordinates
    uv:Float[Array, "n_frames n_views 21 2"] = uv_hom[..., :2] / uv_hom[..., 2:]

    return uv

# Calculate indices for 5 equally spaced frames
num_frames_to_select = 5
total_frames = len(sequence)
# Use np.linspace to get equally spaced indices, including start and end
frame_indices = np.linspace(0, total_frames - 1, num_frames_to_select).astype(int)

# Get the ExoData for the selected frames
exo_data_list: list[ExoData] = [sequence[int(frame_idx)] for frame_idx in frame_indices]

# create a batch of xyz points
xyz_batch: list[Float[ndarray, "2 21 3"]] = [exo_data.xyz for exo_data in exo_data_list]
xyz_batch: Float[ndarray, "n_frames 2 21 3"] = np.stack(xyz_batch, axis=0)
xyz_hom_batch: Float[ndarray, "n_frames 2 21 4"] = np.concatenate([xyz_batch, np.ones_like(xyz_batch)[..., 0:1]], axis=-1)

left_xyz_hom_batch: Float[ndarray, "n_frames 21 4"] = xyz_hom_batch[:, 0]
right_xyz_hom_batch: Float[ndarray, "n_frames 21 4"] = xyz_hom_batch[:, 1]


uv_batch_left:Float[Array, "n_frames n_views 21 2"] = proj_3d_other(
    xyz_hom=npj.array(left_xyz_hom_batch),
    P=npj.array(Pall)
)
uv_batch_right:Float[Array, "n_frames n_views 21 2"] = proj_3d_other(
    xyz_hom=npj.array(right_xyz_hom_batch),
    P=npj.array(Pall)
)
uv_batch: Float[ndarray, "n_frames n_views 2 21 2"] = np.stack([uv_batch_left, uv_batch_right], axis=2)

In [10]:
rr.init("Random Sequence")
set_pose_annotation_context()
rr.notebook_show(width=1000, height=700, blueprint=blueprint)
for frame_idx, (frame_ts, exo_data) in enumerate(zip(frame_indices, exo_data_list, strict=True)):
    rr.set_time_sequence("frame idx", sequence=frame_ts)
    for hand_idx, (hand_side, color, class_id) in enumerate(
            (
                ("left", (0, 255, 0), 0),
                ("right", (0, 255, 0), 0),
            )
        ):
            xyz_gt: Float32[ndarray, "21 3"] = exo_data.xyz[hand_idx]
            uv_views: Float[ndarray, "n_views 21 2"] = uv_batch[frame_idx, :, hand_idx, ...]
            rr.log(
                hand_side,
                rr.Points3D(
                    xyz_gt,
                    colors=color,
                    class_ids=class_id,
                    keypoint_ids=MEDIAPIPE_IDS,
                    show_labels=False,
                ),
            )
            for cam_param, bgr, uv in zip(exo_data.cam_params_list, exo_data.bgr_list, uv_views, strict=True): 
                cam_log_path: Path = parent_log_path / cam_param.name   
                image_log_path: Path = cam_log_path / "pinhole" / "video"

                rr.log(
                    f"{image_log_path}",
                    rr.Image(
                        bgr,
                        color_model=rr.ColorModel.BGR
                    ).compress(jpeg_quality=75),
                )
                rr.log(
                    f"{image_log_path}/{hand_side}_keypoints",
                    rr.Points2D(
                        uv,
                        colors=color,
                        class_ids=class_id,
                        keypoint_ids=MEDIAPIPE_IDS,
                        show_labels=False,
                    ),
                )

                log_pinhole(
                    camera=cam_param,
                    cam_log_path=cam_log_path,
                    image_plane_distance=0.1,
                    static=False
                )


Viewer()

## Example optimization using 3d keypoints

In [11]:
rr.init("Optimization over sequence data")
set_pose_annotation_context()
rr.notebook_show(width=1000, height=700, blueprint=blueprint)

exo_data: ExoData
for idx, exo_data in enumerate(itertools.islice(sequence, 15)):
    rr.set_time_sequence("frame idx",sequence=idx)
    for hand_idx, (hand_side, color, class_id) in enumerate(
            (
                ("left", (0, 255, 0), 0),
                ("right", (0, 255, 0), 0),
            )
        ):
            
            output: tuple[OptimizationResults, LevenbergMarquardtState] = megatrack_fitter(xyz_pred_batch=exo_data.xyz)
            optimized_result: OptimizationResults = output[0]
            state: LevenbergMarquardtState = output[1]
            xyz_pred: Float32[ndarray, "21 3"] = exo_data.xyz[hand_idx]
            rr.log("log", rr.TextLog(f"num_iters = {state.iter_num}"))
            match hand_side:
                case "left":
                    xyz: Float[ndarray, "21 3"] = optimized_result.xyz_mano[0]
                case "right":
                    xyz: Float[ndarray, "21 3"] = optimized_result.xyz_mano[1]
                case _: 
                    raise ValueError(f"Invalid hand side: {hand_side}")
            rr.log(
                f"{hand_side}_optimized",
                rr.Points3D(
                    xyz,
                    colors=color,
                    class_ids=1,
                    keypoint_ids=MEDIAPIPE_IDS,
                    show_labels=False,
                ),
            )
            rr.log(
                f"{hand_side}",
                rr.Points3D(
                    xyz_pred,
                    colors=color,
                    class_ids=class_id,
                    keypoint_ids=sequence.hand_ids,
                    show_labels=False,
                ),
            )

Viewer()

## Example Optimization on a single timestep using multiview 2d keypoints

In [82]:
from jax.debug import print as jprint

HAND_TYPE= ["left", "right"]

@jit
def mv_2d_residual(
    param_to_optimize: Float[Array, "_"],
    Pall: Float[Array, "b 3 4"],
    uv_pred: Float[Array, "b n_views 21 2"],
    loss_weights: LossWeights,
    is_left,
) -> Float[Array, "_"]:
    """
    Calculates the residual error between projected MANO keypoints and target 2D keypoints.

    Args:
        param_to_optimize: Flattened MANO parameters (pose coefficients and translation).
                           Must be a 1D array (batch_size * 51) because jaxopt optimizers
                           like LevenbergMarquardt expect a flat vector of parameters.
        Pall: Projection matrices for each camera view, shape (b, 3, 4).
              'b' here refers to the batch size (number of frames/samples).
        uv_pred: Target 2D keypoints for each view and joint, shape (b, n_views, 21, 2).
                 'n_views' is the number of camera views.
        loss_weights: Dictionary containing weights for different loss components (e.g., 'keypoint_2d').
        is_left: Boolean indicating whether to use the left or right MANO model.

    Returns:
        A flattened 1D array containing the weighted residual errors for all keypoints, views, and batch items.
    """
    batch_size: int = uv_pred.shape[0]
    # extract parameters that are being optimized and add batch dimension
    param_to_optimize: Float[Array, "b 51"] = param_to_optimize.reshape(batch_size, 51)
    so3: Float[Array, "b 48"] = param_to_optimize[:, 0:48]
    trans: Float[Array, "b 1 3"] = param_to_optimize[:, npj.newaxis, 48:]

    def left_func(x:tuple[Float[Array, "b 48"], Float[Array, "b 1 3"]])->Float[Array, "b 21 3"]:
        return mano_j_left(x[0], x[1])

    def right_func(x:tuple[Float[Array, "b 48"], Float[Array, "b 1 3"]])->Float[Array, "b 21 3"]:
        return mano_j_right(x[0], x[1])

    xyz_mano: Float[Array, "b 21 3"] = jax.lax.cond(is_left, left_func, right_func, (so3, trans))
    xyz_mano_hom: Float[Array, "b 21 4"] = npj.concatenate([xyz_mano, npj.ones_like(xyz_mano)[..., 0:1]], axis=-1)

    uv_mano: Float[Array, "b n_views 21 2"] = proj_3d_other(
        xyz_hom=xyz_mano_hom,
        P=Pall
    )

    # calculate residuals
    res_2d: Float[Array, "b n_views 21 2"] = uv_mano - uv_pred
    res_2d = jnp.nan_to_num(res_2d * loss_weights["keypoint_2d"], nan=0.0)


    # Return the flattened vector of valid, weighted residuals
    return res_2d.flatten()




# init_params: Float[Array, "51"] = npj.concatenate([optimized_result.so3[0], optimized_result.trans[0]])
init_params: Float[Array, "b 51"] = npj.zeros((1, 51))

exo_data: ExoData = sequence[0]
P_list: list[ndarray] = [exo_cam.projection_matrix for exo_cam in exo_data.cam_params_list]
uv_list: list[ndarray] = [exo_data.uv_dict[cam_param.name] for cam_param in exo_data.cam_params_list]
P_all: Float[Array, "n_views 3 4"] = npj.stack(P_list, axis=0)
uv_gt: Float[Array, "n_views 2 21 2"] = npj.stack(uv_list, axis=0)
uv_left_gt_batch: Float[Array, "n_views 21 2"] = uv_gt[:, 0, ...]
uv_left_gt_batch: Float[Array, "n_frames n_views 21 2"] = rearrange(uv_left_gt_batch, "n_views n_joints uv -> 1 n_views n_joints uv")

init_params: Float[Array, "b 51"] = npj.zeros((1, 51))

residual = mv_2d_residual(
    param_to_optimize=init_params.flatten(),
    Pall=P_all,
    uv_pred=uv_left_gt_batch,
    loss_weights=loss_weights,
    is_left=True
)
print(lo(np.array(residual)))

array[336] f32 1.3Kb x∈[-13.673, 13.244] μ=0.322 σ=5.815


In [83]:
optimizer = LevenbergMarquardt(
    residual_fun=mv_2d_residual, maxiter=100, solver="cholesky", jit=True, xtol=1e-6, gtol=1e-6
)
# add jit
print("Tracing JIT, can take a while...")
_, _ = optimizer.run(
    init_params.flatten(),
    Pall=P_all,
    uv_pred=uv_left_gt_batch,
    loss_weights=loss_weights,
    is_left=True
)
optimizer = jit(optimizer.run)



Tracing JIT, can take a while...


In [86]:
# optimize 5 frames of the left hand
exo_data_list: list[ExoData] = [sequence[int(frame_idx)] for frame_idx in range(1)]

# create batch of frames and views of uv
uv_batch_list: list[Float[ndarray, "n_views 2 21 2"]] = []
for exo_data in exo_data_list:
    uv_list: list[ndarray] = [exo_data.uv_dict[cam_param.name] for cam_param in exo_data.cam_params_list]
    uv_gt: Float[Array, "n_views 2 21 2"] = npj.stack(uv_list, axis=0)
    uv_batch_list.append(uv_gt)

uv_batch: Float[Array, "n_frames n_views 2 21 2"] = npj.stack(uv_batch_list, axis=0)
P_list: list[Float[ndarray, "3 4"]] = [exo_cam.projection_matrix for exo_cam in exo_data.cam_params_list]
Pall: Float[Array, "n_views 3 4"] = npj.stack(P_list, axis=0)

uv_left_gt_batch: Float[Array, "n_frames n_views 21 2"] = uv_batch[:, :, 0, ...]
batch_size: int = uv_left_gt_batch.shape[0]
init_params: Float[Array, "b 51"] = npj.zeros((batch_size, 51))

optimized_params, state = optimizer(
    init_params.flatten(),
    Pall=P_all,
    uv_pred=uv_left_gt_batch,
    loss_weights=loss_weights,
    is_left=True
)
optimized_params: Float[Array, "b 51"] = optimized_params.reshape(batch_size, 51)

In [87]:
xyz_mano_left_batch: Float[Array, "n_frames 21 3"] = mano_j_left(optimized_params[:, 0:48], optimized_params[:, npj.newaxis, 48:])
xyz_gt_left_batch: Float[Array, "n_frames 21 3"] = npj.stack([exo_data.xyz[0] for exo_data in exo_data_list], axis=0)
rr.init("Optimization over sequence data")
set_pose_annotation_context()
rr.notebook_show(width=1000, height=700, blueprint=blueprint)

for idx, (xyz_gt_left, xyz_mano_left) in enumerate(zip(xyz_gt_left_batch, xyz_mano_left_batch,strict=True)):
    rr.set_time_sequence("frame idx",sequence=idx)
    rr.log(
        f"{hand_side}_optimized",
        rr.Points3D(
            xyz_mano_left,
            colors=color,
            class_ids=1,
            keypoint_ids=MEDIAPIPE_IDS,
            show_labels=False,
        ),
    )
    rr.log(
        f"{hand_side}",
        rr.Points3D(
            xyz_gt_left,
            colors=color,
            class_ids=class_id,
            keypoint_ids=sequence.hand_ids,
            show_labels=False,
        ),
    )

Viewer()

In [13]:
class LMOptimJointOnlyCalibration:

    def __init__(
        self,
        Pall: Float[ndarray, "batch 3 4"],
        loss_weights: LossWeights,
        num_iters: int = 30,
        optimize_scale_factor: bool = False,
    ) -> None:
        """
        Pall - n, 3, 4 projection matrix
        loss_weights - dictionary containing how much value to give each portion
            of the cost function (2d, 3d, temporal)
        num_iters - how many iterations to optimize
        """

        batch_size: int = Pall.shape[0]

        self.num_iters: int = num_iters
        # Projection Matrix (n, 3, 4) where n is the number of cameras
        self.Pall: Float[ndarray, "batch 3 4"] = Pall

        self.loss_weights: LossWeights = loss_weights
        # use previous values to initialize, there should only ever be 1
        # hand model per frame
        self.so3_left_prev: Float[Array, "48"] = npj.zeros(48)  # noqa: UP037
        self.trans_left_prev: Float[Array, "3"] = npj.array([0.2, 0, 1.5])  # noqa: UP037

        self.so3_right_prev: Float[Array, "48"] = npj.zeros(48)  # noqa: UP037
        self.trans_right_prev: Float[Array, "3"] = npj.array([-0.2, 0, 1.5])  # noqa: UP037

        # remnant from mano, not needed
        self.beta: Float[Array, "1 10"] = npj.zeros((1, 10))

        # remove the need for two different optimizers, solvers ‘cholesky’, ‘inv’
        self.optimizer = LevenbergMarquardt(
            residual_fun=mv_residual, maxiter=self.num_iters, solver="cholesky", jit=True, xtol=1e-6, gtol=1e-6
        )
        # add jit
        print("Tracing JIT, can take a while...")
        init_params: Float[Array, "51"] = npj.concatenate([self.so3_left_prev, self.trans_left_prev])
        _, _ = self.optimizer.run(
            init_params,
            cameras=npj.array(Pall),
            xyz_pred=npj.zeros((batch_size, 21, 3)),
            loss_weights=self.loss_weights,
            is_left=True,
        )
        self.optimizer = jit(self.optimizer.run)

        print("Trace Done")

    def __call__(
        self,
        uv_pred_batch: Float[ndarray, "b n_views 21 3"] | None = None,
    ) -> tuple[OptimizationResults, LevenbergMarquardtState]:
        """
        pose_predictions_dict
            pose_predictions
            camera_dict
        """
        so3_optimized: Float[ndarray, "2 48"] = np.zeros((2, 48))
        trans_optimized: Float[ndarray, "2 3"] = np.zeros((2, 3))
        xyz_mano: Float[ndarray, "2 21 3"] = np.zeros((2, 21, 3))
        for hand_type in self.HAND_TYPE:
            # get previous values, and extract pose_predictions
            match hand_type:
                case "left":
                    xyz_pred: Float[ndarray, "21 3"] = xyz_pred_batch[0]
                    so3_prev: Float[Array, "48"] = self.so3_left_prev.copy()
                    trans_prev: Float[Array, "3"] = self.trans_left_prev.copy()

                case "right":
                    xyz_pred: Float[ndarray, "21 3"] = xyz_pred_batch[1]
                    so3_prev: Float[Array, "48"] = self.so3_right_prev.copy()
                    trans_prev: Float[Array, "3"] = self.trans_right_prev.copy()

            # TODO initialize only rotation from wrist form either 3d procustus or mano preds
            so3_init: Float[Array, "48"] = so3_prev
            trans_init: Float[Array, "3"] = trans_prev

            cam_param_list: Float[Array, "batch 3 4"] = npj.array(self.Pall)
            init_params: Float[Array, "51"] = npj.concatenate([so3_init, trans_init])
            params, state = self.optimizer(
                init_params,
                cameras=cam_param_list,
                xyz_pred=xyz_pred[npj.newaxis, ...],
                loss_weights=self.loss_weights,
                is_left=hand_type == "left",
            )

            if np.isnan(params).any():
                continue

            so3: Float[Array, "48"] = params[:48]
            trans: Float[Array, "3"] = params[48:]

            so3_optimized[0 if hand_type == "left" else 1] = np.array(so3)
            trans_optimized[0 if hand_type == "left" else 1] = np.array(trans)

            # pass optimized values to mano to extract 3d joints
            match hand_type:
                case "left":
                    self.so3_left_prev = so3
                    self.trans_left_prev = trans

                    xyz_mano_left: Float[Array, "1 21 3"] = mano_j_left(
                        so3[npj.newaxis, ...], trans[npj.newaxis, npj.newaxis, ...]
                    )
                    xyz_mano[0] = np.array(xyz_mano_left[0])

                case "right":
                    self.so3_right_prev = so3
                    self.trans_right_prev = trans

                    xyz_mano_right: Float[Array, "1 21 3"] = mano_j_right(
                        so3[npj.newaxis, ...], trans[npj.newaxis, npj.newaxis, ...]
                    )
                    xyz_mano[1] = np.array(xyz_mano_right[0])

        optimization_results = OptimizationResults(
            xyz_mano=xyz_mano,
            so3=so3_optimized,
            trans=trans_optimized,
            betas=np.concatenate([self.beta, self.beta], axis=0),
        )

        return optimization_results, state