# Notebook 3: Trajectory Integration for Walking Data

This notebook integrates FicTrac wheel rotation/displacement data with STAC-registered
joint angles to create complete walking trajectories with proper body position and orientation.

**Workflow:**
1. Load STAC IK output from Notebook 2
2. Load FicTrac trajectory data from original CSV
3. Calculate yaw orientation from trajectory direction
4. Apply position and orientation offsets to qpos
5. Optionally interpolate to higher frequency
6. Export final trajectory-integrated data

## 1. Environment Setup and Imports

In [1]:
# Environment setup (must be before JAX import)
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["MUJOCO_GL"] = "egl"
os.environ["PYOPENGL_PLATFORM"] = "egl"
os.environ["XLA_FLAGS"] = "--xla_gpu_triton_gemm_any=True"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Adjust GPU as needed

In [2]:
%load_ext autoreload
%autoreload 2

# Standard library
import sys
import warnings
from pathlib import Path

# Scientific computing
import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
from jax import jit, vmap
from scipy import interpolate

# MuJoCo
import mujoco

# Visualization
import matplotlib.pyplot as plt

# Progress bars
from tqdm.auto import tqdm

# H5 file I/O
import h5py

# JAX cache setup
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)

print(f"JAX devices: {jax.devices()}")

  from .autonotebook import tqdm as notebook_tqdm


JAX devices: [CudaDevice(id=0)]


## 2. Configuration

In [3]:
# === PATH CONFIGURATION ===
BASE_PATH = Path("/home/talmolab/Desktop/SalkResearch")
DATA_PATH = BASE_PATH / "data"
STAC_MJX_PATH = BASE_PATH / "stac-mjx"

# Input: STAC IK output from Notebook 2
STAC_IK_PATH = DATA_PATH / "stac_ik_output.h5"

# Input: Original CSV for FicTrac data
CSV_DATA_PATH = DATA_PATH / "wt_berlin_tethered_dataset.csv"

# Input: Aligned keypoints (for clip mapping)
ALIGNED_KP_PATH = DATA_PATH / "aligned_walking_keypoints.h5"

# Model path (for wing joint indices)
MODEL_PATH = STAC_MJX_PATH / "models" / "fruitfly" / "fruitfly_force.xml"

# Output path
OUTPUT_PATH = DATA_PATH / "walking_trajectory_integrated.h5"

# === PROCESSING PARAMETERS ===
SMOOTHING_WINDOW = 75  # Window size for yaw smoothing
Z_OFFSET = 0.0195  # Height offset for z position
SOURCE_HZ = 300  # Source data frequency
TARGET_HZ = 500  # Target frequency for interpolation (None to skip)

# Default wing position (folded wings)
DEFAULT_WING_POS = jnp.array([1.5, 0.814, -0.821, 1.5, 0.814, -0.821])

print(f"STAC IK input: {STAC_IK_PATH}")
print(f"CSV data: {CSV_DATA_PATH}")
print(f"Output: {OUTPUT_PATH}")

STAC IK input: /home/talmolab/Desktop/SalkResearch/data/stac_ik_output.h5
CSV data: /home/talmolab/Desktop/SalkResearch/data/wt_berlin_tethered_dataset.csv
Output: /home/talmolab/Desktop/SalkResearch/data/walking_trajectory_integrated.h5


## 3. Load STAC IK Output

In [4]:
# Load STAC IK results from Notebook 2
print(f"Loading STAC IK output from {STAC_IK_PATH}...")

with h5py.File(STAC_IK_PATH, "r") as f:
    print(f"Attributes: {dict(f.attrs)}")

    # Load qpos (joint positions)
    qpos_all = f["qpos"][:]
    print(f"qpos shape: {qpos_all.shape}")

    # Load clip lengths
    clip_lengths = f["clip_lengths"][:]
    print(f"Number of clips: {len(clip_lengths)}")

    # Load other data if available
    xpos_all = f["xpos"][:] if "xpos" in f else None
    xquat_all = f["xquat"][:] if "xquat" in f else None
    offsets = f["offsets"][:] if "offsets" in f else None

    # Load names
    joint_names = (
        [n.decode("utf-8") for n in f["joint_names"][:]] if "joint_names" in f else None
    )
    kp_names = (
        [n.decode("utf-8") for n in f["kp_names"][:]] if "kp_names" in f else None
    )

