# Raw: Rendering & Analysis (Fly & Rodent)

- Modified from Scott's notebook

This notebook will show you how to load the saved rollout, and create a rendering video from that, with further visual analysis pipeline to visualize the agents, i.e. temporal dynamics of the intentions.

In [1]:
import os
import sys
from pathlib import Path

main_path = Path().resolve().parent
if str(main_path) not in sys.path:
    sys.path.append(str(main_path))

os.environ["MUJOCO_GL"] = "egl"
os.environ["PYOPENGL_PLATFORM"] = "egl"

from typing import List
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm
import json
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='retina'
from matplotlib.backends.backend_agg import FigureCanvasAgg
import matplotlib.animation as animation

from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from PIL import Image
from IPython.display import HTML


from track_mjx.environment.task.multi_clip_tracking import MultiClipTracking
from track_mjx.environment.walker.rodent import Rodent

import mujoco
from pathlib import Path
from dm_control import mjcf as mjcf_dm
from dm_control.locomotion.walkers import rescale
import imageio
import numpy as np

import multiprocessing as mp
import functools

# Rendering Related Helper Functions

In [2]:
def agg_backend_context(func):
    """
    Decorator to switch to a headless backend during function execution.
    """

    def wrapper(*args, **kwargs):
        orig_backend = matplotlib.get_backend()
        matplotlib.use("Agg")  # Switch to headless 'Agg' to inhibit figure rendering.
        # Code to execute BEFORE the original function
        result = func(*args, **kwargs)
        # Code to execute AFTER the original function
        plt.close("all")  # Figure auto-closing upon backend switching is deprecated.
        matplotlib.use(orig_backend)
        return result

    return wrapper


def render_from_saved_rollout(
    rollout: dict,
    walker_name: str,
) -> list:
    """
    Render a rollout from saved qposes.

    Args:
        rollout (dict): A dictionary containing the qposes of the reference and rollout trajectories.

    Returns:
        list: list of frames of the rendering
    """
    qposes_ref, qposes_rollout = rollout["qposes_ref"], rollout["qposes_rollout"]
    # print(len(qposes_rollout))
    # need to change to the new xml file
    
    if walker_name == "rodent":
        pair_render_xml_path = (
            #"/root/vast/kaiwen/track-mjx/track_mjx/environment/walker/assets/fruitfly/fruitfly_force_pair.xml"
            "/root/vast/scott-yang/track-mjx/track_mjx/environment/walker/assets/rodent/rodent_ghostpair_scale080.xml"
        )
        camera_name = "close_profile"
        
        spec = mujoco.MjSpec()
        spec = spec.from_file(str(pair_render_xml_path))
        
        # in training scaled by this amount as well
        for geom in spec.geoms:
            if geom.size is not None:
                geom.size *= 0.95
            if geom.pos is not None:
                geom.pos *= 0.95
    else:
        pair_render_xml_path = (
            "/root/vast/kaiwen/track-mjx/track_mjx/environment/walker/assets/fruitfly/fruitfly_force_pair.xml"
        )
        camera_name = "track1-0"
        
        spec = mujoco.MjSpec()
        spec = spec.from_file(str(pair_render_xml_path))
        
        # in training scaled by this amount as well
        for geom in spec.geoms:
            if geom.size is not None:
                geom.size *= 1
            if geom.pos is not None:
                geom.pos *= 1

    mj_model = spec.compile()

    mj_model.opt.solver = {
        "cg": mujoco.mjtSolver.mjSOL_CG,
        "newton": mujoco.mjtSolver.mjSOL_NEWTON,
    }["cg"]

    mj_model.opt.iterations = 6
    mj_model.opt.ls_iterations = 6
    mj_data = mujoco.MjData(mj_model)

    site_id = [
        mj_model.site(i).id
        for i in range(mj_model.nsite)
        if "-0" in mj_model.site(i).name
    ]
    for id in site_id:
        mj_model.site(id).rgba = [1, 0, 0, 1]
    
    for i in range(mj_model.ngeom):
        geom_name = mj_model.geom(i).name
        if "-1" in geom_name:  # ghost
            mj_model.geom(i).rgba = [
                1,
                1,
                1,
                0.5,
            ]  # White color, 50% transparent

    # visual mujoco rendering
    scene_option = mujoco.MjvOption()
    scene_option.sitegroup[:] = [1, 1, 1, 1, 1, 0]
    # save rendering and log to wandb
    mujoco.mj_kinematics(mj_model, mj_data)
    renderer = mujoco.Renderer(mj_model, height=480, width=640)
    frames = []
    print("MuJoCo Rendering...")
    for qpos1, qpos2 in tqdm(zip(qposes_rollout, qposes_ref), total=len(qposes_rollout)):
        mj_data.qpos = np.append(qpos1, qpos2)
        mujoco.mj_forward(mj_model, mj_data)
        renderer.update_scene(
            mj_data,
            camera=camera_name
        )
        pixels = renderer.render()
        frames.append(pixels)
        
    return frames


def plot_pca_intention(
    idx,
    episode_start,
    pca_projections: np.ndarray,
    clip_idx: int,
    feature_name: str,
    n_components: int = 4,
    terminated=False,
):
    """
    plot pca intention progression of the episode
    Args:
        idx: the current timestep
        episode_start: the start timestep of the episode
        pca_projections: the pca projection of the episode, shape (timestep, n_components)
        clip_idx: the clip index
        feature_name: the feature name
        n_components: the number of pca components to plot
        ylim: the y-axis limit
        terminated: whether the episode is terminated

    """
    max_y = np.max(list(pca_projections[:, :n_components]))
    min_y = np.min(list(pca_projections[:, :n_components]))
    y_lim = (min_y - 0.2, max_y + 0.2)
    window_size = 530
    idx_in_this_episode = idx - episode_start  # the current timestep in this episode
    plt.figure(figsize=(9.6, 4.8))
    for pc_ind in range(n_components):
        # Plot the PCA projection of the episode
        plt.plot(
            pca_projections[episode_start:idx, pc_ind],
            label=f"PC {pc_ind} ({pca.explained_variance_ratio_[pc_ind]*100:.1f}%)",
        )
        plt.scatter(idx - episode_start, pca_projections[idx - 1, pc_ind])
    if terminated:
        # Mark the episode termination
        plt.axvline(x=idx - episode_start, color="r", linestyle="-")
        plt.text(
            idx - episode_start - 8,  # Adjust the x-offset as needed
            sum(y_lim) / 2,  # Adjust the y-position as needed
            "Episode Terminated",
            color="r",
            rotation=90,
        )  # Rotate the text vertically
    if idx_in_this_episode <= window_size:
        plt.xlim(0, window_size)
    else:
        plt.xlim(idx_in_this_episode - window_size, idx_in_this_episode)  # dynamically move xlim as time progress
    plt.ylim(*y_lim)
    plt.legend(loc="upper right")
    plt.xlabel("Timestep")
    plt.title(f"PCA {feature_name} Progression for Clip {clip_idx}")  # TODO make it configurable
    # Get the current figure
    fig = plt.gcf()
    # Create a canvas for rendering
    canvas = FigureCanvasAgg(fig)
    # Render the canvas to a buffer
    canvas.draw()
    s, (width, height) = canvas.print_to_buffer()
    # Convert the buffer to a PIL Image
    image = Image.frombytes("RGBA", (width, height), s)
    rgb_array = np.array(image.convert("RGB"))
    return rgb_array


def render_with_pca_progression(
    rollout: dict, pca_projections: np.ndarray, n_components: int = 4, feature_name: str = "ctrl"
):
    """
    render with the rewards progression graph concat alongside with the rendering
    """
    frames_mujoco = render_from_saved_rollout(rollout)[1:]
    # skip the first frame, since we don't have intention for the first frame
    orig_backend = matplotlib.get_backend()
    matplotlib.use("Agg")  # Switch to headless 'Agg' to inhibit figure rendering.
    clip_idx = int(rollout["info"][0]["clip_idx"])
    worker = functools.partial(
        plot_pca_intention,
        episode_start=0,
        clip_idx=clip_idx,
        pca_projections=pca_embedded,
        n_components=n_components,
        feature_name=feature_name,
    )
    print("Rendering with PCA progression...")
    # Use multiprocessing to parallelize the rendering of the reward graph
    with mp.Pool(processes=mp.cpu_count()) as pool:
        frames_pca = pool.map(worker, range(len(rollout["qposes_rollout"])))
    concat_frames = []
    episode_start = 0
    # implement reset logics of the reward graph too.
    print("Concatenating frames...")
    for idx, frame in tqdm(enumerate(frames_mujoco)):
        concat_frames.append(np.hstack([frame, frames_pca[idx]]))
    reward_plot = plot_pca_intention(
        len(frames_mujoco) - 1, episode_start, pca_projections, clip_idx, feature_name, n_components, terminated=True
    )
    plt.close("all")  # Figure auto-closing upon backend switching is deprecated.
    matplotlib.use(orig_backend)
    for _ in range(50):
        concat_frames.append(np.hstack([frames_mujoco[-1], reward_plot]))  # create stoppage when episode terminates
    return concat_frames


