# Notebook 1: Procrustes Alignment for Walking Data

This notebook loads the Berlin tethered walking dataset, performs Procrustes alignment
to a reference fly model pose, and exports scaled keypoint data ready for STAC registration.

**Workflow:**
1. Load CSV dataset with walking bout keypoints
2. Extract and transform keypoints to model reference frame
3. Get reference pose from MuJoCo fly model
4. Apply Procrustes alignment with scaling
5. Apply ground contact alignment
6. Export aligned keypoints to H5 for STAC

## 1. Environment Setup and Imports

In [None]:
# Environment setup (must be before JAX import)
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["MUJOCO_GL"] = "egl"
os.environ["PYOPENGL_PLATFORM"] = "egl"
os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Adjust GPU as needed

In [None]:
%load_ext autoreload
%autoreload 2

# Standard library
import sys
import warnings
from pathlib import Path

# Scientific computing
import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
from jax import jit, vmap

# MuJoCo
import mujoco

# Visualization
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Progress bars
from tqdm.auto import tqdm

# H5 file I/O
import h5py

# JAX cache setup
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

print(f"JAX devices: {jax.devices()}")

## 2. Configuration

In [None]:
# === PATH CONFIGURATION ===
# Base paths - adjust these to your environment
BASE_PATH = Path("/home/talmolab/Desktop/SalkResearch")
DATA_PATH = BASE_PATH / "data"
STAC_MJX_PATH = BASE_PATH / "stac-mjx"

# Input data
CSV_DATA_PATH = DATA_PATH / "wt_berlin_tethered_dataset.csv"

# Reference model path
MODEL_PATH = STAC_MJX_PATH / "models" / "fruitfly" / "fruitfly_force.xml"

# Output path for aligned keypoints
OUTPUT_PATH = DATA_PATH / "aligned_walking_keypoints.h5"

# === ALIGNMENT PARAMETERS ===
FLOOR_HEIGHT = -0.125  # Target floor Z height in model units
GROUND_CONTACT_PERCENTILE = 5.0  # Percentile for ground contact detection
END_EFFECTOR_INDICES = jnp.array([4, 9, 14, 19, 24, 29])  # Indices of leg tips (claws)

print(f"CSV data path: {CSV_DATA_PATH}")
print(f"Model path: {MODEL_PATH}")
print(f"Output path: {OUTPUT_PATH}")

## 3. Load Walking Dataset

In [None]:
# Load the CSV dataset
print(f"Loading data from {CSV_DATA_PATH}...")
full_df = pd.read_csv(CSV_DATA_PATH)
print(f"Loaded dataframe with shape: {full_df.shape}")
print(f"Columns: {len(full_df.columns)}")

In [None]:
# Define column mappings for keypoints
# Data uses L1, R1, L2, R2, L3, R3 for legs and A, B, C, D, E for joints
legs_data = ["L1", "R1", "L2", "R2", "L3", "R3"]
joints_data = ["A", "B", "C", "D", "E"]  # coxa, femur, tibia, tarsus, claw
coords_data = ["_x", "_y", "_z"]

# Build column names for keypoint positions
joint_pos_columns = [
    leg + joint + coord
    for leg in legs_data
    for joint in joints_data
    for coord in coords_data
]

print(f"Number of keypoint columns: {len(joint_pos_columns)}")
print(f"Expected: 6 legs x 5 joints x 3 coords = 90")

# Verify columns exist
missing_cols = [col for col in joint_pos_columns if col not in full_df.columns]
if missing_cols:
    print(f"WARNING: Missing columns: {missing_cols[:5]}...")
else:
    print("All keypoint columns found.")

In [None]:
def transform_bout(bout_kp):
    """
    Transform keypoints from data reference frame to model reference frame.

    Args:
        bout_kp: ndarray of shape (T, 30, 3) - keypoints for one bout

    Returns:
        Transformed keypoints in model frame (cm units)
    """
    # Swap X and Y axes (rotate around Z)
    bout_kp = bout_kp[:, :, [1, 0, 2]]
    # Flip Y axis
    bout_kp[:, :, 1] *= -1
    # Convert mm to cm
    bout_kp *= 0.1
    return bout_kp

