In [1]:
# Pose Smoothing with Dynamax EKF
# We load ensemble 2D pose predictions from 6 cameras (A–F), compute ensemble variance for observation noise, 
# triangulate a geometric 3D latent state using calibrated camera parameters, and apply the Extended Kalman Smoother (EKF) using Dynamax.

import os
import numpy as np
import pandas as pd
from pathlib import Path
from glob import glob

from aniposelib.boards import CharucoBoard
from aniposelib.cameras import CameraGroup

In [2]:
import jax
import jax.numpy as jnp
from jax import jit

def _rodrigues(rvec):
    """OpenCV-style Rodrigues: rvec (3,) -> R (3,3)."""
    theta = jnp.linalg.norm(rvec)
    def small_angle(_):
        # First-order approx: R ≈ I + [r]_x  (good when theta ~ 0)
        rx, ry, rz = rvec
        K = jnp.array([[0.0, -rz,  ry],
                       [rz,  0.0, -rx],
                       [-ry, rx,  0.0]])
        return jnp.eye(3) + K
    def general(_):
        rx, ry, rz = rvec / theta
        K = jnp.array([[0.0, -rz,  ry],
                       [rz,  0.0, -rx],
                       [-ry, rx,  0.0]])
        s = jnp.sin(theta)
        c = jnp.cos(theta)
        return jnp.eye(3) + s*K + (1.0 - c) * (K @ K)
    return jax.lax.cond(theta < 1e-12, small_angle, general, operand=None)

def _parse_dist(dist_coeffs):
    """
    OpenCV pinhole distortion ordering:
      [k1, k2, p1, p2, k3, k4, k5, k6, s1, s2, s3, s4, tx, ty]  (tx,ty tilt optional)
    We support up to s1..s4; tilt is ignored here.
    """
    dc = jnp.pad(jnp.asarray(dist_coeffs, dtype=jnp.float64), (0, max(0, 14 - len(dist_coeffs))))  # length ≥ 14
    k1, k2, p1, p2, k3, k4, k5, k6, s1, s2, s3, s4, tx, ty = [dc[i] for i in range(14)]
    return dict(k1=k1, k2=k2, p1=p1, p2=p2, k3=k3, k4=k4, k5=k5, k6=k6, s1=s1, s2=s2, s3=s3, s4=s4)

def make_jax_projection_fn(rvec, tvec, K, dist_coeffs):
    """
    JAX-compatible replacement for cv2.projectPoints (standard pinhole model).

    Args
    ----
    rvec : (3,) Rodrigues rotation vector (world -> camera)
    tvec : (3,) translation (world -> camera), same units as your world coords
    K    : (3,3) camera intrinsic matrix
           [[fx, s, cx],
            [ 0, fy, cy],
            [ 0, 0,  1 ]]
    dist_coeffs : iterable of distortion coefficients in OpenCV order
                  [k1, k2, p1, p2[, k3[, k4, k5, k6[, s1, s2, s3, s4[, tx, ty]]]]]

    Returns
    -------
    project(object_points) -> image_points
      object_points: (..., 3)
      image_points:  (..., 2)
    """
    # cache params as arrays
    rvec = jnp.asarray(rvec, dtype=jnp.float64)
    tvec = jnp.asarray(tvec, dtype=jnp.float64)
    K    = jnp.asarray(K,    dtype=jnp.float64)
    fx, fy, cx, cy, skew = K[0,0], K[1,1], K[0,2], K[1,2], K[0,1]
    d = _parse_dist(dist_coeffs)
    R = _rodrigues(rvec)

    @jit
    def project(object_points):
        # object_points: (..., 3)
        Xw = jnp.asarray(object_points, dtype=jnp.float64)
        # world -> camera
        Xc = Xw @ R.T + tvec  # (..., 3)
        X, Y, Z = Xc[..., 0], Xc[..., 1], Xc[..., 2]

        # normalized coords
        x = X / Z
        y = Y / Z

        r2  = x*x + y*y
        r4  = r2*r2
        r6  = r4*r2
        r8  = r4*r4
        r10 = r8*r2
        r12 = r6*r6

        radial = (
            1.0
            + d["k1"]*r2 + d["k2"]*r4 + d["k3"]*r6
            + d["k4"]*r8 + d["k5"]*r10 + d["k6"]*r12
        )

        x_tan = 2.0*d["p1"]*x*y + d["p2"]*(r2 + 2.0*x*x)
        y_tan = d["p1"]*(r2 + 2.0*y*y) + 2.0*d["p2"]*x*y

        # thin-prism
        x_tp = d["s1"]*r2 + d["s2"]*r4
        y_tp = d["s3"]*r2 + d["s4"]*r4

        xd = x * radial + x_tan + x_tp
        yd = y * radial + y_tan + y_tp

        # intrinsics (allow nonzero skew)
        u = fx * xd + skew * yd + cx
        v = fy * yd + cy

        return jnp.stack([u, v], axis=-1)  # (..., 2)

    return project