def display_video(frames, framerate=30):
    """
    Args:
        frames (array): (n_frames, height, width, 3)
        framerate (int): the framerate of the video
    """
    height, width, _ = frames[0].shape
    dpi = 70
    orig_backend = matplotlib.get_backend()
    matplotlib.use("Agg")  # Switch to headless 'Agg' to inhibit figure rendering.
    fig, ax = plt.subplots(1, 1, figsize=(width / dpi, height / dpi), dpi=dpi)
    plt.close("all")  # Figure auto-closing upon backend switching is deprecated.
    matplotlib.use(orig_backend)  # Switch back to the original backend.
    ax.set_axis_off()
    ax.set_aspect("equal")
    ax.set_position([0, 0, 1, 1])
    im = ax.imshow(frames[0])

    def update(frame):
        im.set_data(frame)
        return [im]

    interval = 1000 / framerate
    anim = animation.FuncAnimation(fig=fig, func=update, frames=frames, interval=interval, blit=True, repeat=False)
    return HTML(anim.to_html5_video())

## Step 1: Load the rollout file from the `.h5` file

In [3]:
import h5py

def load_from_h5py(file, group_path="/") -> dict:
    """
    Load a pytree structure from an HDF5 file.

    Args:
        file (h5py.File): An open HDF5 file object.
        group_path (str): The HDF5 group path to read data from.

    Returns:
        The reconstructed data structure.
    """
    group = file[group_path]
    if isinstance(group, h5py.Dataset):
        return group[()]  # Read dataset value
    elif isinstance(group, h5py.Group):
        if all(k.isdigit() for k in group.keys()):  # Likely a list
            return [load_from_h5py(file, f"{group_path}/{k}") for k in sorted(group.keys(), key=int)]
        else:  # Dictionary-like group
            return {k: load_from_h5py(file, f"{group_path}/{k}") for k in group.keys()}
    else:
        raise TypeError(f"Unsupported group type: {type(group)}")

# Example usage
with h5py.File("/root/vast/kaiwen/track-mjx/rodent_rollout_info/clip_0.h5", "r") as h5file:
    rollout = load_from_h5py(h5file)

We load in both of the fly and the rodent to do comparisons

In [4]:
# directly get out the activations
with h5py.File("/root/vast/kaiwen/track-mjx/rodent_rollout_info/clip_1.h5", "r") as h5file:
    activations_fly = load_from_h5py(h5file, group_path="/activations")
    intentions_fly = [a["intention"] for a in activations_fly]

with h5py.File("/root/vast/scott-yang/rodent_rollout_info/data/clip_1.h5", "r") as h5file:
    activations_rodent = load_from_h5py(h5file, group_path="/activations")
    intentions_rodent = [a["intention"] for a in activations_rodent]

In [None]:
activations_fly[0]["intention"].shape

# Analysis: PCA for all intentions across rollout

The following cell traverse though the recorded rollout `.h5` file in the directly, and parse out the intentions of each episode. All of the intentions vectors are aggregated into a single matrix for PCA analysis

In [6]:
def get_aggregate_data(group_path, keys: List[str], clip_idx: int, path: str):
    """
    Get the aggregate data from the hdf5 file
    """
    with h5py.File(path + f'/clip_{clip_idx}.h5', "r") as h5file:
        data = load_from_h5py(h5file, group_path=group_path)
        for key in keys:
            if type(data) == list and type(data[0]) == dict:
                data = [d[key] for d in data]
            elif type(data) == dict:
                data = data[key]
            else:
                raise ValueError("Data structure not supported")
    return data

# MultiProcessing SpeedUp

this will take 0.63 * 850 = 535 seconds = 8.9 minutes to run, if I run it in a ordinary for loop

In [None]:
%%timeit
# if I wanna get the activations for the decoder layer 0 of clip 1
get_aggregate_data("/activations", ["intention"], 1, path=f"/root/vast/kaiwen/track-mjx/rodent_rollout_info")

# this will take 0.63 * 850 = 535 seconds = 8.9 minutes to run

 However, if I multiprocessing the IO call, we only need 26 seconds to complete the call. 

In [None]:
import multiprocessing as mp
from tqdm import tqdm

work_fly = functools.partial(get_aggregate_data, "/activations", ['intention'], path=f"/root/vast/kaiwen/track-mjx/rodent_rollout_info")
work_rodent = functools.partial(get_aggregate_data, "/activations", ['intention'], path=f"/root/vast/scott-yang/rodent_rollout_info/data/")

with mp.Pool(processes=mp.cpu_count()) as pool:
    activations_fly = list(tqdm(pool.imap(work_fly, range(499)), total=499))
    activations_rodent = list(tqdm(pool.imap(work_rodent, range(842)), total=842))

activations_fly = np.vstack(activations_fly)
activations_rodent = np.vstack(activations_rodent)

In [None]:
activations_fly.shape, activations_rodent.shape

# 2D PCA + Cluster Analysis

In [None]:
scaler = StandardScaler()
activations_rodent = scaler.fit_transform(activations_rodent)
activations_fly = scaler.fit_transform(activations_fly)

pca_rodent = PCA()

pca_rodent = pca_rodent.fit(activations_rodent)
print(np.cumsum(pca_rodent.explained_variance_ratio_[:10]))
pca_embedded_rodent = pca_rodent.transform(activations_rodent)

pca_fly = PCA()

pca_fly = pca_fly.fit(activations_fly)
print(np.cumsum(pca_fly.explained_variance_ratio_[:10]))
pca_embedded_fly = pca_fly.transform(activations_fly)

In [None]:
plt.scatter(pca_embedded_rodent[:, 0], pca_embedded_rodent[:, 1], c=np.arange(420158), cmap="tab20", alpha=0.5)
plt.xlabel(f"PCA 1 {pca_rodent.explained_variance_ratio_[0]*100:.2f}%")
plt.ylabel(f"PCA 2 {pca_rodent.explained_variance_ratio_[1]*100:.2f}%")
plt.title("PCA of intentions across all episodes")

In [None]:
plt.scatter(pca_embedded_fly[:, 0], pca_embedded_fly[:, 1], c=np.arange(498501), cmap="tab20", alpha=0.5)
plt.xlabel(f"PCA 1 {pca_fly.explained_variance_ratio_[0]*100:.2f}%")
plt.ylabel(f"PCA 2 {pca_fly.explained_variance_ratio_[1]*100:.2f}%")
plt.title("PCA of intentions across all episodes")

In [None]:
pca_rodent = PCA(n_components=3)
pca_embedded_fly = pca_rodent.fit_transform(activations_rodent) 

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')

N = pca_embedded_fly.shape[0]
c_values = np.arange(N)

sc = ax.scatter(
    pca_embedded_fly[:, 0],
    pca_embedded_fly[:, 1],
    pca_embedded_fly[:, 2],
    c=c_values,
    cmap="tab20",
    alpha=0.5
)

cb = plt.colorbar(sc, ax=ax, shrink=0.6)
cb.set_label("Sample Index")

ax.set_xlabel(f"PC1 ({pca_rodent.explained_variance_ratio_[0]*100:.2f}%)")
ax.set_ylabel(f"PC2 ({pca_rodent.explained_variance_ratio_[1]*100:.2f}%)")
ax.set_zlabel(f"PC3 ({pca_rodent.explained_variance_ratio_[2]*100:.2f}%)")

ax.set_title("3D PCA of intentions across all episodes")
plt.tight_layout()
plt.show()

In [None]:
pca_fly = PCA(n_components=3)
pca_embedded_fly = pca_fly.fit_transform(activations_fly) 

k = 6
kmeans_fly = KMeans(n_clusters=k, random_state=42).fit(pca_embedded_fly)
fly_labels = kmeans_fly.labels_

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')

N = pca_embedded_fly.shape[0]
c_values = np.arange(N)

sc = ax.scatter(
    pca_embedded_fly[:, 0],
    pca_embedded_fly[:, 1],
    pca_embedded_fly[:, 2],
    c=fly_labels,
    cmap="tab20",
    alpha=0.5
)

cb = plt.colorbar(sc, ax=ax, shrink=0.6)
cb.set_label("Sample Index")

ax.set_xlabel(f"PC1 ({pca_fly.explained_variance_ratio_[0]*100:.2f}%)")
ax.set_ylabel(f"PC2 ({pca_fly.explained_variance_ratio_[1]*100:.2f}%)")
ax.set_zlabel(f"PC3 ({pca_fly.explained_variance_ratio_[2]*100:.2f}%)")

ax.set_title("3D PCA of intentions across all episodes")
plt.tight_layout()
plt.show()