print(f"\nTotal frames: {qpos_all.shape[0]}")
print(f"qpos dimensions: {qpos_all.shape[1]}")

Loading STAC IK output from /home/talmolab/Desktop/SalkResearch/data/stac_ik_output.h5...
Attributes: {}
qpos shape: (692333, 36)


KeyError: "Unable to synchronously open object (object 'clip_lengths' doesn't exist)"

In [None]:
# Reshape qpos into clips based on clip_lengths
qpos_clips = []
start_idx = 0

for clip_len in clip_lengths:
    end_idx = start_idx + clip_len
    qpos_clips.append(qpos_all[start_idx:end_idx])
    start_idx = end_idx

print(f"Split into {len(qpos_clips)} clips")
print(f"Clip length range: {min(clip_lengths)} - {max(clip_lengths)}")

## 4. Load FicTrac Trajectory Data

In [None]:
# Load the original CSV to get FicTrac data
print(f"Loading FicTrac data from {CSV_DATA_PATH}...")
full_df = pd.read_csv(CSV_DATA_PATH)
print(f"Loaded dataframe with shape: {full_df.shape}")

# Check for FicTrac columns
fictrac_cols = [col for col in full_df.columns if 'fictrac' in col.lower()]
print(f"\nFound {len(fictrac_cols)} FicTrac columns:")
for col in fictrac_cols[:10]:
    print(f"  - {col}")
if len(fictrac_cols) > 10:
    print(f"  ... and {len(fictrac_cols) - 10} more")

In [None]:
# Extract walking bouts and FicTrac trajectory data
all_bout_nums = full_df["walking_bout_number"].unique()
all_bout_nums = all_bout_nums[all_bout_nums > 0]  # Skip bout 0
all_bout_nums = sorted(all_bout_nums)

print(f"Found {len(all_bout_nums)} walking bouts")

# Extract trajectory data for each bout
int_x_cm = []  # Integrated X position in cm
int_y_cm = []  # Integrated Y position in cm
heading_deg = []  # Heading angle
bout_clip_lengths = []  # Actual bout lengths from CSV

for bout_num in tqdm(all_bout_nums, desc="Extracting FicTrac data"):
    bout = full_df[full_df["walking_bout_number"] == bout_num]
    bout_clip_lengths.append(len(bout))
    
    # Extract integrated position (convert mm to cm, swap axes for body model)
    # Note: Axes are swapped because the model coordinate system differs from FicTrac
    x_mm = bout["fictrac_int_y_mm"].values  # Swapped
    y_mm = bout["fictrac_int_x_mm"].values  # Swapped
    
    # Convert to cm and subtract initial position
    int_x_cm.append((x_mm / 10) - (x_mm[0] / 10))
    int_y_cm.append((y_mm / 10) - (y_mm[0] / 10))
    
    # Extract heading (subtract initial to get relative heading)
    heading = bout["fictrac_heading"].values
    heading_deg.append(heading - heading[0])

print(f"\nExtracted {len(int_x_cm)} trajectory clips")
print(f"Bout length range: {min(bout_clip_lengths)} - {max(bout_clip_lengths)}")

In [None]:
# Verify clip counts match
print(f"STAC IK clips: {len(qpos_clips)}")
print(f"FicTrac clips: {len(int_x_cm)}")

if len(qpos_clips) != len(int_x_cm):
    print(f"WARNING: Clip count mismatch!")
    # Use minimum
    n_clips = min(len(qpos_clips), len(int_x_cm))
    print(f"Using first {n_clips} clips")