In [None]:
# Extract walking bouts
all_bout_nums = full_df["walking_bout_number"].unique()
# Skip bout 0 (usually non-walking frames)
all_bout_nums = all_bout_nums[all_bout_nums > 0]
all_bout_nums = sorted(all_bout_nums)

print(f"Found {len(all_bout_nums)} walking bouts")

# Create bout dictionary with keypoint data
bout_dict = {}

for n, bout_num in enumerate(tqdm(all_bout_nums, desc="Extracting bouts")):
    bout_df = full_df[full_df["walking_bout_number"] == bout_num]

    # Extract keypoint positions and reshape to (T, 30, 3)
    kp_raw = bout_df[joint_pos_columns].values.reshape(-1, 30, 3)

    # Transform to model reference frame
    kp_transformed = transform_bout(kp_raw.copy())

    bout_key = f"walking_bout{n:04d}"
    bout_dict[bout_key] = {
        "orig_kp": kp_transformed,
        "bout_number": bout_num,
        "n_frames": len(bout_df),
    }

print(f"\nProcessed {len(bout_dict)} bouts")
print(f"Example bout shape: {bout_dict['walking_bout0000']['orig_kp'].shape}")

## 4. Load Reference Pose from MuJoCo Model

In [None]:
# Load the fly model
print(f"Loading MuJoCo model from {MODEL_PATH}...")
spec = mujoco.MjSpec().from_file(str(MODEL_PATH))
mj_model = spec.compile()

# Get tracking site names and indices
site_names = [site.name for site in spec.sites if "tracking" in site.name]
site_idxs = jnp.array([site.id for site in spec.sites if "tracking" in site.name])

print(f"Found {len(site_names)} tracking sites")
print(f"Site names: {site_names[:5]}...")

In [None]:
# Get reference pose from default model configuration
mj_data = mujoco.MjData(mj_model)
mujoco.mj_forward(mj_model, mj_data)

# Extract tracking site positions as reference pose
ref_pose = mj_data.site_xpos[site_idxs].copy()

# Center at first joint (L1A - coxa of left front leg)
print(f"Centering reference pose at: {ref_pose[0]} ({site_names[0]})")
ref_pose = ref_pose - ref_pose[0]

print(f"Reference pose shape: {ref_pose.shape}")
print(
    f"Reference pose range: X=[{ref_pose[:, 0].min():.4f}, {ref_pose[:, 0].max():.4f}]"
)
print(
    f"                      Y=[{ref_pose[:, 1].min():.4f}, {ref_pose[:, 1].max():.4f}]"
)
print(
    f"                      Z=[{ref_pose[:, 2].min():.4f}, {ref_pose[:, 2].max():.4f}]"
)

## 5. Procrustes Alignment Functions

These functions implement:
1. **Procrustes alignment with scaling**: Rigid body transformation + uniform scaling to align keypoints to reference pose
2. **Ground contact alignment**: Ensures end effectors (leg tips) touch the ground plane

In [None]:
def procrustes_with_scaling(source, target):
    """
    Compute optimal rotation, translation, and scale to align source to target.

    Uses Procrustes analysis with uniform scaling.

    Args:
        source: (N, 3) source points
        target: (N, 3) target points

    Returns:
        aligned: (N, 3) aligned source points
        info: dict with transformation parameters
    """
    # Center both point sets
    source_centered = source - jnp.mean(source, axis=0)
    target_centered = target - jnp.mean(target, axis=0)

    # Compute optimal rotation using SVD
    H = source_centered.T @ target_centered
    U, S, Vt = jnp.linalg.svd(H)

    # Handle reflection case
    d = jnp.sign(jnp.linalg.det(Vt.T @ U.T))
    D = jnp.diag(jnp.array([1.0, 1.0, d]))

    R = Vt.T @ D @ U.T

    # Compute optimal scale
    source_rotated = source_centered @ R.T
    scale = jnp.sum(target_centered * source_rotated) / jnp.sum(
        source_rotated * source_rotated
    )

    # Apply transformation
    aligned = scale * (source_centered @ R.T) + jnp.mean(target, axis=0)

    # Compute alignment error
    error = jnp.sqrt(jnp.mean(jnp.sum((aligned - target) ** 2, axis=1)))

    info = {
        "rotation": R,
        "scale": scale,
        "source_centroid": jnp.mean(source, axis=0),
        "target_centroid": jnp.mean(target, axis=0),
        "error": error,
    }

    return aligned, info