In [72]:
# import plotly.express as px
# pca_fly = PCA(n_components=3)
# pca_embedded_fly = pca_fly.fit_transform(activations_fly) 

# k = 6
# kmeans_fly = KMeans(n_clusters=k, random_state=42).fit(pca_embedded_fly)
# fly_labels = kmeans_fly.labels_

# fig = px.scatter_3d(
#     x=pca_embedded_fly[:, 0],
#     y=pca_embedded_fly[:, 1],
#     z=pca_embedded_fly[:, 2],
#     color=fly_labels.astype(str),
#     labels={
#         "x": f"PC1 ({pca_fly.explained_variance_ratio_[0]*100:.2f}%)",
#         "y": f"PC2 ({pca_fly.explained_variance_ratio_[1]*100:.2f}%)",
#         "z": f"PC3 ({pca_fly.explained_variance_ratio_[2]*100:.2f}%)",
#         "color": "Cluster"
#     },
#     title="3D PCA of intentions across all episodes"
# )

# fig.update_traces(marker=dict(size=5, opacity=0.5))
# fig.show()

In [None]:
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.scatter(pca_embedded_fly[:, 0], pca_embedded_fly[:, 1], alpha=0.5)
plt.title("Fly data in PCA space (own PCA)")
plt.xlabel("PC1")
plt.ylabel("PC2")

plt.subplot(1, 2, 2)
plt.scatter(pca_embedded_rodent[:, 0], pca_embedded_rodent[:, 1], alpha=0.5, color="orange")
plt.title("Rodent data in PCA space (own PCA)")
plt.xlabel("PC1")
plt.ylabel("PC2")

plt.tight_layout()
plt.show()

In [None]:
combined_data = np.vstack([activations_rodent, activations_fly])

# Fit a single PCA on the combined dataset
pca_shared = PCA(n_components=2, random_state=42)
pca_shared.fit(combined_data)

# Transform each set into the same 2D space
fly_2d_shared = pca_shared.transform(activations_fly)
rodent_2d_shared = pca_shared.transform(activations_rodent)

plt.figure(figsize=(7, 6))
plt.scatter(fly_2d_shared[:, 0], fly_2d_shared[:, 1], 
            alpha=0.5, label="Fly")
plt.scatter(rodent_2d_shared[:, 0], rodent_2d_shared[:, 1], 
            alpha=0.5, label="Rodent", color="orange")
plt.title("Fly vs. Rodent in a Shared PCA Space")
plt.xlabel("PC1 (shared)")
plt.ylabel("PC2 (shared)")
plt.legend()
plt.show()

In [None]:
kmeans_rodent = KMeans(n_clusters=5, random_state=42).fit(pca_embedded_rodent)
rodent_labels = kmeans_rodent.labels_

plt.figure(figsize=(6,5))
plt.scatter(pca_embedded_rodent[:, 0], pca_embedded_rodent[:, 1], 
            c=rodent_labels, cmap='tab10', alpha=0.6)
plt.title("Rodent Clusters in PCA space")
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.show()

kmeans_fly = KMeans(n_clusters=5, random_state=42).fit(pca_embedded_fly)
fly_labels = kmeans_fly.labels_

plt.figure(figsize=(6,5))
plt.scatter(pca_embedded_fly[:, 0], pca_embedded_fly[:, 1], 
            c=fly_labels, cmap='tab10', alpha=0.6)
plt.title("Fly Clusters in PCA space")
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.show()

Helper functions for clustering plotting

In [35]:
def plot_pca_intention_with_clusters(
    frame_idx: int,
    pca_projections: np.ndarray,
    cluster_labels: np.ndarray,
    var: np.ndarray,
    clip_idx: int,
    n_components: int,
    terminated: bool = False,
):
    """
    Plot the PCA progression for the entire rollout, color-coded by clusters.
    Highlight the current frame in some way.
    Returns an RGB image (H x W x 3) as a numpy array, where the height matches 
    the MuJoCo frame (e.g., 480px).
    """
    
    #   4.8 inches * 100 dpi = 480 px
    fig, ax = plt.subplots(figsize=(6.4, 4.8), dpi=100)

    if n_components < 2:
        raise ValueError("This code expects at least 2 PCA components to plot in 2D.")

    x = pca_projections[:, 0]
    y = pca_projections[:, 1]

    # entire trajectory, color-coded by cluster
    sc = ax.scatter(
        x, 
        y, 
        c=cluster_labels, 
        cmap="tab10", 
        alpha=0.3,
        s=20
    )

    # highlight the current frame in a larger circle
    ax.scatter(
        x[frame_idx], 
        y[frame_idx], 
        c=[cluster_labels[frame_idx]], 
        cmap="tab10", 
        edgecolor="black", 
        s=100, 
        alpha=1.0
    )

    title_str = f"Clip {clip_idx} - Frame {frame_idx} - Cluster {cluster_labels[frame_idx]}"
    if terminated:
        title_str = f"Clip {clip_idx} - Final Frame - Cluster {cluster_labels[frame_idx]}"
    ax.set_title(title_str)
    ax.set_xlabel(f"PC1 ({var[0]*100:.2f}%)")
    ax.set_ylabel(f"PC2 ({var[1]*100:.2f}%)")

    fig.canvas.draw()
    width, height = fig.canvas.get_width_height()  # should be 640 x 480
    plot_img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(height, width, 3)

    plt.close(fig)
    return plot_img

def render_with_pca_progression_and_clusters(
    rollout: dict, 
    pca_projections: np.ndarray, 
    cluster_labels: np.ndarray,
    var: np.ndarray,
    n_components: int = 2, 
    walker_name: str = "rodent",
):
    """
    Render the MuJoCo frames side-by-side with a PCA+cluster progression graph
    showing which cluster is active at each frame.
    """
    # raw Mujoco frames (assume length T+1), then skip the first frame
    frames_mujoco = render_from_saved_rollout(rollout, walker_name)[1:]  
    # Now frames_mujoco has length T if your rollout is T steps

    T = len(frames_mujoco)  
    if T != len(pca_projections):
        raise ValueError(f"Mismatch: got {T} Mujoco frames vs {len(pca_projections)} PCA steps.")

    # parallel plotting
    orig_backend = matplotlib.get_backend()
    matplotlib.use("Agg")  # headless for parallel figure creation

    clip_idx = int(rollout["info"][0]["clip_idx"])
    episode_start = 0

    # partial function to handle each time step's plot
    worker = functools.partial(
        plot_pca_intention_with_clusters,
        pca_projections=pca_projections,
        cluster_labels=cluster_labels,
        clip_idx=clip_idx,
        n_components=n_components,
        var=var,
    )

    print("Rendering PCA + Cluster progression...")

    # create PCA+cluster frames in parallel for all timesteps
    with mp.Pool(processes=mp.cpu_count()) as pool:
        frames_pca = pool.map(worker, range(T))

    # concatenate frames side-by-side
    concat_frames = []
    print("Concatenating frames...")
    for idx in tqdm(range(T)):
        # The shapes of frames_mujoco[idx] and frames_pca[idx] 
        # should now match in height, so hstack won't error.
        concat = np.hstack([frames_mujoco[idx], frames_pca[idx]])
        concat_frames.append(concat)

    # terminated PCA frame at the end
    reward_plot = plot_pca_intention_with_clusters(
        frame_idx=T - 1,
        pca_projections=pca_projections,
        cluster_labels=cluster_labels,
        clip_idx=clip_idx,
        n_components=n_components,
        terminated=True,
        var=var,
    )

    plt.close("all")
    matplotlib.use(orig_backend)

    # short pause at the end (50 repeated frames)
    final_frame = frames_mujoco[-1]
    for _ in range(50):
        concat_frames.append(np.hstack([final_frame, reward_plot]))

    return concat_frames

def global_local_pca_worker(
    frame_idx: int,
    pca_global: np.ndarray,
    cluster_global: np.ndarray,
    global_subset_indices: np.ndarray,
    pca_local: np.ndarray,
    cluster_local: np.ndarray,
    T: int,
    clip_idx: int,
    var_lcoal: np.ndarray,
    var_global: np.ndarray,
) -> np.ndarray:
    """
    The picklable worker function called by Pool.
    """
    is_terminated = (frame_idx == T - 1)
    return plot_global_and_local_pca_intention_with_clusters(
        pca_global=pca_global,
        cluster_global=cluster_global,
        global_subset_indices=global_subset_indices,
        pca_local=pca_local,
        cluster_local=cluster_local,
        frame_idx_local=frame_idx,
        clip_idx=clip_idx,
        var_lcoal=var_lcoal,
        var_global=var_global,
        terminated=is_terminated,
    )