In [None]:
import os
import cv2
import jax.numpy as jnp
from jax import jit, vmap
import numpy as np
import pandas as pd
from aniposelib.cameras import CameraGroup
from sklearn.decomposition import PCA
from typeguard import typechecked
from typing import Tuple, Callable
from eks.core import ensemble
from eks.marker_array import (
    MarkerArray,
    input_dfs_to_markerArray,
    mA_to_stacked_array,
    stacked_array_to_mA,
)
import jax
jax.config.update("jax_enable_x64", True)
from eks.stats import compute_mahalanobis, compute_pca
from eks.utils import center_predictions, format_data, make_dlc_pandas_index
from eks.multicam_smoother import mA_compute_maha, initialize_kalman_filter_pca

def fit_eks_multicam(
    input_source: str | list,
    save_dir: str,
    bodypart_list: list | None = None,
    smooth_param: float | list | None = None,
    s_frames: list | None = None,
    camera_names: list | None = None,
    quantile_keep_pca: float = 95.0,
    avg_mode: str = 'median',
    var_mode: str = 'confidence_weighted_var',
    inflate_vars: bool = False,
    verbose: bool = False,
    n_latent: int = 3,
    backend: str = 'jax',
    camgroup=None
) -> tuple:

    # Load and format input files
    # NOTE: input_dfs_list is a list of camera-specific lists of Dataframes
    input_dfs_list, keypoint_names = format_data(input_source, camera_names=camera_names)
    if bodypart_list is None:
        bodypart_list = keypoint_names

    marker_array = input_dfs_to_markerArray(input_dfs_list, bodypart_list, camera_names)

    # Run the ensemble Kalman smoother for multi-camera data
    camera_dfs, smooth_params_final, h_cams, ys_3d = ensemble_kalman_smoother_multicam(
        marker_array=marker_array,
        keypoint_names=bodypart_list,
        smooth_param=smooth_param,
        quantile_keep_pca=quantile_keep_pca,
        camera_names=camera_names,
        s_frames=s_frames,
        avg_mode=avg_mode,
        var_mode=var_mode,
        verbose=verbose,
        inflate_vars=inflate_vars,
        n_latent=n_latent,
        backend=backend,
        camgroup=camgroup
    )
    # Save output DataFrames to CSVs (one per camera view)
    os.makedirs(save_dir, exist_ok=True)
    for c, camera in enumerate(camera_names):
        save_filename = f'multicam_{camera}_results.csv'
        camera_dfs[c].to_csv(os.path.join(save_dir, save_filename))
    return camera_dfs, smooth_params_final, input_dfs_list, bodypart_list, marker_array, h_cams, ys_3d

def initialize_kalman_filter_geometric(ys: np.ndarray) -> Tuple[jnp.ndarray, ...]:
    """
    Initialize Kalman filter parameters for geometric (3D) keypoints.

    Args:
        ys: Array of shape (K, T, 3) — triangulated keypoints.

    Returns:
        Tuple of Kalman filter parameters:
            - m0s: (K, 3) initial means
            - S0s: (K, 3, 3) initial covariances
            - As: (K, 3, 3) transition matrices
            - Qs: (K, 3, 3) process noise covariances
            - Cs: (K, 3, 3) observation matrices
    """
    K, T, D = ys.shape

    # Initial state means (can also use ys[:, 0, :] if preferred)
    m0s = np.zeros((K, D))
    # Use variance across time to estimate initial uncertainty
    S0s = np.array([
        np.diag([
            np.nanvar(ys[k, :, d]) + 1e-4  # avoid degenerate matrices
            for d in range(D)
        ])
        for k in range(K)
    ])  # (K, 3, 3)

    # Identity matrices
    As = np.tile(np.eye(D), (K, 1, 1))
    Cs = np.tile(np.eye(D), (K, 1, 1))
    Qs = np.tile(np.eye(D), (K, 1, 1)) * 1e-3  # small default process noise

    return (
        jnp.array(m0s),
        jnp.array(S0s),
        jnp.array(As),
        jnp.array(Qs),
        jnp.array(Cs),
    )