def vectorized_procrustes_with_scaling_avg(kp_sequence, reference):
    """
    Apply Procrustes alignment using clip average (JIT-compatible).
    """
    # Use average pose for computing transformation
    avg_pose = jnp.mean(kp_sequence, axis=0)
    _, transform_info = procrustes_with_scaling(avg_pose, reference)

    R = transform_info["rotation"]
    scale = transform_info["scale"]
    source_centroid = jnp.mean(kp_sequence, axis=(0, 1), keepdims=True)[0]
    target_centroid = jnp.mean(reference, axis=0)

    # Apply same transformation to all frames
    centered = kp_sequence - source_centroid
    aligned = scale * jnp.einsum("tni,ij->tnj", centered, R) + target_centroid

    # Compute mean error across all frames
    errors = jnp.sqrt(jnp.mean(jnp.sum((aligned - reference) ** 2, axis=-1), axis=-1))
    transform_info["mean_error"] = jnp.mean(errors)

    return aligned, transform_info


# JIT compile the average-based alignment (this is the one we'll use)
jit_vectorized_procrustes_avg = jax.jit(vectorized_procrustes_with_scaling_avg)

In [None]:
def fit_ground_plane(points):
    """
    Fit a plane to 3D points using least squares.

    Args:
        points: (N, 3) points to fit

    Returns:
        normal: (3,) unit normal vector
        d: plane offset (plane equation: normal . x = d)
    """
    centroid = jnp.mean(points, axis=0)
    centered = points - centroid

    # SVD to find normal (smallest singular vector)
    _, _, Vt = jnp.linalg.svd(centered)
    normal = Vt[-1]  # Last row of Vt is the normal

    # Ensure normal points upward (positive Z)
    normal = jnp.where(normal[2] < 0, -normal, normal)

    d = jnp.dot(normal, centroid)

    return normal, d


def rotation_matrix_to_align_vectors(v_from, v_to):
    """
    Compute rotation matrix to align v_from to v_to.
    Uses Rodrigues' rotation formula.
    """
    v_from = v_from / jnp.linalg.norm(v_from)
    v_to = v_to / jnp.linalg.norm(v_to)

    axis = jnp.cross(v_from, v_to)
    axis_norm = jnp.linalg.norm(axis)

    # Handle parallel vectors
    cos_angle = jnp.dot(v_from, v_to)

    def compute_rotation(args):
        axis, axis_norm, cos_angle = args
        axis = axis / axis_norm
        angle = jnp.arccos(jnp.clip(cos_angle, -1.0, 1.0))

        # Rodrigues' formula
        K = jnp.array(
            [[0, -axis[2], axis[1]], [axis[2], 0, -axis[0]], [-axis[1], axis[0], 0]]
        )
        R = jnp.eye(3) + jnp.sin(angle) * K + (1 - jnp.cos(angle)) * (K @ K)
        return R

    # If vectors are nearly parallel, return identity
    R = jax.lax.cond(
        axis_norm < 1e-6,
        lambda args: jnp.eye(3),
        compute_rotation,
        operand=(axis, axis_norm, cos_angle),
    )

    return R