def plot_global_and_local_pca_intention_with_clusters(
    pca_global: np.ndarray,
    cluster_global: np.ndarray, 
    global_subset_indices: np.ndarray, 
    pca_local: np.ndarray, 
    cluster_local: np.ndarray,
    frame_idx_local: int,
    clip_idx: int,
    var_lcoal: np.ndarray,
    var_global: np.ndarray,
    terminated: bool = False
):
    """
    Plot side-by-side:
    1) Global PCA (left subplot): entire dataset in gray, this clip in color, highlight current frame.
    2) Local PCA (right subplot): just this clip, highlight current frame.
    
    Returns an RGB image (H x W x 3) as a numpy array.
    """

    # assume the local embedding is T points and we have T frames. frame_idx_local in [0..T).
    T = len(pca_local)
    if frame_idx_local >= T:
        raise ValueError(f"frame_idx_local={frame_idx_local} exceeds local clip length={T}.")

    fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(12.8, 4.8), dpi=100)
    
    all_x_g = pca_global[:, 0]
    all_y_g = pca_global[:, 1]
    
    # light grey plot
    ax1.scatter(all_x_g, all_y_g, color="lightgray", alpha=0.5, s=10, label="All data (global)")

    # highlight just the points belonging to this clip
    # pca_global_clip shape is (T, 2) if global_subset_indices has length T
    pca_global_clip = pca_global[global_subset_indices]
    cluster_global_clip = cluster_global[global_subset_indices]

    x_clip_g = pca_global_clip[:, 0]
    y_clip_g = pca_global_clip[:, 1]

    sc = ax1.scatter(
        x_clip_g,
        y_clip_g,
        c=cluster_global_clip,
        cmap="tab10",
        alpha=0.7,
        s=30,
        label=f"Clip {clip_idx}"
    )

    # frame_idx_local corresponds to subset_indices[frame_idx_local] in the global space
    current_global_idx = global_subset_indices[frame_idx_local]
    cur_x_g = pca_global[current_global_idx, 0]
    cur_y_g = pca_global[current_global_idx, 1]
    cur_clust_g = cluster_global[current_global_idx]

    ax1.scatter(
        cur_x_g, 
        cur_y_g, 
        c=[cur_clust_g], 
        cmap="tab10",
        edgecolor="black",
        s=100,
        alpha=1.0
    )
    ax1.set_title(f"Global PCA (Clip {clip_idx})")
    ax1.set_xlabel(f"PC1 (global) ({var_global[0]*100:.2f}%)")
    ax1.set_ylabel(f"PC2 (global) ({var_global[1]*100:.2f}%)")
    

    x_l = pca_local[:, 0]
    y_l = pca_local[:, 1]
    sc2 = ax2.scatter(
        x_l, 
        y_l, 
        c=cluster_local, 
        cmap="tab10",
        alpha=0.7,
        s=30,
        label=f"Clip {clip_idx} local"
    )

    ax2.scatter(
        x_l[frame_idx_local],
        y_l[frame_idx_local],
        c=[cluster_local[frame_idx_local]],
        cmap="tab10",
        edgecolor="black",
        s=100,
        alpha=1.0
    )
    local_title_str = f"Local PCA: Clip {clip_idx}, Frame {frame_idx_local}"
    if terminated:
        local_title_str += " (terminated)"
    ax2.set_title(local_title_str)
    ax2.set_xlabel(f"PC1 (local) ({var_lcoal[0]*100:.2f}%)")
    ax2.set_ylabel(f"PC2 (local) ({var_lcoal[1]*100:.2f}%)")

    plt.tight_layout()

    fig.canvas.draw()
    width, height = fig.canvas.get_width_height()
    img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(height, width, 3)
    plt.close(fig)
    return img


def render_with_global_and_local_pca_progression(
    rollout: dict,
    walker_name: str,
    pca_global: np.ndarray,
    cluster_global: np.ndarray,
    global_subset_indices: np.ndarray,
    pca_local: np.ndarray,
    cluster_local: np.ndarray,
    var_lcoal: np.ndarray,
    var_global: np.ndarray,
):
    """
    Render the MuJoCo frames side-by-side with a figure showing
    BOTH global PCA + local PCA for each frame.
    """
    
    frames_mujoco = render_from_saved_rollout(rollout, walker_name)[1:]
    T = len(frames_mujoco)
    
    if T != len(pca_local):
        raise ValueError(f"Mismatch: {T} MuJoCo frames vs {len(pca_local)} local PCA steps.")
    if T != len(global_subset_indices):
        raise ValueError(f"Mismatch: {T} frames vs {len(global_subset_indices)} global_subset_indices.")

    clip_idx = int(rollout["info"][0]["clip_idx"])
    orig_backend = matplotlib.get_backend()
    matplotlib.use("Agg")

    worker_args = []
    for frame_idx in range(T):
        worker_args.append((
            frame_idx,
            pca_global,
            cluster_global,
            global_subset_indices,
            pca_local,
            cluster_local,
            T,
            clip_idx,
            var_lcoal,
            var_global,
        ))
    
    print("Rendering PCA (global+local) progression...")
    with mp.Pool(processes=mp.cpu_count()) as pool:
        frames_pca = pool.starmap(global_local_pca_worker, worker_args)
        
    concat_frames = []
    for idx in range(T):
        concat_img = np.hstack([frames_mujoco[idx], frames_pca[idx]])
        concat_frames.append(concat_img)

    matplotlib.use(orig_backend)
    return concat_frames

## Serialized the PCA result to disk

In this way, we can directly load the pca object to do the transformation.

In [77]:
# Serialize PCA components to a dictionary
pca_data = {
    "components_": pca_fly.components_.tolist(),
    "explained_variance_": pca_fly.explained_variance_.tolist(),
    "explained_variance_ratio_": pca_fly.explained_variance_ratio_.tolist(),
    "mean_": pca_fly.mean_.tolist(),
    "n_components_": pca_fly.n_components_,
}

# Save to a JSON file
with open("pca_intention_fly.json", "w") as f:
    json.dump(pca_data, f)

# Load the PCA object from json

In [78]:
with open("pca_intention_fly.json", "r") as f:
    loaded_pca_data = json.load(f)

# Reconstruct the PCA object
pca = PCA(n_components=loaded_pca_data["n_components_"])
pca.components_ = np.array(loaded_pca_data["components_"])
pca.explained_variance_ = np.array(loaded_pca_data["explained_variance_"])
pca.explained_variance_ratio_ = np.array(loaded_pca_data["explained_variance_ratio_"])
pca.mean_ = np.array(loaded_pca_data["mean_"])

In [None]:
pca.components_.shape

This PCA laods in data from the fly, not the rodent, this is only for demonstrating, not actually using

# PCA + Mujoco Rendering W/ Clusters

In [None]:
clip_id = 1
directory = f"/root/vast/kaiwen/track-mjx/rodent_rollout_info/"
with h5py.File(directory + f"clip_{clip_id}.h5", "r") as h5file:
    rollout = load_from_h5py(h5file)
    # directly get out the activations
    act = get_aggregate_data("/activations", ['intention'], clip_id, directory)

len(act)

In [None]:
rollout['info'][11]

In [None]:
rollout['info'][10]

In [27]:
scaler = StandardScaler()

In [None]:
clip_id = 2
directory = f"/root/vast/kaiwen/track-mjx/rodent_rollout_info/"
with h5py.File(directory + f"clip_{clip_id}.h5", "r") as h5file:
    rollout = load_from_h5py(h5file)
    # directly get out the activations
    act = get_aggregate_data("/activations", ['intention'], clip_id, directory)

act = act[:599] # ref only have 600

act = scaler.fit_transform(act)
pca_fly = PCA()
pca_fly.fit(act)
pca_embedded = pca_fly.transform(act)
kmeans = KMeans(n_clusters=4, random_state=42).fit(pca_embedded)
frames = render_with_pca_progression_and_clusters(rollout=rollout, pca_projections=pca_embedded, cluster_labels=kmeans.labels_, n_components=2, walker_name="fly", var=pca_fly.explained_variance_ratio_)
display_video(frames, framerate=30)

In [58]:
clip_id = 0
directory = f"/root/vast/scott-yang/rodent_rollout_info/data/"
with h5py.File(directory + f"clip_{clip_id}.h5", "r") as h5file:
    rollout = load_from_h5py(h5file)
    # directly get out the activations
    act = get_aggregate_data("/activations", ['intention'], clip_id, directory)

act = scaler.fit_transform(act)
pca_rodent = PCA()
pca_rodent.fit(act)
pca_embedded = pca_rodent.transform(act)
kmeans = KMeans(n_clusters=4, random_state=42).fit(pca_embedded)
frames = render_with_pca_progression_and_clusters(rollout=rollout, pca_projections=pca_embedded, cluster_labels=kmeans.labels_, n_components=2, walker_name="rodent", var=pca_rodent.explained_variance_ratio_)
display_video(frames, framerate=30)

Get global PCA plotting then mujoco rendering

In [None]:
clip_id = 5
clip_length = 599
start = clip_id * clip_length 
end = start + clip_length
global_subset_indices = np.arange(start, end)