else:
    n_clips = len(qpos_clips)
    print("Clip counts match!")

# Verify lengths match for each clip
length_mismatches = []
for i in range(n_clips):
    qpos_len = len(qpos_clips[i])
    traj_len = len(int_x_cm[i])
    if qpos_len != traj_len:
        length_mismatches.append((i, qpos_len, traj_len))

if length_mismatches:
    print(f"\nFound {len(length_mismatches)} clips with length mismatches:")
    for i, qlen, tlen in length_mismatches[:5]:
        print(f"  Clip {i}: qpos={qlen}, traj={tlen}")
else:
    print("All clip lengths match!")

## 5. Quaternion Utilities

Functions for quaternion operations (rotation, multiplication) needed for yaw integration.

In [None]:
def quat_rot_axis(axis, angle):
    """
    Create a quaternion representing rotation around an axis.
    
    Args:
        axis: (3,) rotation axis (should be normalized)
        angle: rotation angle in radians
        
    Returns:
        q: (4,) quaternion [w, x, y, z]
    """
    axis = axis / jnp.linalg.norm(axis)
    half_angle = angle / 2.0
    w = jnp.cos(half_angle)
    xyz = axis * jnp.sin(half_angle)
    return jnp.array([w, xyz[0], xyz[1], xyz[2]])


def quat_mul(q1, q2):
    """
    Multiply two quaternions.
    
    Args:
        q1: (4,) first quaternion [w, x, y, z]
        q2: (4,) second quaternion [w, x, y, z]
        
    Returns:
        q: (4,) result quaternion
    """
    w1, x1, y1, z1 = q1
    w2, x2, y2, z2 = q2
    
    w = w1*w2 - x1*x2 - y1*y2 - z1*z2
    x = w1*x2 + x1*w2 + y1*z2 - z1*y2
    y = w1*y2 - x1*z2 + y1*w2 + z1*x2
    z = w1*z2 + x1*y2 - y1*x2 + z1*w2
    
    return jnp.array([w, x, y, z])


def normalize_quat(q):
    """
    Normalize a quaternion to unit length.
    """
    return q / jnp.linalg.norm(q)

## 6. Trajectory Processing Functions