def ensemble_kalman_smoother_multicam(
    marker_array: MarkerArray,
    keypoint_names: list,
    smooth_param: float | list | None = None,
    quantile_keep_pca: float = 95.0,
    camera_names: list | None = None,
    s_frames: list | None = None,
    avg_mode: str = 'median',
    var_mode: str = 'confidence_weighted_var',
    inflate_vars: bool = False,
    inflate_vars_kwargs: dict = {},
    verbose: bool = False,
    pca_object: PCA | None = None,
    n_latent: int = 3,
    backend: str = 'jax',
    camgroup=None,
) -> tuple:

    n_models, n_cameras, n_frames, n_keypoints, _ = marker_array.shape

    # === Ensemble Mean/Var per camera/keypoint ===
    ensemble_marker_array = ensemble(marker_array, avg_mode=avg_mode, var_mode=var_mode)
    emA_unsmoothed_preds = ensemble_marker_array.slice_fields("x", "y")
    emA_vars = ensemble_marker_array.slice_fields("var_x", "var_y")
    emA_likes = ensemble_marker_array.slice_fields("likelihood")

    # === Triangulate all 3D positions ===
    triangulated_3d_models = np.zeros((n_models, n_keypoints, n_frames, 3))
    raw_array = marker_array.get_array()
    for m in range(n_models):
        for k in range(n_keypoints):
            for t in range(n_frames):
                xy_views = [raw_array[m, c, t, k, :2] for c in range(n_cameras)]
                triangulated_3d_models[m, k, t] = camgroup.triangulate(np.array(xy_views))

    ys_3d = triangulated_3d_models.mean(axis=0)          # (K, T, 3)
    ensemble_vars_3d = triangulated_3d_models.var(axis=0)  # (K, T, 3)

    # === Define a single multi-view h_fn (ℝ³ → ℝ^{2V}) ===
    h_cams = []
    for cam in camgroup.cameras:
        print(cam.get_size())
        rot = np.array(cam.get_rotation())
        # Convert to Rodrigues vector if needed
        rvec = cv2.Rodrigues(rot)[0].ravel() if rot.shape == (3, 3) else rot.ravel()
        tvec = np.array(cam.get_translation()).ravel()
        K    = np.array(cam.get_camera_matrix())
        dist = np.array(cam.get_distortions()).ravel()  # distortion coeffs: k1,k2,p1,p2,k3,...

        h_cams.append(
            make_jax_projection_fn(
                jnp.array(rvec),
                jnp.array(tvec),
                jnp.array(K),
                jnp.array(dist)
            )
        )

    def make_combined_h_fn(h_list):
        def h_fn(x):
            return jnp.concatenate([h(x) for h in h_list], axis=0)
        return h_fn

    h_fn_combined = make_combined_h_fn(h_cams)

    # === Initialize Kalman filter ===

    m0s, S0s, As, cov_mats, Cs = initialize_kalman_filter_geometric(ys_3d)
    m0s = np.array([ys_3d[k, :10].mean(axis=0) for k in range(n_keypoints)])
    s_finals = np.full(len(keypoint_names), smooth_param) if np.isscalar(smooth_param) else np.asarray(smooth_param)

    # === Apply EKF in latent 3D space using projected 2D observations ===
    ms_all, Vs_all = [], []
    for k in range(n_keypoints):
        y_proj = np.concatenate([vmap(h)(ys_3d[k]) for h in h_cams], axis=1)  # (T, 2V)
        r_proj = np.concatenate([ensemble_vars_3d[k][:, :2] for _ in range(n_cameras)], axis=1)  # (T, 2V)
        
        ms, Vs = dynamax_ekf_smooth_routine(
            y=ys_3d[k],
            m0=m0s[k],
            S0=S0s[k],
            A=As[k],
            Q=s_finals[k] * cov_mats[k],
            C=np.eye(3),
            ensemble_vars=ensemble_vars_3d[k],
            f_fn=None,
            h_fn=None, 
        )

        # ms, Vs = dynamax_ekf_smooth_routine(
        #     y=y_proj,
        #     m0=m0s[k],
        #     S0=S0s[k],
        #     A=As[k],
        #     Q=s_finals[k] * cov_mats[k],
        #     C=None,
        #     ensemble_vars=r_proj,
        #     f_fn=None,
        #     h_fn=h_fn_combined, 
        # )


        ms_all.append(np.array(ms))
        Vs_all.append(np.array(Vs))

    ms_all = np.stack(ms_all, axis=0)  # (K, T, 3)
    Vs_all = np.stack(Vs_all, axis=0)  # (K, T, 3, 3)


    # === Reproject smoothed 3D estimates back to each camera ===
    camera_arrs = [[] for _ in camera_names]
    for k, keypoint in enumerate(keypoint_names):
        ms_k = ms_all[k]
        Vs_k = Vs_all[k]
        inflated_vars_k = ensemble_vars_3d[k]
        
        # rebuild a no-distortion projector per cam using the same rvec,tvec,K
        print("camgroup order:", [getattr(cam, "name", f"cam{i}") for i,cam in enumerate(camgroup.cameras)])
        print("marker_array order:", marker_array.get_camera_names() if hasattr(marker_array, "get_camera_names") else "unknown")

        # Compare one frame k=0,t=0
        k=0; t=0
        for c in range(len(camgroup.cameras)):
            obs = emA_unsmoothed_preds.slice("keypoints", k).slice("cameras", c).get_array(squeeze=True)[t]
            prj = np.array(h_cams[c](ms_all[k][t]))
            print(f"c{c}: obs={obs}, proj={prj}, diff={obs-prj}")
        
        for c, camera in enumerate(camgroup.cameras):
            #xy_proj = camera.project(ms_k).reshape(-1, 2)
            xy_proj = np.array(vmap(h_cams[c])(ms_k))     # (T, 2)
            xy_obs = emA_unsmoothed_preds.slice("keypoints", k).slice("cameras", c).get_array(squeeze=True)  # (T,2)
            resid = xy_obs - xy_proj  # (T,2)
            print(f"cam {c} mean residual (px):", resid.mean(axis=0), " std:", resid.std(axis=0))
            try:
                cov2d_proj = camera.project_covariance(ms_k, Vs_k)
                var_x = cov2d_proj[:, 0, 0] + inflated_vars_k[:, 0]
                var_y = cov2d_proj[:, 1, 1] + inflated_vars_k[:, 1]
            except AttributeError:
                var_x = np.full(ms_k.shape[0], np.nan)
                var_y = np.full(ms_k.shape[0], np.nan)

            data_arr = camera_arrs[c]
            data_arr.extend([
                xy_proj[:, 0],
                xy_proj[:, 1],
                emA_likes.slice("keypoints", k).slice("cameras", c).get_array(squeeze=True),
                emA_unsmoothed_preds.slice("keypoints", k).slice("cameras", c).slice_fields("x").get_array(squeeze=True),
                emA_unsmoothed_preds.slice("keypoints", k).slice("cameras", c).slice_fields("y").get_array(squeeze=True),
                emA_vars.slice("keypoints", k).slice("cameras", c).slice_fields("var_x").get_array(squeeze=True),
                emA_vars.slice("keypoints", k).slice("cameras", c).slice_fields("var_y").get_array(squeeze=True),
                var_x,
                var_y,
            ])

    # === Format output ===
    labels = ['x', 'y', 'likelihood', 'x_ens_median', 'y_ens_median',
              'x_ens_var', 'y_ens_var', 'x_posterior_var', 'y_posterior_var']
    pdindex = make_dlc_pandas_index(keypoint_names, labels=labels)
    camera_dfs = [pd.DataFrame(np.asarray(arr).T, columns=pdindex) for arr in camera_arrs]

    return camera_dfs, s_finals, h_cams, ys_3d