directory = f"/root/vast/kaiwen/track-mjx/rodent_rollout_info/"
with h5py.File(directory + f"clip_{clip_id}.h5", "r") as h5file:
    rollout = load_from_h5py(h5file)
    act = get_aggregate_data("/activations", ['intention'], clip_id, directory)

act = act[:599] # ref only have 600
act = scaler.fit_transform(act)
activations_fly = scaler.fit_transform(activations_fly)

pca_fly = PCA()
pca_fly.fit(activations_fly)
global_embedding = pca_fly.transform(activations_fly)
global_labels = KMeans(n_clusters=4, random_state=42).fit(global_embedding).labels_

pca_local = PCA()
pca_local.fit(act)
local_embedding = pca_local.transform(act)
local_labels = KMeans(n_clusters=4, random_state=42).fit(local_embedding).labels_

frames = render_with_global_and_local_pca_progression(
    rollout=rollout,
    walker_name="fly",
    pca_global=global_embedding, 
    cluster_global=global_labels,
    global_subset_indices=global_subset_indices,
    pca_local=local_embedding,
    cluster_local=local_labels,
    var_lcoal=pca_fly.explained_variance_ratio_,
    var_global=pca_local.explained_variance_ratio_,
)
display_video(frames, framerate=30)

In [None]:
clip_id = 20
clip_length = 599
start = clip_id * clip_length 
end = start + clip_length
global_subset_indices = np.arange(start, end)

directory = f"/root/vast/kaiwen/track-mjx/rodent_rollout_info/"
with h5py.File(directory + f"clip_{clip_id}.h5", "r") as h5file:
    rollout = load_from_h5py(h5file)
    act = get_aggregate_data("/activations", ['intention'], clip_id, directory)

act = act[:599] # ref only have 600
act = scaler.fit_transform(act)
activations_fly = scaler.fit_transform(activations_fly)

pca_fly = PCA()
pca_fly.fit(activations_fly)
global_embedding = pca_fly.transform(activations_fly)
global_labels = KMeans(n_clusters=4, random_state=42).fit(global_embedding).labels_

pca_local = PCA()
pca_local.fit(act)
local_embedding = pca_local.transform(act)
local_labels = KMeans(n_clusters=4, random_state=42).fit(local_embedding).labels_

frames = render_with_global_and_local_pca_progression(
    rollout=rollout,
    walker_name="fly",
    pca_global=global_embedding, 
    cluster_global=global_labels,
    global_subset_indices=global_subset_indices,
    pca_local=local_embedding,
    cluster_local=local_labels,
    var_lcoal=pca_fly.explained_variance_ratio_,
    var_global=pca_local.explained_variance_ratio_,
)
display_video(frames, framerate=30)

In [None]:
clip_id = 20
clip_length = 499
start = clip_id * clip_length 
end = start + clip_length
global_subset_indices = np.arange(start, end)

directory = f"/root/vast/scott-yang/rodent_rollout_info/data/"
with h5py.File(directory + f"clip_{clip_id}.h5", "r") as h5file:
    rollout = load_from_h5py(h5file)
    act = get_aggregate_data("/activations", ['intention'], clip_id, directory)

act = scaler.fit_transform(act)
activations_rodent = scaler.fit_transform(activations_rodent)

pca_rodent = PCA()
pca_rodent.fit(activations_rodent)
global_embedding = pca_rodent.transform(activations_rodent)
global_labels = KMeans(n_clusters=4, random_state=42).fit(global_embedding).labels_

pca_local = PCA()
pca_local.fit(act)
local_embedding = pca_local.transform(act)
local_labels = KMeans(n_clusters=4, random_state=42).fit(local_embedding).labels_

frames = render_with_global_and_local_pca_progression(
    rollout=rollout,
    walker_name="rodent",
    pca_global=global_embedding, 
    cluster_global=global_labels,
    global_subset_indices=global_subset_indices,
    pca_local=local_embedding,
    cluster_local=local_labels,
    var_lcoal=pca_fly.explained_variance_ratio_,
    var_global=pca_local.explained_variance_ratio_,
)
display_video(frames, framerate=30)

In [None]:
clip_id = 500
clip_length = 499
start = clip_id * clip_length 
end = start + clip_length
global_subset_indices = np.arange(start, end)

directory = f"/root/vast/scott-yang/rodent_rollout_info/data/"
with h5py.File(directory + f"clip_{clip_id}.h5", "r") as h5file:
    rollout = load_from_h5py(h5file)
    act = get_aggregate_data("/activations", ['intention'], clip_id, directory)

act = scaler.fit_transform(act)
activations_rodent = scaler.fit_transform(activations_rodent)

pca_rodent = PCA()
pca_rodent.fit(activations_rodent)
global_embedding = pca_rodent.transform(activations_rodent)
global_labels = KMeans(n_clusters=4, random_state=42).fit(global_embedding).labels_

pca_local = PCA()
pca_local.fit(act)
local_embedding = pca_local.transform(act)
local_labels = KMeans(n_clusters=4, random_state=42).fit(local_embedding).labels_

frames = render_with_global_and_local_pca_progression(
    rollout=rollout,
    walker_name="rodent",
    pca_global=global_embedding, 
    cluster_global=global_labels,
    global_subset_indices=global_subset_indices,
    pca_local=local_embedding,
    cluster_local=local_labels,
    var_lcoal=pca_fly.explained_variance_ratio_,
    var_global=pca_local.explained_variance_ratio_,
)
display_video(frames, framerate=30)

In [None]:
clip_id = 200
clip_length = 499
start = clip_id * clip_length 
end = start + clip_length
global_subset_indices = np.arange(start, end)

directory = f"/root/vast/scott-yang/rodent_rollout_info/data/"
with h5py.File(directory + f"clip_{clip_id}.h5", "r") as h5file:
    rollout = load_from_h5py(h5file)
    act = get_aggregate_data("/activations", ['intention'], clip_id, directory)

act = scaler.fit_transform(act)
activations_rodent = scaler.fit_transform(activations_rodent)

pca_rodent = PCA()
pca_rodent.fit(activations_rodent)
global_embedding = pca_rodent.transform(activations_rodent)
global_labels = KMeans(n_clusters=4, random_state=42).fit(global_embedding).labels_

pca_local = PCA()
pca_local.fit(act)
local_embedding = pca_local.transform(act)
local_labels = KMeans(n_clusters=4, random_state=42).fit(local_embedding).labels_

frames = render_with_global_and_local_pca_progression(
    rollout=rollout,
    walker_name="rodent",
    pca_global=global_embedding, 
    cluster_global=global_labels,
    global_subset_indices=global_subset_indices,
    pca_local=local_embedding,
    cluster_local=local_labels,
    var_lcoal=pca_fly.explained_variance_ratio_,
    var_global=pca_local.explained_variance_ratio_,
)
display_video(frames, framerate=30)

In [None]:
display_video(frames, framerate=10)

# Clustering With Labels

In [None]:
with h5py.File("/root/vast/scott-yang/vnl_ray/clips/all_snippets.h5", "r") as h5file:
    group = h5file["clip_0/walkers/walker_0"]
    if isinstance(group, h5py.Dataset):
        print(group[()])  # Read dataset value
    else:
        print(group.keys())

In [None]:
import pickle as pkl

with open("/root/vast/scott-yang/vnl_ray/clips/all_snips.p", "rb") as file:
    all_snips = pkl.load(file)

In [None]:
import os
import re

def extract_clip_info(snippet_path: str):
    """
    Extracts the behavior label and clip number from a snippet filename.
    
    Example: '././snippets_2_25_2021/snips/Walk_145.p' -> ('Walk', 145)
    """
    filename = os.path.basename(snippet_path)  # e.g. 'Walk_145.p'
    name, _ = os.path.splitext(filename)       # Remove '.p', e.g. 'Walk_145'

    match = re.match(r"([a-zA-Z]+)_(\d+)", name)  # Match 'Behavior_ClipNumber'
    if match:
        behavior = match.group(1)  # Extracts 'Walk', 'Rear', etc.
        clip_number = int(match.group(2))  # Extracts '145'
        return behavior, clip_number
    else:
        return None, None  # Handle unexpected filenames



clip_info = [extract_clip_info(path) for path in all_snips["snips_order"]]
print(clip_info)  

In [None]:
directory = "/root/vast/scott-yang/rodent_rollout_info/data/"

def load_clip_activations(clip_id):
    """
    Loads activation data for a specific clip.
    """
    file_path = os.path.join(directory, f"clip_{clip_id}.h5")
    if not os.path.exists(file_path):
        print(f"Warning: Clip {clip_id} not found!")
        return None
    
    with h5py.File(file_path, "r") as h5file:
        act = get_aggregate_data("/activations", ['intention'], clip_id, directory)  # shape (T, D)
    return act

# Load activations and labels for all snippets
all_labels = []
all_clip_ids = []

for path in all_snips["snips_order"]:
    behavior, clip_id = extract_clip_info(path)
    all_labels.append(behavior)
    all_clip_ids.append(clip_id)