def align_to_ground_plane_with_contact(
    kp_sequence, end_eff_indices, percentile, target_z
):
    """
    Align keypoints so end effectors contact a horizontal ground plane.

    Steps:
    1. Find ground contact points (lowest percentile of end effector Z positions)
    2. Fit a plane to these points
    3. Rotate to make plane horizontal
    4. Translate to target Z height
    5. Clip any points below floor

    Args:
        kp_sequence: (T, N, 3) keypoint sequence
        end_eff_indices: indices of end effector keypoints
        percentile: percentile for ground contact detection
        target_z: target floor height

    Returns:
        aligned: (T, N, 3) aligned keypoints
        info: dict with alignment info
    """
    # Extract end effector positions
    endeff_xpos = kp_sequence[:, end_eff_indices]

    # Find Z threshold for each end effector
    z_thresholds = jnp.percentile(endeff_xpos[..., 2], percentile, axis=0)

    # Collect ground contact points
    def get_contact_points(ee_idx):
        z_vals = endeff_xpos[:, ee_idx, 2]
        threshold = z_thresholds[ee_idx]
        mask = z_vals <= threshold
        # Get mean position of contact frames
        contact_pos = jnp.where(mask[:, None], endeff_xpos[:, ee_idx], jnp.nan)
        return jnp.nanmean(contact_pos, axis=0)

    ground_points = vmap(get_contact_points)(jnp.arange(len(end_eff_indices)))

    # Fit ground plane
    normal, d = fit_ground_plane(ground_points)

    # Rotation to align ground normal with Z-axis
    z_axis = jnp.array([0.0, 0.0, 1.0])
    R = rotation_matrix_to_align_vectors(normal, z_axis)

    # Apply rotation
    rotated = jnp.einsum("ij,tnj->tni", R, kp_sequence)

    # Compute translation to target floor height
    rotated_endeff = rotated[:, end_eff_indices]
    min_z = jnp.percentile(rotated_endeff[..., 2], percentile)
    translation = jnp.array([0.0, 0.0, target_z - min_z])

    # Apply translation
    translated = rotated + translation

    # Clip points below floor
    clipped = translated.at[..., 2].set(jnp.maximum(translated[..., 2], target_z))

    # Count statistics
    n_clipped = jnp.sum(translated[..., 2] < target_z)
    n_contacts = jnp.sum(jnp.abs(clipped[:, end_eff_indices, 2] - target_z) < 0.001)

    info = {
        "rotation_matrix": R,
        "translation": translation,
        "ground_normal": normal,
        "z_thresholds": z_thresholds,
        "n_clipped": n_clipped,
        "n_contacts": n_contacts,
        "n_ground_points": len(end_eff_indices),
        "contact_ratio": n_contacts / (kp_sequence.shape[0] * len(end_eff_indices)),
    }

    return clipped, info

In [None]:
def complete_alignment_pipeline(
    kp_sequence, reference, end_eff_indices, percentile=5.0, target_z=-0.125
):
    """
    Complete alignment pipeline: Procrustes + ground contact alignment.

    Args:
        kp_sequence: (T, N, 3) raw keypoint sequence
        reference: (N, 3) reference pose
        end_eff_indices: indices of end effector keypoints
        percentile: percentile for ground contact detection
        target_z: target floor height

    Returns:
        aligned: (T, N, 3) fully aligned keypoints
        info: dict with pipeline info
    """
    # Step 1: Procrustes alignment (using clip average)
    procrustes_aligned, procrustes_info = vectorized_procrustes_with_scaling_avg(
        kp_sequence, reference
    )

    # Step 2: Ground contact alignment
    ground_aligned, ground_info = align_to_ground_plane_with_contact(
        procrustes_aligned, end_eff_indices, percentile, target_z
    )

    info = {"procrustes": procrustes_info, "ground_contact": ground_info}

    return ground_aligned, info


# Create JIT-compiled version
jit_complete_alignment = jax.jit(complete_alignment_pipeline)

## 6. Visualize Alignment (Optional)