In [4]:
from dynamax.nonlinear_gaussian_ssm.inference_ekf import extended_kalman_smoother, extended_kalman_filter
from dynamax.nonlinear_gaussian_ssm.models import (
    ParamsNLGSSM,
)

import jax
import jax.numpy as jnp
import numpy as np
from typing import Union, Tuple, Callable
from typeguard import typechecked

ArrayLike = Union[np.ndarray, jax.Array]

def dynamax_ekf_smooth_routine(
    y: ArrayLike,
    m0: ArrayLike,
    S0: ArrayLike,
    A: ArrayLike,
    Q: ArrayLike,
    C: ArrayLike | None,
    ensemble_vars: ArrayLike,  # shape (T, obs_dim)
    f_fn: Callable[[jnp.ndarray], jnp.ndarray] | None = None,
    h_fn: Callable[[jnp.ndarray], jnp.ndarray] | None = None
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Extended Kalman smoother using the Dynamax nonlinear interface,
    allowing for time-varying observation noise.

    By default, uses linear dynamics and emissions: f(x) = Ax, h(x) = Cx.

    Args:
        y: (T, obs_dim) observation sequence.
        m0: (state_dim,) initial mean.
        S0: (state_dim, state_dim) initial covariance.
        A: (state_dim, state_dim) dynamics matrix.
        Q: (state_dim, state_dim) process noise covariance.
        C: (obs_dim, state_dim) emission matrix (optional).
        ensemble_vars: (T, obs_dim) per-timestep observation noise variance.
        f_fn: optional dynamics function f(x).
        h_fn: optional emission function h(x).

    Returns:
        smoothed_means: (T, state_dim)
        smoothed_covariances: (T, state_dim, state_dim)
    """
    y, m0, S0, A, Q, ensemble_vars = map(jnp.asarray, (y, m0, S0, A, Q, ensemble_vars))
    C = jnp.asarray(C) if C is not None else None

    if f_fn is None:
        f_fn = lambda x: A @ x
    if h_fn is None:
        if C is None:
            raise ValueError("Must provide either emission matrix C or a nonlinear emission function h_fn.")
        h_fn = lambda x: C @ x
    # Dynamically determine obs_dim from h_fn output
    obs_dim = y.shape[1]
    R_t = jnp.stack([jnp.diag(var_t[:obs_dim]) for var_t in ensemble_vars], axis=0)  # shape (T, obs_dim, obs_dim)
    params = ParamsNLGSSM(
        initial_mean=m0,
        initial_covariance=S0,
        dynamics_function=f_fn,
        dynamics_covariance=Q,
        emission_function=h_fn,
        emission_covariance=R_t,
    )
    #with jax.disable_jit():
    posterior = extended_kalman_smoother(params, y)
    return posterior.filtered_means, posterior.filtered_covariances

In [5]:
from eks.utils import plot_results

input_source = "./data/chickadee_uncropped"
camera_names = ["lBack", "lFront", "lTop", "rBack", "rFront", "rTop"]
keypoints = ["topBeak", "topHead", "backHead", "centerChes", "baseTail", "tipTail", "leftEye", "leftNeck", "leftWing", "leftAnkle", "leftFoot", "rightEye", "rightNeck", "rightWing", "rightAnkle", "rightFoot"]
camgroup = CameraGroup.load("./data/chickadee/calibration.toml")
# input_source = "./data/fly"
# camera_names = ["Cam-A", "Cam-B", "Cam-C", "Cam-D", "Cam-E", "Cam-F"]
# keypoints = ["L1A", "L1B"]
# camgroup = CameraGroup.load("./data/fly/calibration.toml")

save_dir = "./outputs/"

# Load calibration file


camera_dfs, s_finals, input_dfs, bodypart_list, marker_array, h_cams, ys_3d = fit_eks_multicam(
    input_source=input_source,
    save_dir=save_dir,
    bodypart_list=keypoints,
    smooth_param=10,
    camera_names=camera_names,
    quantile_keep_pca=95,
    verbose=True,
    inflate_vars=False,
    n_latent=3,
    backend="dynamax-ekf",
    camgroup=camgroup
)

keypoint_i = -1
camera_c = -1
plot_results(
    output_df=camera_dfs[camera_c],
    input_dfs_list=input_dfs[camera_c],
    key=f'{bodypart_list[keypoint_i]}',
    idxs=(0, 500),
    s_final=s_finals[keypoint_i],
    nll_values=None,
    save_dir=save_dir,
    smoother_type='multicam',
)


[2816, 1408]
[2816, 1408]
[2816, 1696]
[2816, 1408]
[2816, 1408]
[2816, 1696]
camgroup order: ['lBack', 'lFront', 'lTop', 'rBack', 'rFront', 'rTop']
marker_array order: unknown
c0: obs=[1750.93206108  391.50040874], proj=[1751.38181764  385.1042529 ], diff=[-0.44975657  6.39615584]
c1: obs=[1594.26820338  463.19929254], proj=[1609.76057455  480.0672997 ], diff=[-15.49237117 -16.86800716]
c2: obs=[1691.0761517   778.44187536], proj=[1692.78650143  778.0042768 ], diff=[-1.71034972  0.43759856]
c3: obs=[1280.40348787  326.09742532], proj=[1283.34233378  326.58320903], diff=[-2.93884591 -0.48578371]
c4: obs=[1006.14625554  404.40209646], proj=[1015.72107357  401.72426204], diff=[-9.57481803  2.67783442]
c5: obs=[1113.13084931  556.79887314], proj=[1121.75050458  544.48723124], diff=[-8.61965526 12.3116419 ]
cam 0 mean residual (px): [-2.00586461 -0.04818341]  std: [4.65336109 3.65989187]
cam 1 mean residual (px): [-0.400378   -0.55426035]  std: [4.47750558 4.57930606]
cam 2 mean residual (

In [6]:
import numpy as np
import jax.numpy as jnp

# === Settings ===
frame_idx = 120
keypoint_idx = 1
model_idx = 0  # or average across models

# === Step 1: Extract 2D predictions from all cameras ===
raw_array = marker_array.get_array()  # (n_models, n_cameras, n_frames, n_keypoints, 2+)
xy_views = [raw_array[model_idx, c, frame_idx, keypoint_idx, :2] for c in range(len(camgroup.cameras))]
xy_views_np = np.stack(xy_views)  # (n_cameras, 2)

# === Step 2: Triangulate to get 3D point ===
x_3d = camgroup.triangulate(xy_views_np)  # shape (3,)

print(f"Triangulated 3D point: {x_3d}")

# === Step 3: Reproject into each view ===
projected_views = [h(jnp.array(x_3d)) for h in h_cams]
projected_views_np = np.stack([np.array(p) for p in projected_views])  # (n_cameras, 2)

# === Step 4: Compute reprojection error per view ===
for i, (orig, proj) in enumerate(zip(xy_views_np, projected_views_np)):
    err = np.linalg.norm(orig - proj)
    print(f"Camera {i}: reprojection error = {err:.3f} pixels")

Triangulated 3D point: [-0.05837184 -0.28620351  0.42230134]
Camera 0: reprojection error = 5.255 pixels
Camera 1: reprojection error = 5.623 pixels
Camera 2: reprojection error = 2.997 pixels
Camera 3: reprojection error = 2.839 pixels
Camera 4: reprojection error = 0.815 pixels
Camera 5: reprojection error = 6.390 pixels


In [7]:
import numpy as np
from sklearn.decomposition import PCA
import jax.numpy as jnp

# === Settings ===
frame_idx = 45
keypoint_idx = 0
model_idx = 0  # pick one or average if needed

# --- helper: make camgroup.project output (n_cams, 2) ---
def cg_project_point(camgroup, x3d):
    """Project a single 3D point with aniposelib CameraGroup, return (n_cams, 2)."""
    x = np.asarray(x3d, dtype=float).reshape(1, 3)
    out = camgroup.project(x)  # library-dependent shape
    # Try common shapes: (n_cams, 1, 2), (n_cams, 2), dict of cam->(1,2)
    if isinstance(out, dict):
        proj = np.stack([np.asarray(out[cam.name])[0] for cam in camgroup.cameras], axis=0)
    else:
        arr = np.asarray(out)
        if arr.ndim == 3 and arr.shape[1] == 1 and arr.shape[2] == 2:
            proj = arr[:, 0, :]
        elif arr.ndim == 2 and arr.shape == (len(camgroup.cameras), 2):
            proj = arr
        elif arr.ndim == 2 and arr.shape == (2, len(camgroup.cameras)):
            proj = arr.T
        else:
            raise ValueError(f"Unexpected camgroup.project shape: {arr.shape}")
    return proj  # (n_cams, 2)

# === 1) 2D observations for this keypoint+frame from all cameras ===
raw_array = marker_array.get_array()  # (n_models, n_cams, n_frames, n_keypoints, 2+)
n_cams = len(camgroup.cameras)
xy_views_np = np.stack(
    [raw_array[model_idx, c, frame_idx, keypoint_idx, :2] for c in range(n_cams)],
    axis=0
)  # (n_cams, 2)

# === 2) Triangulated 3D point ===
x_triang = camgroup.triangulate(xy_views_np)  # (3,)

# === 3) PCA-reconstructed 3D point ===
# Assume ys_3d: (K, T, 3)
ys_3d_reshaped = ys_3d.reshape(-1, 3)
pca = PCA(n_components=3)
Z = pca.fit_transform(ys_3d_reshaped)
ys_3d_pca = pca.inverse_transform(Z).reshape(ys_3d.shape)
x_pca = ys_3d_pca[keypoint_idx, frame_idx]  # (3,)

# === 4) Project both 3D points with:
#     (a) your custom JAX projectors h_cams
#     (b) camgroup.project (OpenCV-based)
reproj_triang_custom = np.stack([np.array(h(jnp.array(x_triang))) for h in h_cams], axis=0)  # (n_cams, 2)
reproj_pca_custom    = np.stack([np.array(h(jnp.array(x_pca)))    for h in h_cams], axis=0)

reproj_triang_cg = cg_project_point(camgroup, x_triang)  # (n_cams, 2)
reproj_pca_cg    = cg_project_point(camgroup, x_pca)

# === 5) Print comparison ===
print(f"Triangulated point:     {x_triang}")
print(f"PCA-reconstructed point:{x_pca}\n")

for i in range(n_cams):
    obs = xy_views_np[i]

    # errors vs observations
    err_tri_custom = np.linalg.norm(obs - reproj_triang_custom[i])
    err_tri_cg     = np.linalg.norm(obs - reproj_triang_cg[i])

    err_pca_custom = np.linalg.norm(obs - reproj_pca_custom[i])
    err_pca_cg     = np.linalg.norm(obs - reproj_pca_cg[i])

    # difference between projectors (should be ~0 if both are consistent)
    diff_tri = np.linalg.norm(reproj_triang_custom[i] - reproj_triang_cg[i])
    diff_pca = np.linalg.norm(reproj_pca_custom[i]    - reproj_pca_cg[i])

    print(
        f"Camera {i}: "
        f"tri e_custom={err_tri_custom:.2f}px, e_cg={err_tri_cg:.2f}px, projΔ={diff_tri:.2f}px | "
        f"pca e_custom={err_pca_custom:.2f}px, e_cg={err_pca_cg:.2f}px, projΔ={diff_pca:.2f}px"
    )


Triangulated point:     [-0.09515544 -0.22536958  0.36407084]
PCA-reconstructed point:[-0.09511244 -0.22479253  0.36471556]

Camera 0: tri e_custom=4.01px, e_cg=4.01px, projΔ=0.00px | pca e_custom=3.46px, e_cg=3.46px, projΔ=0.00px
Camera 1: tri e_custom=12.93px, e_cg=12.93px, projΔ=0.00px | pca e_custom=13.93px, e_cg=13.93px, projΔ=0.00px
Camera 2: tri e_custom=2.23px, e_cg=2.23px, projΔ=0.00px | pca e_custom=1.27px, e_cg=1.27px, projΔ=0.00px
Camera 3: tri e_custom=2.84px, e_cg=2.84px, projΔ=0.00px | pca e_custom=2.60px, e_cg=2.60px, projΔ=0.00px
Camera 4: tri e_custom=5.52px, e_cg=5.52px, projΔ=0.00px | pca e_custom=4.95px, e_cg=4.95px, projΔ=0.00px
Camera 5: tri e_custom=1.84px, e_cg=1.84px, projΔ=0.00px | pca e_custom=2.30px, e_cg=2.30px, projΔ=0.00px