all_labels = np.array(all_labels)
all_clip_ids = np.array(all_clip_ids)

work_rodent = functools.partial(get_aggregate_data, "/activations", ['intention'], path=f"/root/vast/scott-yang/rodent_rollout_info/data/")
with mp.Pool(processes=mp.cpu_count()) as pool:
    activations_rodent = list(tqdm(pool.imap(work_rodent, range(842)), total=842))

activations_rodent = np.vstack(activations_rodent)

In [None]:
activations_rodent.shape

In [None]:
# Ensure the number of clips matches expectations
num_clips = len(all_clip_ids)  # Should be around 842 clips
frames_per_clip = 499

# Expand behavior labels to match frame-level activations
expanded_labels = np.repeat(all_labels, frames_per_clip)
expanded_clip_ids = np.repeat(all_clip_ids, frames_per_clip)

# Ensure total matches 420,158 frames
assert len(expanded_labels) == activations_rodent.shape[0]
assert len(expanded_clip_ids) == activations_rodent.shape[0]

print(f"Expanded Labels Shape: {expanded_labels.shape}")
print(f"Expanded Clip IDs Shape: {expanded_clip_ids.shape}")

In [None]:
from collections import Counter

scaler = StandardScaler()
X_scaled = scaler.fit_transform(activations_rodent)
pca = PCA(n_components=2, random_state=42)
X_pca = pca.fit_transform(X_scaled)

var_explained = pca.explained_variance_ratio_ * 100
print(f"Explained variance: PC1 = {var_explained[0]:.2f}%, PC2 = {var_explained[1]:.2f}%")

k = 4
kmeans = KMeans(n_clusters=k, random_state=42)
cluster_ids = kmeans.fit_predict(X_pca)

# Compute majority behavior for each cluster
cluster_to_behavior = {i: [] for i in range(k)}
for i, cluster in enumerate(cluster_ids):
    cluster_to_behavior[cluster].append(expanded_labels[i]) 

cluster_majority_behavior = {}
for cluster, behaviors in cluster_to_behavior.items():
    # Get most frequent behavior
    most_common_behavior = Counter(behaviors).most_common(1)[0][0]
    cluster_majority_behavior[cluster] = most_common_behavior
    print(f"Cluster {cluster} → Assigned Behavior: {most_common_behavior}")

plt.figure(figsize=(8,6))
sns.scatterplot(x=X_pca[:, 0], y=X_pca[:, 1], hue=cluster_ids, palette="tab10", alpha=0.7)
plt.title("PCA Projection of Rodent Activations (K-Means Clustering)")
plt.xlabel(f"PC1 ({var_explained[0]:.2f}%)")
plt.ylabel(f"PC2 ({var_explained[1]:.2f}%)")
for cluster, behavior in cluster_majority_behavior.items():
    cluster_center = np.mean(X_pca[cluster_ids == cluster], axis=0)
    plt.text(cluster_center[0], cluster_center[1], behavior, fontsize=10, ha='center', va='center',
             bbox=dict(facecolor='white', alpha=0.6, edgecolor='black'))

plt.legend(title="Cluster")
plt.show()

# Activations Analysis

In [None]:
import multiprocessing as mp
from tqdm import tqdm

work_fly = functools.partial(get_aggregate_data, "/activations", ['decoder', 'layer_1'], path=f"/root/vast/kaiwen/track-mjx/rodent_rollout_info")
work_rodent = functools.partial(get_aggregate_data, "/activations", ['decoder', 'layer_1'], path=f"/root/vast/scott-yang/rodent_rollout_info/data/")

with mp.Pool(processes=mp.cpu_count()) as pool:
    activations_fly = list(tqdm(pool.imap(work_fly, range(499)), total=499))
    activations_rodent = list(tqdm(pool.imap(work_rodent, range(842)), total=842))

activations_fly = np.vstack(activations_fly)
activations_rodent = np.vstack(activations_rodent)

In [None]:
scaler = StandardScaler()
activations_rodent = scaler.fit_transform(activations_rodent)
activations_fly = scaler.fit_transform(activations_fly)

pca_rodent = PCA()

pca_rodent = pca_rodent.fit(activations_rodent)
print(np.cumsum(pca_rodent.explained_variance_ratio_[:10]))
pca_embedded_rodent = pca_rodent.transform(activations_rodent)

pca_fly = PCA()

pca_fly = pca_fly.fit(activations_fly)
print(np.cumsum(pca_fly.explained_variance_ratio_[:10]))
pca_embedded_fly = pca_fly.transform(activations_fly)

In [None]:
kmeans_rodent = KMeans(n_clusters=5, random_state=42).fit(pca_embedded_rodent)
rodent_labels = kmeans_rodent.labels_

plt.figure(figsize=(6,5))
plt.scatter(pca_embedded_rodent[:, 0], pca_embedded_rodent[:, 1], 
            c=rodent_labels, cmap='tab10', alpha=0.6)
plt.title("Rodent Clusters in PCA space")
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.show()

kmeans_fly = KMeans(n_clusters=5, random_state=42).fit(pca_embedded_fly)
fly_labels = kmeans_fly.labels_

plt.figure(figsize=(6,5))
plt.scatter(pca_embedded_fly[:, 0], pca_embedded_fly[:, 1], 
            c=fly_labels, cmap='tab10', alpha=0.6)
plt.title("Fly Clusters in PCA space")
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.show()

In [None]:
pca_fly = PCA(n_components=3)
pca_embedded_fly = pca_fly.fit_transform(activations_fly) 

k = 6
kmeans_fly = KMeans(n_clusters=k, random_state=42).fit(pca_embedded_fly)
fly_labels = kmeans_fly.labels_

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')

N = pca_embedded_fly.shape[0]
c_values = np.arange(N)

sc = ax.scatter(
    pca_embedded_fly[:, 0],
    pca_embedded_fly[:, 1],
    pca_embedded_fly[:, 2],
    c=fly_labels,
    cmap="tab20",
    alpha=0.5
)

cb = plt.colorbar(sc, ax=ax, shrink=0.6)
cb.set_label("Sample Index")

ax.set_xlabel(f"PC1 ({pca_fly.explained_variance_ratio_[0]*100:.2f}%)")
ax.set_ylabel(f"PC2 ({pca_fly.explained_variance_ratio_[1]*100:.2f}%)")
ax.set_zlabel(f"PC3 ({pca_fly.explained_variance_ratio_[2]*100:.2f}%)")

ax.set_title("3D PCA of intentions across all episodes")
plt.tight_layout()
plt.show()

In [None]:
pca_rodent = PCA(n_components=3)
pca_embedded_fly = pca_rodent.fit_transform(activations_rodent) 

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')

N = pca_embedded_fly.shape[0]
c_values = np.arange(N)

sc = ax.scatter(
    pca_embedded_fly[:, 0],
    pca_embedded_fly[:, 1],
    pca_embedded_fly[:, 2],
    c=c_values,
    cmap="tab20",
    alpha=0.5
)

cb = plt.colorbar(sc, ax=ax, shrink=0.6)
cb.set_label("Sample Index")

ax.set_xlabel(f"PC1 ({pca_rodent.explained_variance_ratio_[0]*100:.2f}%)")
ax.set_ylabel(f"PC2 ({pca_rodent.explained_variance_ratio_[1]*100:.2f}%)")
ax.set_zlabel(f"PC3 ({pca_rodent.explained_variance_ratio_[2]*100:.2f}%)")

ax.set_title("3D PCA of intentions across all episodes")
plt.tight_layout()
plt.show()

In [None]:
from collections import Counter

scaler = StandardScaler()
X_scaled = scaler.fit_transform(activations_rodent)
pca = PCA(n_components=2, random_state=42)
X_pca = pca.fit_transform(X_scaled)

var_explained = pca.explained_variance_ratio_ * 100
print(f"Explained variance: PC1 = {var_explained[0]:.2f}%, PC2 = {var_explained[1]:.2f}%")

k = 3
kmeans = KMeans(n_clusters=k, random_state=42)
cluster_ids = kmeans.fit_predict(X_pca)

# Compute majority behavior for each cluster
cluster_to_behavior = {i: [] for i in range(k)}
for i, cluster in enumerate(cluster_ids):
    cluster_to_behavior[cluster].append(expanded_labels[i]) 

cluster_majority_behavior = {}
for cluster, behaviors in cluster_to_behavior.items():
    # Get most frequent behavior
    most_common_behavior = Counter(behaviors).most_common(1)[0][0]
    cluster_majority_behavior[cluster] = most_common_behavior
    print(f"Cluster {cluster} → Assigned Behavior: {most_common_behavior}")