In [None]:
def visualize_alignment(original, aligned, reference, frame_idx=0, title="Alignment"):
    """
    Visualize original, aligned, and reference poses in 3D.
    """
    fig = plt.figure(figsize=(15, 5))

    orig_frame = original[frame_idx] if original.ndim == 3 else original
    aligned_frame = aligned[frame_idx] if aligned.ndim == 3 else aligned

    # Original
    ax1 = fig.add_subplot(131, projection="3d")
    ax1.scatter(orig_frame[:, 0], orig_frame[:, 1], orig_frame[:, 2], c="blue", s=20)
    ax1.set_title("Original")
    ax1.set_xlabel("X")
    ax1.set_ylabel("Y")
    ax1.set_zlabel("Z")

    # Aligned
    ax2 = fig.add_subplot(132, projection="3d")
    ax2.scatter(
        aligned_frame[:, 0], aligned_frame[:, 1], aligned_frame[:, 2], c="green", s=20
    )
    ax2.scatter(
        reference[:, 0], reference[:, 1], reference[:, 2], c="red", s=20, alpha=0.5
    )
    ax2.set_title("Aligned (green) vs Reference (red)")
    ax2.set_xlabel("X")
    ax2.set_ylabel("Y")
    ax2.set_zlabel("Z")

    # Overlay
    ax3 = fig.add_subplot(133, projection="3d")
    ax3.scatter(
        aligned_frame[:, 0],
        aligned_frame[:, 1],
        aligned_frame[:, 2],
        c="green",
        s=20,
        label="Aligned",
    )
    ax3.scatter(
        reference[:, 0],
        reference[:, 1],
        reference[:, 2],
        c="red",
        s=20,
        alpha=0.5,
        label="Reference",
    )
    # Draw floor plane
    xx, yy = np.meshgrid(
        np.linspace(aligned_frame[:, 0].min(), aligned_frame[:, 0].max(), 10),
        np.linspace(aligned_frame[:, 1].min(), aligned_frame[:, 1].max(), 10),
    )
    ax3.plot_surface(xx, yy, np.full_like(xx, FLOOR_HEIGHT), alpha=0.2, color="gray")
    ax3.set_title(f"{title} - Frame {frame_idx}")
    ax3.set_xlabel("X")
    ax3.set_ylabel("Y")
    ax3.set_zlabel("Z")
    ax3.legend()

    plt.tight_layout()
    plt.show()

In [None]:
# Test alignment on a single clip
test_clip_idx = 25
test_clip_key = f"walking_bout{test_clip_idx:04d}"

if test_clip_key in bout_dict:
    kp_test = jnp.array(bout_dict[test_clip_key]["orig_kp"])
    ref_pose_jax = jnp.array(ref_pose)

    print(f"Testing alignment on {test_clip_key}")
    print(f"Input shape: {kp_test.shape}")

    # Apply full alignment pipeline
    aligned_test, info = jit_complete_alignment(
        kp_test,
        ref_pose_jax,
        END_EFFECTOR_INDICES,
        percentile=GROUND_CONTACT_PERCENTILE,
        target_z=FLOOR_HEIGHT,
    )

    print(f"Procrustes mean error: {info['procrustes']['mean_error']:.6f}")
    print(f"Ground contact ratio: {info['ground_contact']['contact_ratio']:.1%}")

    # Visualize
    visualize_alignment(
        kp_test, aligned_test, ref_pose_jax, frame_idx=0, title=test_clip_key
    )
else:
    print(f"Clip {test_clip_key} not found. Available clips: {len(bout_dict)}")

## 7. Process All Bouts

In [None]:
# Process all bouts through the alignment pipeline
ref_pose_jax = jnp.array(ref_pose)

alignment_errors = []
contact_ratios = []

for clip_key in tqdm(bout_dict.keys(), desc="Aligning bouts"):
    kp_clip = jnp.array(bout_dict[clip_key]["orig_kp"])

    # Apply full alignment pipeline
    aligned_clip, pipeline_info = jit_complete_alignment(
        kp_clip,
        ref_pose_jax,
        END_EFFECTOR_INDICES,
        percentile=GROUND_CONTACT_PERCENTILE,
        target_z=FLOOR_HEIGHT,
    )

    # Store aligned keypoints
    bout_dict[clip_key]["aligned_kp"] = np.array(aligned_clip)

    # Track statistics
    alignment_errors.append(float(pipeline_info["procrustes"]["mean_error"]))
    contact_ratios.append(float(pipeline_info["ground_contact"]["contact_ratio"]))

print(f"\nAlignment complete!")
print(
    f"Mean alignment error: {np.mean(alignment_errors):.6f} +/- {np.std(alignment_errors):.6f}"
)
print(
    f"Mean contact ratio: {np.mean(contact_ratios):.1%} +/- {np.std(contact_ratios):.1%}"
)

In [None]:
# Plot alignment statistics
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].hist(alignment_errors, bins=30, edgecolor="black")
axes[0].set_xlabel("Procrustes Alignment Error")
axes[0].set_ylabel("Count")
axes[0].set_title("Alignment Error Distribution")
axes[0].axvline(
    np.mean(alignment_errors),
    color="r",
    linestyle="--",
    label=f"Mean: {np.mean(alignment_errors):.4f}",
)
axes[0].legend()