In [None]:
def calculate_yaw_from_trajectory(int_x, int_y, smoothing_window=75):
    """
    Calculate smoothed yaw angles from trajectory position data.
    
    The yaw angle is computed from the direction of travel (velocity),
    then smoothed with a Gaussian kernel for natural transitions.
    
    Args:
        int_x: (T,) X position trajectory
        int_y: (T,) Y position trajectory
        smoothing_window: window size for Gaussian smoothing
        
    Returns:
        yaw_angles: (T,) smoothed yaw angles in radians
    """
    N = smoothing_window
    original_length = int_x.shape[0]
    
    # Calculate velocity using smoothed finite differences
    vel_window = 5
    pad_x = jnp.pad(int_x, vel_window // 2, mode="edge")
    pad_y = jnp.pad(int_y, vel_window // 2, mode="edge")
    
    # Smooth velocity calculation
    vel_x = jnp.convolve(jnp.diff(pad_x), jnp.ones(vel_window) / vel_window, mode="valid")
    vel_y = jnp.convolve(jnp.diff(pad_y), jnp.ones(vel_window) / vel_window, mode="valid")
    
    # Pad velocity to match original length
    current_length = vel_x.shape[0]
    pad_size = original_length - current_length
    pad_left = pad_size // 2
    pad_right = pad_size - pad_left
    vel_x = jnp.pad(vel_x, (pad_left, pad_right), mode="edge")
    vel_y = jnp.pad(vel_y, (pad_left, pad_right), mode="edge")
    
    # Calculate yaw from velocity direction
    dyaw_t = jnp.arctan2(vel_y, vel_x)
    
    # Unwrap to handle angle discontinuities
    unwrapped_yaw = jnp.unwrap(dyaw_t)
    
    # Gaussian smoothing kernel
    sigma = N / 6.0
    kernel_size = N + 1
    x = jnp.arange(-N // 2, N // 2 + 1)
    gaussian_kernel = jnp.exp(-0.5 * (x / sigma) ** 2)
    gaussian_kernel = gaussian_kernel / jnp.sum(gaussian_kernel)
    
    # Apply Gaussian smoothing
    pad_width = kernel_size // 2
    padded_yaw = jnp.pad(unwrapped_yaw, pad_width, mode="edge")
    smooth_yaw_conv = jnp.convolve(padded_yaw, gaussian_kernel, mode="valid")
    
    # Trim to original length
    diff = smooth_yaw_conv.shape[0] - original_length
    start = diff // 2
    smooth_yaw = jax.lax.dynamic_slice(smooth_yaw_conv, (start,), (original_length,))
    
    return smooth_yaw

In [None]:
def process_single_clip_trajectory(
    int_x_clip,
    int_y_clip,
    qpos_clip,
    wing_joint_idxs,
    default_wing_pos,
    smoothing_window=75,
    z_offset=0.0195,
):
    """
    Integrate trajectory data into a clip's qpos.
    
    Updates:
    - X, Y positions from FicTrac integrated trajectory
    - Z position adjusted by offset
    - Quaternion orientation updated with calculated yaw
    - Wing joints set to default (folded) position
    
    Args:
        int_x_clip: (T,) X trajectory positions in cm
        int_y_clip: (T,) Y trajectory positions in cm
        qpos_clip: (T, nq) joint positions from STAC
        wing_joint_idxs: indices of wing joints in qpos
        default_wing_pos: default wing joint positions
        smoothing_window: window for yaw smoothing
        z_offset: height offset for z position
        
    Returns:
        updated_qpos: (T, nq) qpos with integrated trajectory
    """
    # Calculate smoothed yaw from trajectory
    smooth_yaw = calculate_yaw_from_trajectory(int_x_clip, int_y_clip, smoothing_window)
    
    # Get initial quaternions from STAC output
    # qpos format: [x, y, z, qw, qx, qy, qz, joint1, joint2, ...]
    initial_quaternion = qpos_clip[:, 3:7]
    
    # Create yaw rotation quaternions (rotation around Z-axis)
    z_axis = jnp.array([0.0, 0.0, 1.0])
    yaw_quaternions = vmap(lambda angle: quat_rot_axis(z_axis, angle))(smooth_yaw)
    
    # Apply yaw rotation to initial orientation
    final_quaternions = vmap(quat_mul)(yaw_quaternions, initial_quaternion)
    
    # Normalize quaternions
    final_quaternions = vmap(normalize_quat)(final_quaternions)
    
    # Update qpos with trajectory data
    updated_qpos = qpos_clip.at[:, 0].set(int_x_clip)  # X position
    updated_qpos = updated_qpos.at[:, 1].set(int_y_clip)  # Y position
    updated_qpos = updated_qpos.at[:, 2].set(jnp.min(qpos_clip[:, 2]) - z_offset)  # Z position
    updated_qpos = updated_qpos.at[:, 3:7].set(final_quaternions)  # Quaternion
    
    # Set wing joints to default position
    if wing_joint_idxs is not None and len(wing_joint_idxs) > 0:
        for i, idx in enumerate(wing_joint_idxs):
            if i < len(default_wing_pos):
                updated_qpos = updated_qpos.at[:, idx].set(default_wing_pos[i])
    
    return updated_qpos


# JIT compile
jit_process_clip = jit(process_single_clip_trajectory, static_argnames=["smoothing_window"])

## 7. Get Wing Joint Indices

In [None]:
# Load model to find wing joint indices
print(f"Loading model from {MODEL_PATH}...")
spec = mujoco.MjSpec().from_file(str(MODEL_PATH))
mj_model = spec.compile()

# Find wing joints
wing_joint_names = [joint.name for joint in spec.joints if 'wing' in joint.name.lower()]
print(f"Wing joints: {wing_joint_names}")

# Get joint indices in qpos
# Note: qpos includes root position (3) + quaternion (4) + joint angles
# So joint indices start at 7
all_joint_names = [joint.name for joint in spec.joints]
wing_joint_idxs = []

for wing_name in wing_joint_names:
    if wing_name in all_joint_names:
        # Find the index in joint list (excluding free joint if any)
        joint_idx = all_joint_names.index(wing_name)
        # Add offset for root position and quaternion
        qpos_idx = 7 + joint_idx  # 3 pos + 4 quat + joint index
        wing_joint_idxs.append(qpos_idx)
        print(f"  {wing_name}: joint_idx={joint_idx}, qpos_idx={qpos_idx}")

wing_joint_idxs = tuple(wing_joint_idxs) if wing_joint_idxs else None
print(f"\nWing joint qpos indices: {wing_joint_idxs}")

## 8. Visualize Sample Trajectory

In [None]:
# Visualize a sample trajectory before processing
sample_idx = 19  # Change this to view different clips

if sample_idx < len(int_x_cm):
    x = int_x_cm[sample_idx]
    y = int_y_cm[sample_idx]
    
    # Calculate yaw for visualization
    test_yaw = np.array(calculate_yaw_from_trajectory(jnp.array(x), jnp.array(y), SMOOTHING_WINDOW))
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot trajectory with yaw directions
    u = np.cos(test_yaw)
    v = np.sin(test_yaw)
    
    axes[0].plot(x, y, 'k-', linewidth=2, label='Trajectory')
    axes[0].scatter(x[0], y[0], c='g', s=100, label='Start', zorder=5)
    axes[0].scatter(x[-1], y[-1], c='r', s=100, label='End', zorder=5)
    
    # Plot yaw direction arrows (every 10 frames)
    step = max(1, len(x) // 20)
    axes[0].quiver(x[::step], y[::step], u[::step], v[::step],
                   angles='xy', scale_units='xy', scale=10, color='b', alpha=0.7)
    
    axes[0].set_xlabel('X (cm)')
    axes[0].set_ylabel('Y (cm)')
    axes[0].set_title(f'Clip {sample_idx}: COM Trajectory with Yaw Direction')
    axes[0].axis('equal')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Plot yaw over time
    axes[1].plot(test_yaw, 'b-', linewidth=2)
    axes[1].set_xlabel('Frame')
    axes[1].set_ylabel('Yaw Angle (radians)')
    axes[1].set_title('Yaw Angle Evolution')
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
else:
    print(f"Sample index {sample_idx} out of range (max: {len(int_x_cm)-1})")

## 9. Process All Clips

In [None]:
# Process all clips with trajectory integration
print(f"Processing {n_clips} clips with trajectory integration...")

processed_qpos = []

for clip_idx in tqdm(range(n_clips), desc="Integrating trajectories"):
    # Get data for this clip
    int_x = jnp.array(int_x_cm[clip_idx])
    int_y = jnp.array(int_y_cm[clip_idx])
    qpos = jnp.array(qpos_clips[clip_idx])
    
    # Handle length mismatch by truncating to minimum
    min_len = min(len(int_x), len(qpos))
    int_x = int_x[:min_len]
    int_y = int_y[:min_len]
    qpos = qpos[:min_len]
    
    # Process this clip
    updated_qpos = jit_process_clip(
        int_x,
        int_y,
        qpos,
        wing_joint_idxs,
        DEFAULT_WING_POS,
        smoothing_window=SMOOTHING_WINDOW,
        z_offset=Z_OFFSET,
    )
    
    processed_qpos.append(np.array(updated_qpos))

print(f"\nProcessed {len(processed_qpos)} clips")

## 10. Optional: Interpolate to Higher Frequency

In [None]:
def interpolate_clip(clip_data, source_hz, target_hz):
    """
    Interpolate a clip to a different frequency using cubic splines.
    
    Args:
        clip_data: (T, D) array
        source_hz: source frequency
        target_hz: target frequency
        
    Returns:
        interpolated: (T_new, D) array
    """
    T, D = clip_data.shape
    duration = T / source_hz
    
    t_source = np.linspace(0, duration, T)
    T_new = int(duration * target_hz)
    t_target = np.linspace(0, duration, T_new)
    
    interpolated = np.zeros((T_new, D))
    for d in range(D):
        spline = interpolate.CubicSpline(t_source, clip_data[:, d])
        interpolated[:, d] = spline(t_target)
    
    # Re-normalize quaternions (indices 3:7)
    quat = interpolated[:, 3:7]
    quat_norm = np.linalg.norm(quat, axis=1, keepdims=True)
    interpolated[:, 3:7] = quat / quat_norm
    
    return interpolated

In [None]:
if TARGET_HZ is not None and TARGET_HZ != SOURCE_HZ:
    print(f"Interpolating from {SOURCE_HZ}Hz to {TARGET_HZ}Hz...")
    
    interpolated_qpos = []
    new_clip_lengths = []
    
    for clip_idx, clip in enumerate(tqdm(processed_qpos, desc="Interpolating")):
        interp_clip = interpolate_clip(clip, SOURCE_HZ, TARGET_HZ)
        interpolated_qpos.append(interp_clip)
        new_clip_lengths.append(len(interp_clip))
    
    final_qpos = interpolated_qpos
    final_clip_lengths = new_clip_lengths
    output_hz = TARGET_HZ
    
    print(f"Interpolation complete!")
    print(f"Original total frames: {sum(len(c) for c in processed_qpos)}")
    print(f"Interpolated total frames: {sum(final_clip_lengths)}")
else:
    print(f"No interpolation (source={SOURCE_HZ}Hz, target={TARGET_HZ}Hz)")
    final_qpos = processed_qpos
    final_clip_lengths = [len(c) for c in processed_qpos]
    output_hz = SOURCE_HZ

## 11. Visualize Results

In [None]:
# Compare before and after trajectory integration
sample_idx = 19

if sample_idx < len(qpos_clips) and sample_idx < len(final_qpos):
    original = qpos_clips[sample_idx]
    processed = final_qpos[sample_idx]
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # X position
    axes[0, 0].plot(original[:, 0], 'b-', label='Original', alpha=0.7)
    # Resample processed to match original length for comparison if interpolated
    if len(processed) != len(original):
        t_orig = np.arange(len(original))
        t_proc = np.linspace(0, len(original)-1, len(processed))
        axes[0, 0].plot(t_proc, processed[:, 0], 'r-', label='With Trajectory', alpha=0.7)
    else:
        axes[0, 0].plot(processed[:, 0], 'r-', label='With Trajectory', alpha=0.7)
    axes[0, 0].set_ylabel('X Position (cm)')
    axes[0, 0].set_title('X Position')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Y position
    axes[0, 1].plot(original[:, 1], 'b-', label='Original', alpha=0.7)
    if len(processed) != len(original):
        axes[0, 1].plot(t_proc, processed[:, 1], 'r-', label='With Trajectory', alpha=0.7)
    else:
        axes[0, 1].plot(processed[:, 1], 'r-', label='With Trajectory', alpha=0.7)
    axes[0, 1].set_ylabel('Y Position (cm)')
    axes[0, 1].set_title('Y Position')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # XY trajectory
    axes[1, 0].plot(original[:, 0], original[:, 1], 'b-', label='Original', alpha=0.7)
    axes[1, 0].plot(processed[:, 0], processed[:, 1], 'r-', label='With Trajectory', alpha=0.7)
    axes[1, 0].scatter(processed[0, 0], processed[0, 1], c='g', s=100, zorder=5, label='Start')
    axes[1, 0].set_xlabel('X (cm)')
    axes[1, 0].set_ylabel('Y (cm)')
    axes[1, 0].set_title('XY Trajectory')
    axes[1, 0].axis('equal')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Quaternion W component (shows orientation changes)
    axes[1, 1].plot(original[:, 3], 'b-', label='Original qw', alpha=0.7)
    if len(processed) != len(original):
        axes[1, 1].plot(t_proc, processed[:, 3], 'r-', label='With Trajectory qw', alpha=0.7)
    else:
        axes[1, 1].plot(processed[:, 3], 'r-', label='With Trajectory qw', alpha=0.7)
    axes[1, 1].set_xlabel('Frame')
    axes[1, 1].set_ylabel('Quaternion W')
    axes[1, 1].set_title('Orientation (Quaternion W)')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
else:
    print(f"Sample index {sample_idx} out of range")

## 12. Export Final Results

In [None]:
# Concatenate all clips
qpos_final = np.concatenate(final_qpos, axis=0)
print(f"Final qpos shape: {qpos_final.shape}")
print(f"Total frames: {qpos_final.shape[0]}")
print(f"Output frequency: {output_hz}Hz")

In [None]:
# Save final trajectory-integrated data
print(f"Saving to {OUTPUT_PATH}...")

with h5py.File(OUTPUT_PATH, 'w') as f:
    # Metadata
    f.attrs['description'] = 'Trajectory-integrated walking data with FicTrac motion'
    f.attrs['source_stac_ik'] = str(STAC_IK_PATH)
    f.attrs['source_csv'] = str(CSV_DATA_PATH)
    f.attrs['output_hz'] = output_hz
    f.attrs['source_hz'] = SOURCE_HZ
    f.attrs['smoothing_window'] = SMOOTHING_WINDOW
    f.attrs['z_offset'] = Z_OFFSET
    f.attrs['n_clips'] = len(final_qpos)
    f.attrs['n_frames'] = qpos_final.shape[0]
    
    # Main data
    f.create_dataset('qpos', data=qpos_final, compression='gzip')
    f.create_dataset('clip_lengths', data=np.array(final_clip_lengths))
    
    # Also save individual clips for convenience
    clips_grp = f.create_group('clips')
    for i, clip in enumerate(final_qpos):
        clips_grp.create_dataset(f'clip_{i:04d}', data=clip, compression='gzip')
    
    # Copy over other relevant data from STAC output
    if offsets is not None:
        f.create_dataset('offsets', data=offsets, compression='gzip')
    if joint_names is not None:
        f.create_dataset('joint_names', data=np.array(joint_names, dtype='S'))
    if kp_names is not None:
        f.create_dataset('kp_names', data=np.array(kp_names, dtype='S'))

print(f"Saved!")

In [None]:
# Verify saved file
print("\nVerifying saved file...")
with h5py.File(OUTPUT_PATH, 'r') as f:
    print(f"Attributes: {dict(f.attrs)}")
    print(f"\nDatasets:")
    for key in f.keys():
        if isinstance(f[key], h5py.Dataset):
            print(f"  {key}: {f[key].shape}")
        elif isinstance(f[key], h5py.Group):
            print(f"  {key}/: {len(f[key])} items")

print(f"\nOutput file: {OUTPUT_PATH}")

## Summary

This notebook has:
1. Loaded STAC IK output (joint angles) from Notebook 2
2. Loaded FicTrac trajectory data (wheel rotation/displacement) from the original CSV
3. Calculated yaw orientation from the trajectory direction of travel
4. Applied trajectory position (X, Y) and orientation (yaw) to update qpos
5. Set wing joints to default folded position
6. Optionally interpolated to higher frequency
7. Exported trajectory-integrated data to H5

**Output:** `walking_trajectory_integrated.h5` containing:
- `qpos`: Joint positions with integrated trajectory motion
- `clip_lengths`: Length of each walking bout clip
- `clips/`: Individual clips for convenience

The data is now ready for use in simulation, imitation learning, or biomechanical analysis.