plt.figure(figsize=(8,6))
sns.scatterplot(x=X_pca[:, 0], y=X_pca[:, 1], hue=cluster_ids, palette="tab10", alpha=0.7)
plt.title("PCA Projection of Rodent Activations (K-Means Clustering)")
plt.xlabel(f"PC1 ({var_explained[0]:.2f}%)")
plt.ylabel(f"PC2 ({var_explained[1]:.2f}%)")
for cluster, behavior in cluster_majority_behavior.items():
    cluster_center = np.mean(X_pca[cluster_ids == cluster], axis=0)
    plt.text(cluster_center[0], cluster_center[1], behavior, fontsize=10, ha='center', va='center',
             bbox=dict(facecolor='white', alpha=0.6, edgecolor='black'))

plt.legend(title="Cluster")
plt.show()

In [None]:
layers = [['encoder','layer_0'], ['encoder','layer_1'], ['decoder','layer_0'], ['decoder','layer_1']]

fly_activations = {}
rodent_activations = {}

work_fly = {tuple(layer): functools.partial(get_aggregate_data, "/activations", layer,
                                            path="/root/vast/kaiwen/track-mjx/rodent_rollout_info") for layer in layers}
work_rodent = {tuple(layer): functools.partial(get_aggregate_data, "/activations", layer,
                                               path="/root/vast/scott-yang/rodent_rollout_info/data/") for layer in layers}

with mp.Pool(processes=mp.cpu_count()) as pool:
    for layer in layers:
        layer_key = tuple(layer)  # Convert list to tuple to use as a dictionary key
        activations_fly = list(tqdm(pool.imap(work_fly[layer_key], range(499)), total=499))
        activations_rodent = list(tqdm(pool.imap(work_rodent[layer_key], range(842)), total=842))
        fly_activations[layer_key] = np.vstack(activations_fly)
        rodent_activations[layer_key] = np.vstack(activations_rodent)

def process_and_plot_3d(activations, title, ax):
    """Performs PCA and K-Means clustering, then visualizes in 3D."""
    pca = PCA(n_components=3)
    pca_embedded = pca.fit_transform(activations)

    k = 3
    kmeans = KMeans(n_clusters=k, random_state=42).fit(pca_embedded)
    labels = kmeans.labels_

    sc = ax.scatter(
        pca_embedded[:, 0], pca_embedded[:, 1], pca_embedded[:, 2],
        c=labels, cmap="tab20", alpha=0.5
    )

    ax.set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.2f}%)")
    ax.set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.2f}%)")
    ax.set_zlabel(f"PC3 ({pca.explained_variance_ratio_[2]*100:.2f}%)")
    ax.set_title(title)

def process_and_plot_2d(activations, title, ax):
    """Performs PCA and K-Means clustering, then visualizes in 2D."""
    pca = PCA(n_components=2)
    pca_embedded = pca.fit_transform(activations)

    k = 3
    kmeans = KMeans(n_clusters=k, random_state=42).fit(pca_embedded)
    labels = kmeans.labels_

    sc = ax.scatter(
        pca_embedded[:, 0], pca_embedded[:, 1], c=labels, cmap="tab10", alpha=0.6
    )

    ax.set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.2f}%)")
    ax.set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.2f}%)")
    ax.set_title(title)

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(20, 10))

for i, layer in enumerate(layers):
    process_and_plot_2d(fly_activations[tuple(layer)], f"Fly - {layer[0]} {layer[1]}", axes[0, i])

for i, layer in enumerate(layers):
    process_and_plot_2d(rodent_activations[tuple(layer)], f"Rodent - {layer[0]} {layer[1]}", axes[1, i])

plt.tight_layout()
plt.show()

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(22, 10), subplot_kw={"projection": "3d"})

for i, layer in enumerate(layers):
    process_and_plot_3d(fly_activations[tuple(layer)], f"Fly - {layer[0]} {layer[1]}", axes[0, i])

for i, layer in enumerate(layers):
    process_and_plot_3d(rodent_activations[tuple(layer)], f"Rodent - {layer[0]} {layer[1]}", axes[1, i])

fig.subplots_adjust(left=0.05, right=0.95, top=0.90, bottom=0.10, wspace=0.35, hspace=0.40)

# Rotate 3D plots and adjust font sizes for better readability
for ax in axes.flat:
    ax.view_init(elev=25, azim=40)  # Adjust camera angles
    ax.tick_params(axis='both', which='major', labelsize=8)  # Smaller tick labels
    ax.set_xlabel(ax.get_xlabel(), fontsize=10, labelpad=12)
    ax.set_ylabel(ax.get_ylabel(), fontsize=10, labelpad=12)
    ax.set_zlabel(ax.get_zlabel(), fontsize=10, labelpad=12)

plt.show()

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(22, 10), subplot_kw={"projection": "3d"})

for i, layer in enumerate(layers):
    process_and_plot_3d(fly_activations[tuple(layer)], f"Fly - {layer[0]} {layer[1]}", axes[0, i])

for i, layer in enumerate(layers):
    process_and_plot_3d(rodent_activations[tuple(layer)], f"Rodent - {layer[0]} {layer[1]}", axes[1, i])

fig.subplots_adjust(left=0.05, right=0.95, top=0.90, bottom=0.10, wspace=0.35, hspace=0.40)
plt.show()

# LEAP Visualization

In [None]:
import multiprocessing as mp
from tqdm import tqdm

work_fly = functools.partial(get_aggregate_data, "/qposes_rollout", [], path=f"/root/vast/kaiwen/track-mjx/rodent_rollout_info")
work_rodent = functools.partial(get_aggregate_data, "/qposes_rollout", [], path=f"/root/vast/scott-yang/rodent_rollout_info/data/")

with mp.Pool(processes=mp.cpu_count()) as pool:
    qpos_fly = list(tqdm(pool.imap(work_fly, range(30)), total=30))
    qpos_rodent = list(tqdm(pool.imap(work_rodent, range(842)), total=842))

qpos_fly = np.vstack(qpos_fly)
qpos_rodent = np.vstack(qpos_rodent)

In [None]:
qpos_fly.shape, qpos_rodent.shape

In [None]:
import mujoco

pair_render_xml_path = "/root/vast/kaiwen/track-mjx/track_mjx/environment/walker/assets/fruitfly/fruitfly_force_pair.xml"
mj_model = mujoco.MjModel.from_xml_path(pair_render_xml_path)

for i in range(73):
    print(f"Index {i}: {mj_model.joint(i).name}")

In [None]:
pair_render_xml_path = "/root/vast/scott-yang/track-mjx/track_mjx/environment/walker/assets/rodent/rodent_ghostpair_scale080.xml"
mj_model = mujoco.MjModel.from_xml_path(pair_render_xml_path)

for i in range(73):
    print(f"Index {i}: {mj_model.joint(i).name}")

In [None]:
num_clips = qpos_fly.shape[0] // 599
print(f"Total clips: {num_clips}")

In [None]:
from scipy.ndimage import gaussian_filter1d

def compute_forward_velocity(qposes, dt):
    """Computes forward velocity from qposes by differentiating COM position."""
    com_positions = qposes[:, 0]  # Assuming x-position of COM is at index 0
    return np.gradient(com_positions, dt)

def compute_leg_phases(qposes, leg_indices, threshold=0.02, smooth_sigma=1):
    """
    Determines swing (1) or stance (0) phases for each leg tip.
    Uses correct leg tip height indices extracted from MuJoCo model.
    """
    leg_heights = qposes[:, leg_indices]
    raw_leg_phases = (leg_heights > threshold).astype(int)
    smoothed_leg_phases = gaussian_filter1d(raw_leg_phases, sigma=smooth_sigma, axis=0)
    return (smoothed_leg_phases > 0.5).astype(int)  # Re-binarize

def plot_gait_analysis(qposes, leg_indices, leg_labels, dt, clip_id, timesteps_per_clip, color, title):
    """Plots Forward Velocity and Leg Phases for a given dataset."""
    start_idx = clip_id * timesteps_per_clip
    end_idx = start_idx + timesteps_per_clip
    qposes_clip = qposes[start_idx:end_idx, :]

    # Compute velocity and leg phases
    forward_velocity = compute_forward_velocity(qposes_clip, dt)
    leg_phases = compute_leg_phases(qposes_clip, leg_indices)

    # Plotting
    fig, axes = plt.subplots(2, 1, figsize=(7, 5), gridspec_kw={'height_ratios': [1, 3]})
    time_axis = np.linspace(0, timesteps_per_clip * dt, timesteps_per_clip)

    # Top: Forward Velocity
    axes[0].plot(time_axis, forward_velocity, color=color, linewidth=1)
    axes[0].set_ylabel("Forward velocity (mm/s)")
    axes[0].set_xticks([])
    axes[0].set_xlim(0, time_axis[-1])

    # Bottom: Leg Phases (Swing/Stance)
    axes[1].imshow(leg_phases.T, cmap="gray_r", aspect="auto", interpolation="nearest")
    axes[1].set_yticks(np.arange(len(leg_labels)))
    axes[1].set_yticklabels(leg_labels)
    axes[1].set_xlabel("Time (s)")
    axes[1].set_xticks(np.linspace(0, timesteps_per_clip - 1, 6))
    axes[1].set_xticklabels(np.round(np.linspace(0, timesteps_per_clip * dt, 6), 2))

    # Swing/Stance Legend
    legend_patches = [plt.Line2D([0], [0], color="black", lw=4, label="Swing"),
                      plt.Line2D([0], [0], color="white", lw=4, label="Stance")]
    axes[1].legend(handles=legend_patches, loc="upper right", frameon=True, edgecolor="black")

    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