axes[1].hist(contact_ratios, bins=30, edgecolor="black")
axes[1].set_xlabel("Ground Contact Ratio")
axes[1].set_ylabel("Count")
axes[1].set_title("Ground Contact Ratio Distribution")
axes[1].axvline(
    np.mean(contact_ratios),
    color="r",
    linestyle="--",
    label=f"Mean: {np.mean(contact_ratios):.1%}",
)
axes[1].legend()

plt.tight_layout()
plt.show()

## 8. Export Aligned Keypoints

In [None]:
# Prepare data for export
print(f"Preparing data for export to {OUTPUT_PATH}...")

# Verify all clips have aligned keypoints
for clip_key, clip_data in bout_dict.items():
    if "aligned_kp" not in clip_data:
        print(f"WARNING: {clip_key} missing aligned_kp!")
    else:
        orig_shape = clip_data["orig_kp"].shape
        aligned_shape = clip_data["aligned_kp"].shape
        if orig_shape != aligned_shape:
            print(
                f"WARNING: {clip_key} shape mismatch: orig={orig_shape}, aligned={aligned_shape}"
            )

print(f"Total clips to export: {len(bout_dict)}")

In [None]:
# Save to H5 file
print(f"Saving aligned keypoints to {OUTPUT_PATH}...")

with h5py.File(OUTPUT_PATH, "w") as f:
    # Save metadata
    f.attrs["description"] = (
        "Procrustes-aligned walking keypoints for STAC registration"
    )
    f.attrs["n_clips"] = len(bout_dict)
    f.attrs["floor_height"] = FLOOR_HEIGHT
    f.attrs["ground_contact_percentile"] = GROUND_CONTACT_PERCENTILE
    f.attrs["source_csv"] = str(CSV_DATA_PATH)
    f.attrs["model_path"] = str(MODEL_PATH)

    # Save keypoint names
    kp_names = [f"{leg}{joint}" for leg in legs_data for joint in joints_data]
    f.create_dataset("kp_names", data=np.array(kp_names, dtype="S"))

    # Save reference pose
    f.create_dataset("reference_pose", data=ref_pose, compression="gzip")

    # Save each clip
    clips_group = f.create_group("clips")
    clip_lengths = []

    for clip_key, clip_data in tqdm(bout_dict.items(), desc="Saving clips"):
        clip_grp = clips_group.create_group(clip_key)
        clip_grp.create_dataset(
            "orig_kp", data=clip_data["orig_kp"], compression="gzip"
        )
        clip_grp.create_dataset(
            "aligned_kp", data=clip_data["aligned_kp"], compression="gzip"
        )
        clip_grp.attrs["bout_number"] = clip_data["bout_number"]
        clip_grp.attrs["n_frames"] = clip_data["n_frames"]
        clip_lengths.append(clip_data["n_frames"])

    # Save clip lengths
    f.create_dataset("clip_lengths", data=np.array(clip_lengths))

print(f"\nSaved {len(bout_dict)} clips to {OUTPUT_PATH}")
print(f"Total frames: {sum(clip_lengths)}")
print(f"Clip length range: {min(clip_lengths)} - {max(clip_lengths)} frames")

In [None]:
# Verify the saved file
print("\nVerifying saved file...")
with h5py.File(OUTPUT_PATH, "r") as f:
    print(f"File attributes: {dict(f.attrs)}")
    print(f"Datasets: {list(f.keys())}")
    print(f"Number of clips: {len(f['clips'])}")

    # Check first clip
    first_clip = list(f["clips"].keys())[0]
    print(f"\nFirst clip ({first_clip}):")
    print(f"  orig_kp shape: {f['clips'][first_clip]['orig_kp'].shape}")
    print(f"  aligned_kp shape: {f['clips'][first_clip]['aligned_kp'].shape}")

print(f"\nOutput file ready for STAC registration: {OUTPUT_PATH}")

## Summary

This notebook has:
1. Loaded the Berlin tethered walking dataset CSV
2. Extracted walking bouts and transformed keypoints to model reference frame
3. Applied Procrustes alignment with scaling to match the reference fly pose
4. Applied ground contact alignment to ensure leg tips touch the floor
5. Exported aligned keypoints to H5 format for STAC registration

**Next step:** Use `02_STAC_Registration.ipynb` to run STAC on the aligned keypoints.