# --- Fly Configuration ---
fly_leg_indices = [6, 12, 18, 24, 30, 36]
fly_leg_labels = ["Front Leg (T1) Left", "Front Leg (T1) Right",
                  "Middle Leg (T2) Left", "Middle Leg (T2) Right",
                  "Hind Leg (T3) Left", "Hind Leg (T3) Right"]
plot_gait_analysis(qpos_fly, fly_leg_indices, fly_leg_labels, dt=1/500, clip_id=0, timesteps_per_clip=599, color='blue', title="Fly - Gait Analysis")

# --- Rodent Configuration ---
rodent_leg_indices = [12, 18, 59, 67]
rodent_leg_labels = ["Hind Left (Toe)", "Hind Right (Toe)", "Fore Left (Finger)", "Fore Right (Finger)"]
plot_gait_analysis(qpos_rodent, rodent_leg_indices, rodent_leg_labels, dt=1/50, clip_id=0, timesteps_per_clip=250, color='red', title="Rodent - Gait Analysis")


In [None]:
def compute_forward_velocity(qposes, dt, smooth_sigma=2):
    """Computes and smooths forward velocity using central differencing."""
    
    com_positions = qposes[:, 0]
    raw_velocities = (com_positions[2:] - com_positions[:-2]) / (2 * dt)  # Central difference
    raw_velocities = np.insert(raw_velocities, 0, raw_velocities[0])  # Pad first value
    raw_velocities = np.append(raw_velocities, raw_velocities[-1])  # Pad last value
    return gaussian_filter1d(raw_velocities, sigma=smooth_sigma)

def compute_leg_phases(qposes, leg_indices, threshold=0.05):
    """Determines swing (1) or stance (0) phases for each leg tip."""
    
    leg_heights = qposes[:, leg_indices]
    raw_leg_phases = (leg_heights > threshold).astype(int) # > is swing, <= is stance
    return (raw_leg_phases > 0.5).astype(int)  # Re-binarize

def plot_gait_analysis_horizontal(qposes, leg_indices, leg_labels, dt, clip_start, num_clips, timesteps_per_clip, color, title_prefix):
    """Plots multiple clips side-by-side in a horizontal layout but keeps each as a distinct figure."""
    
    fig, axes = plt.subplots(2, num_clips, figsize=(5 * num_clips, 6), gridspec_kw={'height_ratios': [1, 3]})

    for i in range(num_clips):
        clip_id = clip_start + i
        start_idx = clip_id * timesteps_per_clip
        end_idx = start_idx + timesteps_per_clip
        qposes_clip = qposes[start_idx:end_idx, :]

        # Compute velocity and leg phases
        forward_velocity = compute_forward_velocity(qposes_clip, dt)
        leg_phases = compute_leg_phases(qposes_clip, leg_indices)

        # Time axis in seconds
        time_axis = np.linspace(0, timesteps_per_clip * dt, timesteps_per_clip)

        # --- Top Row: Forward Velocity ---
        axes[0, i].plot(time_axis, forward_velocity, color=color, linewidth=1)
        axes[0, i].set_ylabel("Velocity (mm/s)")
        axes[0, i].set_xticks([])
        axes[0, i].set_xlim(0, time_axis[-1])
        axes[0, i].set_title(f"{title_prefix} - Clip {clip_id}")

        # --- Bottom Row: Leg Phases ---
        im = axes[1, i].imshow(leg_phases.T, cmap="gray_r", aspect="auto", interpolation="nearest")
        axes[1, i].set_yticks(np.arange(len(leg_labels)))
        axes[1, i].set_yticklabels(leg_labels)
        axes[1, i].set_xlabel("Time (s)")
        axes[1, i].set_xticks(np.linspace(0, timesteps_per_clip - 1, 6))
        axes[1, i].set_xticklabels(np.round(np.linspace(0, timesteps_per_clip * dt, 6), 2))

        # Swing/Stance Legend (only add on last one for space)
        if i == num_clips - 1:
            legend_patches = [plt.Line2D([0], [0], color="black", lw=4, label="Swing"),
                              plt.Line2D([0], [0], color="white", lw=4, label="Stance")]
            axes[1, i].legend(handles=legend_patches, loc="upper right", frameon=True, edgecolor="black")

    plt.tight_layout()
    plt.show()

# --- Fly Configuration ---
fly_leg_indices = [6, 12, 18, 24, 30, 36]
fly_leg_labels = ["Front Leg (T1) Left", "Front Leg (T1) Right",
                  "Middle Leg (T2) Left", "Middle Leg (T2) Right",
                  "Hind Leg (T3) Left", "Hind Leg (T3) Right"]
plot_gait_analysis_horizontal(qpos_fly, fly_leg_indices, fly_leg_labels, dt=1/500, 
                              clip_start=0, num_clips=3, timesteps_per_clip=599, 
                              color='blue', title_prefix="Fly - Gait Analysis")

# --- Rodent Configuration ---
rodent_leg_indices = [12, 18, 59, 67]
rodent_leg_labels = ["Hind Left (Toe)", "Hind Right (Toe)", "Fore Left (Finger)", "Fore Right (Finger)"]
plot_gait_analysis_horizontal(qpos_rodent, rodent_leg_indices, rodent_leg_labels, dt=1/50, 
                              clip_start=0, num_clips=3, timesteps_per_clip=250, 
                              color='red', title_prefix="Rodent - Gait Analysis")


In [None]:
def plot_leg_tip_heights(qposes, leg_indices, labels, dt, smooth_sigma=2, title="Leg Tip Heights Over Time"):
    """
    Plots the leg tip heights (Z-coordinates) over time.
    - qposes: (T, D) array of joint positions.
    - leg_indices: List of indices corresponding to the leg tip Z-coordinates.
    - labels: List of labels for each leg.
    - dt: Timestep for converting frames to time.
    - smooth_sigma: Smoothing parameter for better visualization.
    - title: Title of the plot.
    """
    # Extract leg tip heights
    leg_tips_z = qposes[:, leg_indices]

    # Apply smoothing per leg
    smoothed_leg_tips = np.array([gaussian_filter1d(leg_tips_z[:, i], sigma=smooth_sigma) for i in range(len(leg_indices))]).T

    # Time axis in seconds
    time_axis = np.arange(qposes.shape[0]) * dt

    # Plot
    plt.figure(figsize=(8, 4))
    for i, label in enumerate(labels):
        plt.plot(time_axis, smoothed_leg_tips[:, i], label=label, linewidth=1.5)

    plt.xlabel("Time (s)")
    plt.ylabel("Height (Z-coordinate)")
    plt.title(title)
    plt.legend()
    plt.show()

# --- Fly Setup ---
fly_leg_indices = [6, 12, 18, 24, 30, 36]  # Z-coordinates for Fly leg tips
fly_leg_labels = ["Front Leg (T1) Left", "Front Leg (T1) Right",
                  "Middle Leg (T2) Left", "Middle Leg (T2) Right",
                  "Hind Leg (T3) Left", "Hind Leg (T3) Right"]

# Select Fly Clip
clip_id_fly = 0
timesteps_per_clip_fly = 499
start_idx_fly = clip_id_fly * timesteps_per_clip_fly
end_idx_fly = start_idx_fly + timesteps_per_clip_fly
qposes_fly_clip = qpos_fly[start_idx_fly:end_idx_fly, :]

plot_leg_tip_heights(qposes_fly_clip, fly_leg_indices, fly_leg_labels, dt=1/500, title="Fly - Leg Tip Heights Over Time")

# --- Rodent Setup ---
rodent_leg_indices = [12, 18, 59, 67]  # Z-coordinates for Rodent foot & hand tips
rodent_leg_labels = ["Hind Left (Toe)", "Hind Right (Toe)",
                     "Fore Left (Finger)", "Fore Right (Finger)"]

# Select Rodent Clip
clip_id_rodent = 0
timesteps_per_clip_rodent = 250
start_idx_rodent = clip_id_rodent * timesteps_per_clip_rodent
end_idx_rodent = start_idx_rodent + timesteps_per_clip_rodent
qposes_rodent_clip = qpos_rodent[start_idx_rodent:end_idx_rodent, :]

plot_leg_tip_heights(qposes_rodent_clip, rodent_leg_indices, rodent_leg_labels, dt=1/50, title="Rodent - Leg Tip Heights Over Time")
