In [4]:
%load_ext autoreload
%autoreload 2

import os, copy, yaml, h5py
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm import tqdm
from omegaconf import OmegaConf
from orbax import checkpoint as ocp
from pyEDM import Simplex
from scipy import signal

# Environment setup
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["MUJOCO_GL"] = "egl"        # or "osmesa"
os.environ["PYOPENGL_PLATFORM"] = "egl"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = os.getenv("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.4")

import mediapy as media
from track_mjx.agent import checkpointing
from track_mjx.analysis import rollout, render, utils
from track_mjx.agent.mlp_ppo.intention_network import make_intention_policy
from brax.training.acme import running_statistics
from track_mjx.analysis.rollout import create_environment

# ═══════════════════════════════════════════════════════════════════════════════
# CONFIGURATION
# ═══════════════════════════════════════════════════════════════════════════════

# Analysis configuration
H5_PATH = "/root/vast/eric/track-mjx/notebooks/rollouts_full_mouse_arm_new_250826_030533_134914.h5"
NEURAL_E = 4     # Default embedding dimension (if action-specific not available)
NEURAL_TAU = -1  # Time delay for neural activations
Tp = 1           # Prediction horizon
PAD = 300        # NaN padding between clips

# Lesioning configuration
ENABLE_LESION = True
TOP_K = 5       # Number of top neurons to lesion per action
LAYERS_TO_ANALYZE = ["layer_0", "layer_1", "layer_2"]

# Action-specific embedding dimensions
ACTION_E = {
    "act0": 6, "act1": 10, "act2": 5, "act3": 5, "act4": 10,
    "act5": 6, "act6": 8,  "act7": 6,  "act8": 7
}

# Layer mapping: activation layer name -> parameter layer name
LAYER_MAPPING = {
    "layer_0": "hidden_0",
    "layer_1": "hidden_1",
    "layer_2": "hidden_2"
}

# Action names (muscle names)
action_names = ["Pec_C", "Lat", "PD", "AD", "MD", "Triceps_Lateral", "Triceps_Long", "Brachialis", "Biceps_Long"]

print(f"Configuration: analyzing {len(LAYERS_TO_ANALYZE)} layers, selecting TOP-{TOP_K} neurons per action")
print("Using action-specific embedding dimensions:")
for act_idx, (act, dim) in enumerate(ACTION_E.items()):
    action_name = action_names[act_idx] if act_idx < len(action_names) else act
    print(f"  {action_name} ({act}): E={dim}")

# ═══════════════════════════════════════════════════════════════════════════════
# HELPER FUNCTIONS - NEURAL ANALYSIS
# ═══════════════════════════════════════════════════════════════════════════════

def zscore(x, axis=None, eps=1e-12):
    """Z-score normalization with small epsilon to avoid division by zero."""
    mu = np.nanmean(x, axis=axis, keepdims=True)
    sd = np.nanstd(x, axis=axis, keepdims=True)
    return (x - mu) / np.maximum(sd, eps)

def concat_neural_action_padded(neural_dict, actions_data, pad=1, standardize=True):
    """Concatenate neural activations with action data, NaN-padded between clips."""
    N, T = next(iter(neural_dict.values())).shape[:2]
    
    frames = []
    row_cursor = 0
    
    for clip in range(N):
        # Add padding between clips (except before first clip)
        if clip > 0:
            pad_data = {"time": np.arange(row_cursor+1, row_cursor+pad+1)}
            
            # Add NaN columns for all neural activations
            for layer_name, layer_data in neural_dict.items():
                n_neurons = layer_data.shape[-1]
                for neuron in range(n_neurons):
                    pad_data[f"{layer_name}_n{neuron:03d}"] = np.nan
            
            # Add NaN columns for actions
            for a, action_name in enumerate(action_names):
                pad_data[f"action_{action_name}"] = np.nan
            
            pad_df = pd.DataFrame(pad_data)
            frames.append(pad_df)
            row_cursor += pad
        
        # Add actual data for this clip
        clip_data = {"time": np.arange(row_cursor+1, row_cursor+T+1)}
        
        # Neural activations (standardized)
        for layer_name, layer_data in neural_dict.items():
            n_neurons = layer_data.shape[-1]
            for neuron in range(n_neurons):
                x = layer_data[clip, :, neuron]
                clip_data[f"{layer_name}_n{neuron:03d}"] = zscore(x) if standardize else x
        
        # Action values (standardized)
        for a, action_name in enumerate(action_names):
            action_vals = actions_data[clip, :, a]
            clip_data[f"action_{action_name}"] = zscore(action_vals) if standardize else action_vals
        
        clip_df = pd.DataFrame(clip_data)
        frames.append(clip_df)
        row_cursor += T
    
    big_df = pd.concat(frames, ignore_index=True)
    return big_df

def build_lib_pred_from_nan_blocks(df, probe_col, split=0.6, min_edge=8):
    """Split each non-NaN block 60/40 into lib/pred for pyEDM."""
    x = df[probe_col].to_numpy()
    notnan = ~np.isnan(x)
    edges = np.diff(np.concatenate([[0], notnan.view(np.int8), [0]]))
    starts = np.where(edges == 1)[0] + 1  # 1-based indexing
    ends = np.where(edges == -1)[0]       # 1-based indexing
    
    lib_parts, pred_parts = [], []
    for s, e in zip(starts, ends):
        n = e - s + 1
        if n < 2 * min_edge:
            continue
        m = s + int(np.floor(split * n)) - 1
        m = max(m, s + min_edge - 1)
        m = min(m, e - min_edge)
        if m <= s or m >= e:
            continue
        lib_parts.extend([str(s), str(m)])
        pred_parts.extend([str(m + 1), str(e)])
    
    return " ".join(lib_parts), " ".join(pred_parts)

def safe_correlation(y, yhat, min_pairs=10):
    """Calculate correlation with safety checks."""
    y = np.asarray(y, float)
    yhat = np.asarray(yhat, float)
    mask = np.isfinite(y) & np.isfinite(yhat)
    
    if mask.sum() < min_pairs:
        return np.nan
    
    y_clean, yhat_clean = y[mask], yhat[mask]
    sy, syh = y_clean.std(ddof=0), yhat_clean.std(ddof=0)
    
    if sy <= 1e-12 or syh <= 1e-12:
        return np.nan
    
    return float(np.corrcoef(y_clean, yhat_clean)[0, 1])

def get_top_neurons_for_action(corr_df, action_name, layer_name, k=25):
    """Get top K neurons for a specific action."""
    action_col = f"action_{action_name}"
    
    # Ensure action column exists
    if action_col not in corr_df.columns:
        print(f"Warning: {action_col} not found in correlation dataframe")
        return []
    
    # Sort by absolute correlation
    corrs = corr_df[action_col].dropna().abs().sort_values(ascending=False)
    top_neurons = corrs.head(k).index.tolist()
    
    # Extract neuron indices from column names
    neuron_indices = []
    for neuron_col in top_neurons:
        neuron_idx = int(neuron_col.split('_n')[1])
        neuron_indices.append(neuron_idx)
    
    # Print top neurons and their correlations
    print(f"\nTop {k} neurons in {layer_name} for {action_name}:")
    for i, neuron_col in enumerate(top_neurons[:min(10, k)]):  # Show top 10 max
        neuron_idx = int(neuron_col.split('_n')[1])
        raw_corr = corr_df.loc[neuron_col, action_col]
        abs_corr = abs(raw_corr)
        print(f"  {i+1:2d}. Neuron {neuron_idx:3d}: {raw_corr:+.4f} (|ρ|={abs_corr:.4f})")
    
    if len(top_neurons) > 10:
        print(f"  ... and {len(top_neurons)-10} more")
    
    return neuron_indices

def neural_to_action_simplex_heatmap(df, neural_cols, action_cols, lib, pred, layer_name, action_E=None, tau=-1, Tp=0):
    """Perform simplex analysis: Neural Activations → Action Dynamics with action-specific embedding dimensions"""
    print(f"Running simplex analysis for {layer_name}: {len(neural_cols)} neurons → {len(action_cols)} actions")
    print(f"Using action-specific embedding dimensions")
    
    # Initialize correlation matrix
    corr_matrix = np.full((len(neural_cols), len(action_cols)), np.nan, dtype=float)
    
    # Total number of computations to perform
    total_pairs = len(neural_cols) * len(action_cols)
    
    # If no action-specific embedding dimensions provided, use default
    if action_E is None:
        action_E = {f"act{i}": NEURAL_E for i in range(len(action_cols))}
    
    # Print embedding dimensions used for each action
    print("\nEmbedding dimensions for each action:")
    for ai, action_col in enumerate(action_cols):
        action_idx = f"act{ai}"
        E = action_E.get(action_idx, NEURAL_E)  # Fall back to default if not specified
        print(f"  {action_col} (index {action_idx}): E={E}")
    
    # Use tqdm for progress tracking - no need for additional manual updates
    pbar = tqdm(total=total_pairs, desc=f"{layer_name}→Action Simplex")
    
    for ni, neural_col in enumerate(neural_cols):
        for ai, action_col in enumerate(action_cols):
            # Determine embedding dimension for this action
            action_idx = f"act{ai}"
            E = action_E.get(action_idx, NEURAL_E)  # Fall back to default if not specified
            
            try:
                # Run simplex prediction with action-specific embedding dimension
                pred_df = Simplex(
                    dataFrame=df,
                    lib=lib,
                    pred=pred,
                    columns=neural_col,
                    target=action_col,
                    E=E,  # Use action-specific E
                    tau=tau,
                    Tp=Tp,
                    ignoreNan=True,
                    showPlot=False
                )
                
                # Extract predictions and observations
                obs = pred_df["Observations"].to_numpy()
                hat = pred_df["Predictions"].to_numpy()
                
                # Calculate correlation
                corr_matrix[ni, ai] = safe_correlation(obs, hat, min_pairs=10)
                
            except Exception as e:
                print(f"Error for {neural_col} → {action_col} with E={E}: {e}")
                corr_matrix[ni, ai] = np.nan
            
            # Update progress bar without additional prints
            pbar.update(1)
    
    # Close progress bar
    pbar.close()
    print(f"Completed {layer_name} simplex analysis")
    
    # Convert to DataFrame
    corr_df = pd.DataFrame(corr_matrix, index=neural_cols, columns=action_cols)
    return corr_df

def analyze_layer_for_actions(layer_name, neural_dict, actions_data):
    """Analyze a single layer to identify neurons predictive of actions."""
    print(f"\n{'='*80}\nANALYZING LAYER: {layer_name} FOR ACTION PREDICTION\n{'='*80}")
    
    # Build a layer-specific neural dictionary (just this one layer)
    layer_neural_dict = {layer_name: neural_dict[layer_name]}
    
    # Build DataFrame with neural activations and actions
    big_df_actions = concat_neural_action_padded(
        neural_dict=layer_neural_dict,
        actions_data=actions_data,
        pad=PAD,
        standardize=True
    )
    
    # Extract column names for analysis
    neural_cols = [col for col in big_df_actions.columns if layer_name in col]
    action_cols = [col for col in big_df_actions.columns if col.startswith("action_")]
    
    # Build lib/pred indices
    probe_col = neural_cols[0]
    lib, pred = build_lib_pred_from_nan_blocks(big_df_actions, probe_col=probe_col, split=0.6, min_edge=8)
    
    # Run simplex analysis with action-specific embedding dimensions
    print(f"Starting {layer_name} Neural → Action Simplex Analysis...")
    corr_results = neural_to_action_simplex_heatmap(
        df=big_df_actions,
        neural_cols=neural_cols,
        action_cols=action_cols,
        lib=lib,
        pred=pred,
        layer_name=layer_name,
        action_E=ACTION_E,  # Pass the action-specific embedding dimensions
        tau=NEURAL_TAU,
        Tp=Tp
    )
    
    # Create and save heatmap
    plt.figure(figsize=(12, 10))
    
    # Compute mean absolute correlation per neuron for sorting
    mean_abs_corr = corr_results.abs().mean(axis=1).sort_values(ascending=False)
    sorted_neurons = mean_abs_corr.index.tolist()
    
    # Get top 50 neurons for visualization
    top_neurons = sorted_neurons[:50]
    top_corr_df = corr_results.loc[top_neurons]
    
    # Create heatmap
    plt.figure(figsize=(12, 14))
    sns.heatmap(top_corr_df, cmap="RdBu_r", center=0, vmin=-1, vmax=1, 
                xticklabels=True, yticklabels=True, annot=False)
    plt.title(f"{layer_name} Neural Activations → Actions Correlation (Top 50 Neurons)\nUsing action-specific embedding dimensions")
    plt.tight_layout()
    plt.savefig(f"{layer_name}_to_actions_heatmap.png", dpi=150)
    plt.close()
    
    # For each action, get top-k neurons
    top_neurons_by_action = {}
    all_top_neurons = set()
    
    for action_name in action_names:
        top_for_action = get_top_neurons_for_action(
            corr_df=corr_results,
            action_name=action_name,
            layer_name=layer_name,
            k=TOP_K
        )
        top_neurons_by_action[action_name] = top_for_action
        all_top_neurons.update(top_for_action)
    
    # Return unique neurons across all actions
    all_top_neurons_list = list(all_top_neurons)
    print(f"\nFound {len(all_top_neurons_list)} unique neurons in {layer_name} predictive of actions")
    
    return all_top_neurons_list, top_neurons_by_action, corr_results

# ═══════════════════════════════════════════════════════════════════════════════
# HELPER FUNCTIONS - LESIONING
# ═══════════════════════════════════════════════════════════════════════════════

def create_lesioned_policy_complete_multi_layer(original_policy, neurons_by_layer):
    """
    Create a lesioned version of the policy by zeroing both incoming and outgoing weights
    for multiple layers at once.
    
    Args:
        original_policy: The original policy to lesion
        neurons_by_layer: Dict mapping layer names (e.g., "hidden_0") to lists of neuron indices
    
    Returns:
        Lesioned policy
    """
    # Make a deep copy of the policy to avoid modifying the original
    lesioned_policy = jtu.tree_map(lambda x: jnp.array(x), original_policy)
    
    # Extract policy parameters
    processor_params, policy_params = lesioned_policy
    decoder_params = policy_params['params']['decoder']
    
    print(f"\nLesioning neurons across {len(neurons_by_layer)} layers...")
    
    # Process each layer
    for target_layer, neuron_indices_to_lesion in neurons_by_layer.items():
        print(f"\n=== Lesioning layer: {target_layer} ===")
        
        if target_layer not in decoder_params:
            print(f"Warning: Layer {target_layer} not found in model. Available layers: {list(decoder_params.keys())}")
            continue
        
        layer_params = decoder_params[target_layer]
        neuron_indices_array = jnp.array(neuron_indices_to_lesion)
        
        print(f"Complete lesioning of {len(neuron_indices_to_lesion)} neurons in {target_layer}")
        
        # 1. Zero outgoing weights (columns in kernel, elements in bias)
        if 'kernel' in layer_params and 'bias' in layer_params:
            kernel = layer_params['kernel']
            bias = layer_params['bias']
            
            # Filter valid indices for outgoing weights
            valid_indices = neuron_indices_array[neuron_indices_array < kernel.shape[1]]
            
            new_kernel = kernel.at[:, valid_indices].set(0.0)  # Set to 0 not -1
            new_bias = bias.at[valid_indices].set(0.0)         # Set to 0 not -1
            
            layer_params['kernel'] = new_kernel
            layer_params['bias'] = new_bias
            
            print(f"✓ Zeroed outgoing weights for {len(valid_indices)} neurons")
        
        # 2. Zero incoming weights to those neurons (rows in the SAME layer's kernel)
        if 'kernel' in layer_params:
            kernel = layer_params['kernel']
            
            # Filter valid indices for incoming weights (rows)
            valid_indices = neuron_indices_array[neuron_indices_array < kernel.shape[0]]
            
            if len(valid_indices) > 0:
                # Zero rows
                new_kernel = kernel.at[valid_indices, :].set(0.0)  # Set to 0 not -1
                layer_params['kernel'] = new_kernel
                
                print(f"✓ Zeroed incoming weights for {len(valid_indices)} neurons")
        
        # 3. Also zero outgoing connections from next layer if it exists
        layer_mapping = {
            "hidden_0": "hidden_1",
            "hidden_1": "hidden_2", 
            "hidden_2": "hidden_3"
        }
        
        next_layer = layer_mapping.get(target_layer)
        if next_layer and next_layer in decoder_params:
            next_layer_params = decoder_params[next_layer]
            if 'kernel' in next_layer_params:
                next_kernel = next_layer_params['kernel']
                valid_indices = neuron_indices_array[neuron_indices_array < next_kernel.shape[0]]
                
                if len(valid_indices) > 0:
                    # Zero rows
                    new_next_kernel = next_kernel.at[valid_indices, :].set(0.0)  # Set to 0 not -1
                    next_layer_params['kernel'] = new_next_kernel
                    
                    print(f"✓ Zeroed connections from lesioned neurons to {next_layer}")
    
    print("\n✓ Multi-layer lesioning completed successfully")
    return lesioned_policy

def verify_multi_layer_lesioning(original_policy, lesioned_policy, neurons_by_layer):
    """Verify that lesioning has been successfully applied across multiple layers."""
    print("\n=== MULTI-LAYER LESIONING VERIFICATION ===")
    
    processor_params_orig, policy_params_orig = original_policy
    processor_params_les, policy_params_les = lesioned_policy
    
    verification_passed = True
    
    for target_layer, neuron_indices in neurons_by_layer.items():
        print(f"\nVerifying layer: {target_layer}")
        
        try:
            # Get the target layer parameters
            layer_params_orig = policy_params_orig['params']['decoder'][target_layer]
            layer_params_les = policy_params_les['params']['decoder'][target_layer]
            
            # Check a subset of neuron weights
            sample_neurons = neuron_indices[:min(3, len(neuron_indices))]
            
            for idx in sample_neurons:
                if idx < layer_params_les['kernel'].shape[1]:
                    # Check bias
                    bias_orig = layer_params_orig['bias'][idx]
                    bias_les = layer_params_les['bias'][idx]
                    bias_zeroed = jnp.allclose(bias_les, 0.0)
                    
                    # Check outgoing weights
                    outgoing_orig = layer_params_orig['kernel'][:, idx]
                    outgoing_les = layer_params_les['kernel'][:, idx]
                    outgoing_zeroed = jnp.allclose(outgoing_les, jnp.zeros_like(outgoing_les))
                    
                    print(f"  Neuron {idx}: Bias zeroed: {bias_zeroed}, Outgoing weights zeroed: {outgoing_zeroed}")
                    
                    if not bias_zeroed or not outgoing_zeroed:
                        verification_passed = False
            
            # Check overall changes
            kernel_diff = jnp.sum(jnp.abs(layer_params_orig['kernel'] - layer_params_les['kernel']))
            bias_diff = jnp.sum(jnp.abs(layer_params_orig['bias'] - layer_params_les['bias']))
            print(f"  Total weight changes: Kernel diff: {kernel_diff:.2f}, Bias diff: {bias_diff:.2f}")
            
            if kernel_diff <= 0 or bias_diff <= 0:
                verification_passed = False
                
        except Exception as e:
            print(f"Error in verification for {target_layer}: {e}")
            verification_passed = False
    
    return verification_passed

# ═══════════════════════════════════════════════════════════════════════════════
# ACTION COMPARISON FUNCTIONS
# ═══════════════════════════════════════════════════════════════════════════════

def compare_actions(baseline_actions, lesioned_actions):
    """Compare baseline and lesioned actions, generate plots and statistics"""
    print("\n=== COMPARING BASELINE VS LESIONED ACTIONS ===")
    
    n_actions = baseline_actions.shape[1]
    summary_data = []
    
    # Create figure for overall comparison
    plt.figure(figsize=(16, 14))
    
    # Process each action
    for action_idx, action_name in enumerate(action_names[:n_actions]):
        baseline_action = baseline_actions[:, action_idx]
        lesioned_action = lesioned_actions[:, action_idx]
        
        # Calculate metrics
        mae = np.mean(np.abs(baseline_action - lesioned_action))
        rmse = np.sqrt(np.mean((baseline_action - lesioned_action)**2))
        max_diff = np.max(np.abs(baseline_action - lesioned_action))
        
        # Store summary data
        summary_data.append({
            'action': action_name,
            'mae': mae,
            'rmse': rmse,
            'max_diff': max_diff
        })
        
        print(f"\n=== {action_name.upper()} ACTION STATISTICS ===")
        print(f"Mean Absolute Error: {mae:.4f}")
        print(f"RMSE: {rmse:.4f}")
        print(f"Maximum Difference: {max_diff:.4f}")
        
        # Individual action plots
        plt.subplot(3, 3, action_idx + 1)
        plt.plot(baseline_action, label='Baseline', color='blue', alpha=0.7)
        plt.plot(lesioned_action, label='Lesioned', color='red', alpha=0.7)
        plt.title(f'{action_name} Actions')
        plt.xlabel('Frame')
        plt.ylabel('Action Value')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Create detailed action figure
        plt.figure(figsize=(14, 12))
        
        # Plot 1: Time series comparison
        plt.subplot(2, 2, 1)
        plt.plot(baseline_action, label='Baseline', color='blue', alpha=0.7)
        plt.plot(lesioned_action, label='Lesioned', color='red', alpha=0.7)
        plt.title(f'{action_name} Actions')
        plt.xlabel('Frame')
        plt.ylabel('Action Value')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot 2: Difference between baseline and lesioned
        plt.subplot(2, 2, 2)
        plt.plot(lesioned_action - baseline_action, color='purple', alpha=0.7)
        plt.title(f'{action_name} Action Difference (Lesioned - Baseline)')
        plt.xlabel('Frame')
        plt.ylabel('Difference')
        plt.axhline(y=0, color='k', linestyle='-', alpha=0.3)
        plt.grid(True, alpha=0.3)
        
        # Plot 3: Distribution comparison
        plt.subplot(2, 2, 3)
        sns.kdeplot(baseline_action, label='Baseline', color='blue', fill=True, alpha=0.3)
        sns.kdeplot(lesioned_action, label='Lesioned', color='red', fill=True, alpha=0.3)
        plt.title(f'{action_name} Action Distribution')
        plt.xlabel('Action Value')
        plt.ylabel('Density')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot 4: Scatter plot of lesioned vs baseline
        plt.subplot(2, 2, 4)
        plt.scatter(baseline_action, lesioned_action, alpha=0.5, s=10)
        plt.plot([min(baseline_action), max(baseline_action)], 
                 [min(baseline_action), max(baseline_action)], 
                 'k--', alpha=0.5)
        plt.title(f'{action_name} Lesioned vs Baseline')
        plt.xlabel('Baseline Action')
        plt.ylabel('Lesioned Action')
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(Path(ckpt_path) / f"{action_name}_action_analysis.png", dpi=150)
        plt.close()
    
    # Finalize and save the action comparison plot
    plt.tight_layout()
    plt.savefig(Path(ckpt_path) / "all_actions_comparison.png", dpi=150)
    plt.close()
    
    # Create a summary bar chart comparing the effect of lesioning on each action
    plt.figure(figsize=(12, 8))
    
    # Bar chart of RMSE per action
    actions = [d['action'] for d in summary_data]
    rmse_values = [d['rmse'] for d in summary_data]
    
    bars = plt.bar(actions, rmse_values, color='skyblue')
    plt.title('RMSE Between Baseline and Lesioned Actions')
    plt.xlabel('Action')
    plt.ylabel('RMSE')
    plt.xticks(rotation=45)
    plt.grid(axis='y', alpha=0.3)
    
    # Annotate with values
    for bar, value in zip(bars, rmse_values):
        plt.text(bar.get_x() + bar.get_width()/2., 
                 value + 0.01,
                 f'{value:.4f}', 
                 ha='center', va='bottom', rotation=0)
    
    plt.tight_layout()
    plt.savefig(Path(ckpt_path) / "action_rmse_summary.png", dpi=150)
    plt.close()
    
    # Print summary table
    print("\n=== SUMMARY OF LESIONING EFFECTS ACROSS ACTIONS ===")
    print(f"{'Action':<20} {'MAE':>10} {'RMSE':>10} {'Max Difference':>15}")
    print("-" * 60)
    for data in summary_data:
        print(f"{data['action']:<20} {data['mae']:>10.4f} {data['rmse']:>10.4f} {data['max_diff']:>15.4f}")
        
    return summary_data

# ═══════════════════════════════════════════════════════════════════════════════
# MAIN EXECUTION
# ═══════════════════════════════════════════════════════════════════════════════

print("\n=== STARTING NEURAL ANALYSIS AND LESIONING ===")

# Step 1: Load neural and action data for analysis
print("\nLoading neural activation and action data...")
with h5py.File(H5_PATH, "r") as f:
    actions = f["actions"][...]  # Load actions data
    
    # Neural activations from all decoder layers
    layer_data = {}
    for layer_name in LAYERS_TO_ANALYZE:
        layer_data[layer_name] = f[f"decoder_activations/{layer_name}"][:]
        print(f"{layer_name} shape: {layer_data[layer_name].shape}")
    
    print(f"Actions shape: {actions.shape}")

# Step 2: Analyze each layer to identify neurons predictive of actions
neurons_by_layer = {}
neurons_by_action = {}
correlation_results = {}

for layer_name in LAYERS_TO_ANALYZE:
    # Find the corresponding parameter layer name
    param_layer = LAYER_MAPPING[layer_name]
    # Analyze layer to get top neurons
    top_neurons, top_by_action, corr_df = analyze_layer_for_actions(layer_name, layer_data, actions)
    neurons_by_layer[param_layer] = top_neurons
    neurons_by_action[layer_name] = top_by_action
    correlation_results[layer_name] = corr_df

print("\n=== IDENTIFIED NEURONS TO LESION ===")
for layer_name, neurons in neurons_by_layer.items():
    print(f"{layer_name}: {len(neurons)} unique neurons")

# Step 3: Load checkpoint and config
ckpt_path = Path.cwd().parent / "model_checkpoints/250826_030533_134914"
ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path)
cfg = ckpt["cfg"]

# Configure data path
cfg.data_path = "/root/vast/eric/stac-mjx/refined_STACed_data/A36-1_2023-07-18_16-54-01_lightOff_trial01_ik.h5,/root/vast/eric/stac-mjx/refined_STACed_data/A36-1_2023-07-18_16-54-01_lightOff_trial04_ik.h5,/root/vast/eric/stac-mjx/refined_STACed_data/A36-1_2023-07-18_16-54-01_lightOff_trial09_ik.h5,/root/vast/eric/stac-mjx/refined_STACed_data/A36-1_2023-07-18_16-54-01_lightOff_trial10_ik.h5,/root/vast/eric/stac-mjx/refined_STACed_data/A36-1_2023-07-18_16-54-01_lightOff_trial13_ik.h5"
cfg.train_setup.checkpoint_to_restore = ckpt_path

# Create environment for testing
env = rollout.create_environment(cfg)

# Step 4: Create baseline (non-lesioned) rollout for comparison
print("\n=== GENERATING BASELINE ROLLOUT FOR COMPARISON ===")
baseline_inference_fn = checkpointing.load_inference_fn(cfg, ckpt["policy"])
baseline_rollout_gen = rollout.create_rollout_generator(
    cfg, env, baseline_inference_fn, 
    log_activations=True, log_metrics=True, log_sensor_data=True
)
baseline_rollout = baseline_rollout_gen(clip_idx=1)
print("✓ Baseline rollout generated")

# Step 5: Apply multi-layer lesioning to the policy
if ENABLE_LESION:
    print("\n=== CREATING MULTI-LAYER LESIONED POLICY ===")
    
    # Create policy network structure first
    original_policy = ckpt["policy"]
    lesioned_policy = create_lesioned_policy_complete_multi_layer(
        original_policy, 
        neurons_by_layer
    )
    
    # Verify that lesioning was properly applied
    verification_passed = verify_multi_layer_lesioning(original_policy, lesioned_policy, neurons_by_layer)
    print(f"\nMulti-layer lesioning verification {'PASSED' if verification_passed else 'FAILED'}")
    
    # Replace the policy in the checkpoint
    ckpt["policy"] = lesioned_policy
    print(f"Checkpoint policy updated with lesioned policy")
    
    # Use the standard loader with the modified checkpoint
    print("\n=== SETTING UP LESIONED INFERENCE FUNCTION ===")
    inference_fn = checkpointing.load_inference_fn(cfg, ckpt["policy"])
    print("✓ Lesioned inference function created")
else:
    # Use the original policy from checkpoint
    inference_fn = checkpointing.load_inference_fn(cfg, ckpt["policy"])
    print("✓ Original inference function loaded")

# Step 6: Generate rollout with the lesioned policy
print("\n=== GENERATING ROLLOUT WITH MULTI-LAYER LESIONED POLICY ===")
generate_rollout = rollout.create_rollout_generator(
    cfg, 
    env, 
    inference_fn, 
    log_activations=True, 
    log_metrics=True, 
    log_sensor_data=True
)

lesioned_rollout = generate_rollout(clip_idx=1)
print("✓ Lesioned rollout generated")

# Step 7: Extract and analyze activations from rollout
if ENABLE_LESION and lesioned_rollout.get('activations') is not None:
    print("\n=== ANALYZING ROLLOUT ACTIVATIONS ===")
    activations = lesioned_rollout['activations']['decoder']
    
    # Check each layer
    for activation_layer, target_layer in LAYER_MAPPING.items():
        if activation_layer in activations:
            layer_acts = activations[activation_layer]
            neurons_to_check = neurons_by_layer[target_layer]
            
            # Check activation stats for lesioned neurons
            valid_neurons = [n for n in neurons_to_check if n < layer_acts.shape[-1]]
            if valid_neurons:
                lesioned_acts = layer_acts[:, valid_neurons]
                avg_activation = jnp.mean(jnp.abs(lesioned_acts))
                max_activation = jnp.max(jnp.abs(lesioned_acts))
                
                print(f"\n{activation_layer} ({target_layer}) lesioned neurons stats:")
                print(f"  Average absolute activation: {avg_activation:.6f}")
                print(f"  Maximum absolute activation: {max_activation:.6f}")
                print(f"  {'⚠️ NEURONS STILL ACTIVE' if max_activation > 0.1 else '✓ NEURONS PROPERLY SILENCED'}")
        else:
            print(f"Warning: {activation_layer} not found in activations")

# Step 8: Compare actions between baseline and lesioned rollouts
print("\n=== COMPARING BASELINE AND LESIONED ACTIONS ===")
baseline_actions = np.array(baseline_rollout['actions']) if 'actions' in baseline_rollout else None
lesioned_actions = np.array(lesioned_rollout['actions']) if 'actions' in lesioned_rollout else None

if baseline_actions is not None and lesioned_actions is not None:
    action_summary = compare_actions(baseline_actions, lesioned_actions)
else:
    print("ERROR: Could not extract actions from rollouts for comparison")

# Step 9: Render comparison videos
print("\n=== RENDERING COMPARISON VIDEOS ===")

# Render lesioned video
lesioned_frames, lesioned_framerate = render.render_rollout(
    cfg, 
    lesioned_rollout, 
    height=480,
    width=640,
)

# Render baseline video
baseline_frames, baseline_framerate = render.render_rollout(
    cfg, 
    baseline_rollout, 
    height=480,
    width=640,
)

# Save the videos
lesioned_video_path = Path(ckpt_path) / f"rollout_actions_lesioned_top{TOP_K}.mp4"
baseline_video_path = Path(ckpt_path) / f"rollout_baseline.mp4"

media.write_video(lesioned_video_path, lesioned_frames, fps=lesioned_framerate)
media.write_video(baseline_video_path, baseline_frames, fps=baseline_framerate)

print(f"✓ Lesioned video saved to {lesioned_video_path}")
print(f"✓ Baseline video saved to {baseline_video_path}")

# Display the lesioned video
print("\n=== DISPLAYING MULTI-LAYER LESIONED VIDEO ===")
media.show_video(lesioned_frames, fps=lesioned_framerate)

# Step 10: Compare joint positions between baseline and lesioned rollouts
print("\n=== COMPARING CONTROL VS LESIONED PERFORMANCE ACROSS ALL JOINTS ===")

# Joint names for better labeling
JOINT_NAMES = ["sh_elv", "sh_ext", "sh_rot", "elbow"]

def extract_joint_data(rollout, key_name='qposes_rollout'):
    """Extract data for all joints from specified key"""
    if key_name in rollout:
        print(f"Found {key_name} with shape: {rollout[key_name].shape}")
        return np.array(rollout[key_name])
    
    print(f"Could not find {key_name}. Available arrays:")
    for key, value in rollout.items():
        if isinstance(value, np.ndarray):
            print(f"  {key}: {value.shape}")
        elif hasattr(value, 'keys'):
            print(f"  {key} (dict/object with keys): {list(value.keys())}")
    
    return None

# Extract data for all joints
baseline_joints = extract_joint_data(baseline_rollout, 'qposes_rollout')
lesioned_joints = extract_joint_data(lesioned_rollout, 'qposes_rollout')
reference_joints = extract_joint_data(baseline_rollout, 'qposes_ref')

# If reference is not found in baseline, try lesioned
if reference_joints is None:
    reference_joints = extract_joint_data(lesioned_rollout, 'qposes_ref')

# Extract rewards
baseline_rewards = np.array(baseline_rollout['rewards']) if 'rewards' in baseline_rollout else None
lesioned_rewards = np.array(lesioned_rollout['rewards']) if 'rewards' in lesioned_rollout else None

# Fall back to state_rewards if needed
if baseline_rewards is None and 'state_rewards' in baseline_rollout:
    baseline_rewards = np.array(baseline_rollout['state_rewards'])
if lesioned_rewards is None and 'state_rewards' in lesioned_rollout:
    lesioned_rewards = np.array(lesioned_rollout['state_rewards'])

# Proceed only if we found all the necessary data
if baseline_joints is None or lesioned_joints is None or reference_joints is None:
    print("Could not find required joint data in the rollouts")
else:
    # Print reward statistics
    if baseline_rewards is not None and lesioned_rewards is not None:
        print("\n=== OVERALL REWARD STATISTICS ===")
        print(f"Control mean reward: {np.mean(baseline_rewards):.4f}, Lesioned mean reward: {np.mean(lesioned_rewards):.4f}")
        print(f"Reward reduction: {np.mean(baseline_rewards) - np.mean(lesioned_rewards):.4f} ({(1 - np.mean(lesioned_rewards)/np.mean(baseline_rewards))*100:.1f}%)")
        
        # Plot reward comparison
        plt.figure(figsize=(8, 6))
        plt.plot(baseline_rewards, label='Control', color='blue', alpha=0.7)
        plt.plot(lesioned_rewards, label='Lesioned', color='red', alpha=0.7)
        plt.title('Reward Over Time')
        plt.xlabel('Frame')
        plt.ylabel('Reward')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(Path(ckpt_path) / "overall_lesion_rewards.png", dpi=150)
        plt.close()
    
    # Calculate overall metrics for summary table
    summary_data = []
    
    # Create a 2x2 grid of joint plots
    plt.figure(figsize=(16, 14))
    
    # Process each joint
    for joint_idx, joint_name in enumerate(JOINT_NAMES):
        baseline_joint = baseline_joints[:, joint_idx]
        lesioned_joint = lesioned_joints[:, joint_idx]
        reference_joint = reference_joints[:, joint_idx]
        
        # Calculate error metrics
        baseline_error = baseline_joint - reference_joint
        lesioned_error = lesioned_joint - reference_joint
        
        # Calculate absolute errors for visualization
        baseline_abs_error = np.abs(baseline_error)
        lesioned_abs_error = np.abs(lesioned_error)
        
        # Calculate statistics
        baseline_rmse = np.sqrt(np.mean(baseline_error**2))
        lesioned_rmse = np.sqrt(np.mean(lesioned_error**2))
        baseline_mae = np.mean(baseline_abs_error)
        lesioned_mae = np.mean(lesioned_abs_error)
        error_var_ratio = np.var(lesioned_error) / np.var(baseline_error)
        
        # Store summary data
        summary_data.append({
            'joint': joint_name,
            'baseline_rmse': baseline_rmse,
            'lesioned_rmse': lesioned_rmse,
            'rmse_increase': (lesioned_rmse - baseline_rmse),
            'rmse_percent': (lesioned_rmse / baseline_rmse - 1) * 100,
            'baseline_mae': baseline_mae,
            'lesioned_mae': lesioned_mae,
            'mae_increase': (lesioned_mae - baseline_mae),
            'mae_percent': (lesioned_mae / baseline_mae - 1) * 100,
            'error_var_ratio': error_var_ratio
        })
        
        print(f"\n=== {joint_name.upper()} ERROR STATISTICS ===")
        print(f"Control mean error: {np.mean(baseline_error):.4f} rad, Lesioned mean error: {np.mean(lesioned_error):.4f} rad")
        print(f"Control RMSE: {baseline_rmse:.4f} rad, Lesioned RMSE: {lesioned_rmse:.4f} rad")
        print(f"Control MAE: {baseline_mae:.4f} rad, Lesioned MAE: {lesioned_mae:.4f} rad")
        print(f"Control error std: {np.std(baseline_error):.4f} rad, Lesioned error std: {np.std(lesioned_error):.4f} rad")
        print(f"Error variance ratio (lesioned/control): {error_var_ratio:.4f}")
        
        # Individual joint plots
        plt.subplot(2, 2, joint_idx + 1)
        
        # Plot time series with reference
        plt.plot(reference_joint, label='Reference', color='green', linestyle='--', linewidth=2)
        plt.plot(baseline_joint, label='Control', color='blue', alpha=0.7)
        plt.plot(lesioned_joint, label='Lesioned', color='red', alpha=0.7)
        plt.title(f'{joint_name.upper()} Trajectories')
        plt.xlabel('Frame')
        plt.ylabel('Joint Position (rad)')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Create detailed single-joint figure
        plt.figure(figsize=(14, 12))
        
        # Plot 1: Time series comparison with reference
        plt.subplot(2, 2, 1)
        plt.plot(reference_joint, label='Reference', color='green', linestyle='--', linewidth=2)
        plt.plot(baseline_joint, label='Control', color='blue', alpha=0.7)
        plt.plot(lesioned_joint, label='Lesioned', color='red', alpha=0.7)
        plt.title(f'{joint_name.upper()} Trajectories')
        plt.xlabel('Frame')
        plt.ylabel('Joint Position (rad)')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot 2: Error distribution comparison (using absolute errors)
        plt.subplot(2, 2, 2)
        sns.kdeplot(baseline_abs_error, label='Control Error', color='blue', fill=True, alpha=0.3)
        sns.kdeplot(lesioned_abs_error, label='Lesioned Error', color='red', fill=True, alpha=0.3)
        plt.title(f'{joint_name.upper()} Absolute Error Distribution')
        plt.xlabel('Absolute Error (rad)')
        plt.ylabel('Density')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot 3: Absolute error time series
        plt.subplot(2, 2, 3)
        plt.plot(baseline_abs_error, label='Control', color='blue', alpha=0.7)
        plt.plot(lesioned_abs_error, label='Lesioned', color='red', alpha=0.7)
        plt.title(f'{joint_name.upper()} Absolute Error Over Time')
        plt.xlabel('Frame')
        plt.ylabel('Absolute Error (rad)')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot 4: Box plot comparison of absolute errors
        plt.subplot(2, 2, 4)
        error_data = [baseline_abs_error, lesioned_abs_error]
        plt.boxplot(error_data, labels=['Control Error', 'Lesioned Error'])
        plt.title(f'{joint_name.upper()} Absolute Error Distribution')
        plt.ylabel('Absolute Error (rad)')
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(Path(ckpt_path) / f"{joint_name}_lesion_error_analysis.png", dpi=150)
        plt.close()

    # Finalize and save the joint comparison plot
    plt.tight_layout()
    plt.savefig(Path(ckpt_path) / "all_joints_lesion_comparison.png", dpi=150)
    plt.close()
    
    # Create a summary bar chart comparing the effect of lesioning on each joint
    plt.figure(figsize=(12, 8))
    
    # Bar chart of RMSE percent increase
    joints = [d['joint'] for d in summary_data]
    rmse_pct_increases = [d['rmse_percent'] for d in summary_data]
    
    bars = plt.bar(joints, rmse_pct_increases, color='skyblue')
    plt.axhline(y=0, color='k', linestyle='-', alpha=0.3)
    plt.title('Percent Increase in RMSE After Lesioning')
    plt.xlabel('Joint')
    plt.ylabel('% Increase in RMSE')
    plt.grid(axis='y', alpha=0.3)
    
    # Annotate with values
    for bar, value in zip(bars, rmse_pct_increases):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., 
                 height + (5 if height > 0 else -15),
                 f'{value:.1f}%', 
                 ha='center', va='bottom', rotation=0)
    
    plt.tight_layout()
    plt.savefig(Path(ckpt_path) / "joint_error_increase_summary.png", dpi=150)
    plt.close()
    
    # Print summary table
    print("\n=== SUMMARY OF LESIONING EFFECTS ACROSS JOINTS ===")
    print(f"{'Joint':<10} {'Baseline RMSE':>13} {'Lesioned RMSE':>14} {'RMSE Δ%':>10} {'Baseline MAE':>13} {'Lesioned MAE':>14} {'MAE Δ%':>10} {'Error Var Ratio':>15}")
    print("-" * 100)
    for data in summary_data:
        print(f"{data['joint']:<10} {data['baseline_rmse']:>13.4f} {data['lesioned_rmse']:>14.4f} {data['rmse_percent']:>+10.1f}% {data['baseline_mae']:>13.4f} {data['lesioned_mae']:>14.4f} {data['mae_percent']:>+10.1f}% {data['error_var_ratio']:>15.2f}")

# Step 11: Print summary of lesioning
print("\n=== MULTI-LAYER ACTION-PREDICTING NEURON LESIONING SUMMARY ===")
for layer_name, neurons in neurons_by_layer.items():
    print(f"{layer_name}: Lesioned {len(neurons)} unique neurons")

total_neurons = sum(len(neurons) for neurons in neurons_by_layer.values())
print(f"Total neurons lesioned: {total_neurons}")
print(f"\nThis experiment dynamically identified and lesioned the top {TOP_K} neurons")
print("from each layer that are most predictive of each action/muscle activation")
print("Using action-specific embedding dimensions for optimal analysis")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Configuration: analyzing 3 layers, selecting TOP-5 neurons per action
Using action-specific embedding dimensions:
  Pec_C (act0): E=6
  Lat (act1): E=10
  PD (act2): E=5
  AD (act3): E=5
  MD (act4): E=10
  Triceps_Lateral (act5): E=6
  Triceps_Long (act6): E=8
  Brachialis (act7): E=6
  Biceps_Long (act8): E=7

=== STARTING NEURAL ANALYSIS AND LESIONING ===

Loading neural activation and action data...
layer_0 shape: (46, 100, 512)
layer_1 shape: (46, 100, 512)
layer_2 shape: (46, 100, 512)
Actions shape: (46, 100, 9)

ANALYZING LAYER: layer_0 FOR ACTION PREDICTION


  big_df = pd.concat(frames, ignore_index=True)


Starting layer_0 Neural → Action Simplex Analysis...
Running simplex analysis for layer_0: 512 neurons → 9 actions
Using action-specific embedding dimensions

Embedding dimensions for each action:
  action_Pec_C (index act0): E=6
  action_Lat (index act1): E=10
  action_PD (index act2): E=5
  action_AD (index act3): E=5
  action_MD (index act4): E=10
  action_Triceps_Lateral (index act5): E=6
  action_Triceps_Long (index act6): E=8
  action_Brachialis (index act7): E=6
  action_Biceps_Long (index act8): E=7


layer_0→Action Simplex: 100%|██████████| 4608/4608 [01:21<00:00, 56.67it/s]


Completed layer_0 simplex analysis

Top 5 neurons in layer_0 for Pec_C:
   1. Neuron 135: +0.5627 (|ρ|=0.5627)
   2. Neuron 428: +0.5436 (|ρ|=0.5436)
   3. Neuron 222: +0.5374 (|ρ|=0.5374)
   4. Neuron 322: +0.5241 (|ρ|=0.5241)
   5. Neuron 315: +0.5100 (|ρ|=0.5100)

Top 5 neurons in layer_0 for Lat:
   1. Neuron 156: +0.5069 (|ρ|=0.5069)
   2. Neuron   9: +0.4964 (|ρ|=0.4964)
   3. Neuron 396: +0.4946 (|ρ|=0.4946)
   4. Neuron 275: +0.4912 (|ρ|=0.4912)
   5. Neuron  33: +0.4906 (|ρ|=0.4906)

Top 5 neurons in layer_0 for PD:
   1. Neuron 356: +0.2329 (|ρ|=0.2329)
   2. Neuron 229: +0.2121 (|ρ|=0.2121)
   3. Neuron 131: +0.1998 (|ρ|=0.1998)
   4. Neuron 362: +0.1990 (|ρ|=0.1990)
   5. Neuron 505: +0.1950 (|ρ|=0.1950)

Top 5 neurons in layer_0 for AD:
   1. Neuron 322: +0.3319 (|ρ|=0.3319)
   2. Neuron 135: +0.3252 (|ρ|=0.3252)
   3. Neuron 289: +0.2972 (|ρ|=0.2972)
   4. Neuron 282: +0.2860 (|ρ|=0.2860)
   5. Neuron 217: +0.2775 (|ρ|=0.2775)

Top 5 neurons in layer_0 for MD:
   1. Neuro

  big_df = pd.concat(frames, ignore_index=True)


Starting layer_1 Neural → Action Simplex Analysis...
Running simplex analysis for layer_1: 512 neurons → 9 actions
Using action-specific embedding dimensions

Embedding dimensions for each action:
  action_Pec_C (index act0): E=6
  action_Lat (index act1): E=10
  action_PD (index act2): E=5
  action_AD (index act3): E=5
  action_MD (index act4): E=10
  action_Triceps_Lateral (index act5): E=6
  action_Triceps_Long (index act6): E=8
  action_Brachialis (index act7): E=6
  action_Biceps_Long (index act8): E=7


layer_1→Action Simplex: 100%|██████████| 4608/4608 [01:21<00:00, 56.39it/s]


Completed layer_1 simplex analysis

Top 5 neurons in layer_1 for Pec_C:
   1. Neuron 192: +0.5968 (|ρ|=0.5968)
   2. Neuron 257: +0.5787 (|ρ|=0.5787)
   3. Neuron 435: +0.5675 (|ρ|=0.5675)
   4. Neuron 201: +0.5631 (|ρ|=0.5631)
   5. Neuron 178: +0.5374 (|ρ|=0.5374)

Top 5 neurons in layer_1 for Lat:
   1. Neuron 251: +0.5689 (|ρ|=0.5689)
   2. Neuron 168: +0.5311 (|ρ|=0.5311)
   3. Neuron 137: +0.5039 (|ρ|=0.5039)
   4. Neuron 399: +0.4708 (|ρ|=0.4708)
   5. Neuron  62: +0.4611 (|ρ|=0.4611)

Top 5 neurons in layer_1 for PD:
   1. Neuron 189: +0.2982 (|ρ|=0.2982)
   2. Neuron 310: +0.2913 (|ρ|=0.2913)
   3. Neuron 224: +0.2812 (|ρ|=0.2812)
   4. Neuron 414: +0.2389 (|ρ|=0.2389)
   5. Neuron 232: +0.2145 (|ρ|=0.2145)

Top 5 neurons in layer_1 for AD:
   1. Neuron 134: +0.2770 (|ρ|=0.2770)
   2. Neuron 400: +0.2684 (|ρ|=0.2684)
   3. Neuron 111: +0.2538 (|ρ|=0.2538)
   4. Neuron 162: +0.2530 (|ρ|=0.2530)
   5. Neuron 333: +0.2523 (|ρ|=0.2523)

Top 5 neurons in layer_1 for MD:
   1. Neuro

  big_df = pd.concat(frames, ignore_index=True)


Starting layer_2 Neural → Action Simplex Analysis...
Running simplex analysis for layer_2: 512 neurons → 9 actions
Using action-specific embedding dimensions

Embedding dimensions for each action:
  action_Pec_C (index act0): E=6
  action_Lat (index act1): E=10
  action_PD (index act2): E=5
  action_AD (index act3): E=5
  action_MD (index act4): E=10
  action_Triceps_Lateral (index act5): E=6
  action_Triceps_Long (index act6): E=8
  action_Brachialis (index act7): E=6
  action_Biceps_Long (index act8): E=7


layer_2→Action Simplex: 100%|██████████| 4608/4608 [01:21<00:00, 56.51it/s]


Completed layer_2 simplex analysis

Top 5 neurons in layer_2 for Pec_C:
   1. Neuron 461: +0.6496 (|ρ|=0.6496)
   2. Neuron 340: +0.6481 (|ρ|=0.6481)
   3. Neuron  82: +0.6188 (|ρ|=0.6188)
   4. Neuron 312: +0.6163 (|ρ|=0.6163)
   5. Neuron 156: +0.6002 (|ρ|=0.6002)

Top 5 neurons in layer_2 for Lat:
   1. Neuron  69: +0.6639 (|ρ|=0.6639)
   2. Neuron 325: +0.6411 (|ρ|=0.6411)
   3. Neuron 460: +0.5717 (|ρ|=0.5717)
   4. Neuron 200: +0.5580 (|ρ|=0.5580)
   5. Neuron 377: +0.5468 (|ρ|=0.5468)

Top 5 neurons in layer_2 for PD:
   1. Neuron 415: +0.3543 (|ρ|=0.3543)
   2. Neuron 298: +0.2534 (|ρ|=0.2534)
   3. Neuron 176: +0.2177 (|ρ|=0.2177)
   4. Neuron 261: +0.2133 (|ρ|=0.2133)
   5. Neuron 318: +0.2075 (|ρ|=0.2075)

Top 5 neurons in layer_2 for AD:
   1. Neuron 340: +0.3689 (|ρ|=0.3689)
   2. Neuron 246: +0.3236 (|ρ|=0.3236)
   3. Neuron  93: +0.3109 (|ρ|=0.3109)
   4. Neuron 150: +0.3070 (|ρ|=0.3070)
   5. Neuron 351: +0.3026 (|ρ|=0.3026)

Top 5 neurons in layer_2 for MD:
   1. Neuro

100%|██████████| 200/200 [00:00<00:00, 671.72it/s]


MuJoCo Rendering with Ghost Model...


100%|██████████| 200/200 [00:00<00:00, 690.85it/s]


✓ Lesioned video saved to /root/vast/eric/track-mjx/model_checkpoints/250826_030533_134914/rollout_actions_lesioned_top5.mp4
✓ Baseline video saved to /root/vast/eric/track-mjx/model_checkpoints/250826_030533_134914/rollout_baseline.mp4

=== DISPLAYING MULTI-LAYER LESIONED VIDEO ===


0
This browser does not support the video tag.



=== COMPARING CONTROL VS LESIONED PERFORMANCE ACROSS ALL JOINTS ===
Found qposes_rollout with shape: (200, 4)
Found qposes_rollout with shape: (200, 4)
Found qposes_ref with shape: (200, 4)

=== OVERALL REWARD STATISTICS ===
Control mean reward: 4.8536, Lesioned mean reward: 4.6778
Reward reduction: 0.1758 (3.6%)

=== SH_ELV ERROR STATISTICS ===
Control mean error: -0.1592 rad, Lesioned mean error: -0.2546 rad
Control RMSE: 0.2325 rad, Lesioned RMSE: 0.3624 rad
Control MAE: 0.2035 rad, Lesioned MAE: 0.3265 rad
Control error std: 0.1694 rad, Lesioned error std: 0.2580 rad
Error variance ratio (lesioned/control): 2.3178


  plt.boxplot(error_data, labels=['Control Error', 'Lesioned Error'])



=== SH_EXT ERROR STATISTICS ===
Control mean error: 0.0493 rad, Lesioned mean error: 0.0698 rad
Control RMSE: 0.1865 rad, Lesioned RMSE: 0.3195 rad
Control MAE: 0.1077 rad, Lesioned MAE: 0.2254 rad
Control error std: 0.1799 rad, Lesioned error std: 0.3118 rad
Error variance ratio (lesioned/control): 3.0031


  plt.boxplot(error_data, labels=['Control Error', 'Lesioned Error'])



=== SH_ROT ERROR STATISTICS ===
Control mean error: -0.0943 rad, Lesioned mean error: -0.1720 rad
Control RMSE: 0.1657 rad, Lesioned RMSE: 0.2450 rad
Control MAE: 0.1212 rad, Lesioned MAE: 0.1993 rad
Control error std: 0.1362 rad, Lesioned error std: 0.1745 rad
Error variance ratio (lesioned/control): 1.6406


  plt.boxplot(error_data, labels=['Control Error', 'Lesioned Error'])



=== ELBOW ERROR STATISTICS ===
Control mean error: -0.0472 rad, Lesioned mean error: -0.0202 rad
Control RMSE: 0.0927 rad, Lesioned RMSE: 0.1394 rad
Control MAE: 0.0762 rad, Lesioned MAE: 0.1145 rad
Control error std: 0.0798 rad, Lesioned error std: 0.1379 rad
Error variance ratio (lesioned/control): 2.9894


  plt.boxplot(error_data, labels=['Control Error', 'Lesioned Error'])



=== SUMMARY OF LESIONING EFFECTS ACROSS JOINTS ===
Joint      Baseline RMSE  Lesioned RMSE    RMSE Δ%  Baseline MAE   Lesioned MAE     MAE Δ% Error Var Ratio
----------------------------------------------------------------------------------------------------
sh_elv            0.2325         0.3624      +55.9%        0.2035         0.3265      +60.4%            2.32
sh_ext            0.1865         0.3195      +71.3%        0.1077         0.2254     +109.3%            3.00
sh_rot            0.1657         0.2450      +47.9%        0.1212         0.1993      +64.5%            1.64
elbow             0.0927         0.1394      +50.4%        0.0762         0.1145      +50.3%            2.99

=== MULTI-LAYER ACTION-PREDICTING NEURON LESIONING SUMMARY ===
hidden_0: Lesioned 36 unique neurons
hidden_1: Lesioned 37 unique neurons
hidden_2: Lesioned 34 unique neurons
Total neurons lesioned: 107

This experiment dynamically identified and lesioned the top 5 neurons
from each layer that are most 

<Figure size 1200x1000 with 0 Axes>

<Figure size 1200x1000 with 0 Axes>

<Figure size 1200x1000 with 0 Axes>

In [None]:
%load_ext autoreload
%autoreload 2

import os, copy, yaml, h5py
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm import tqdm
from omegaconf import OmegaConf
from orbax import checkpoint as ocp
from pyEDM import Simplex
from scipy import signal

# Environment setup
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["MUJOCO_GL"] = "egl"        # or "osmesa"
os.environ["PYOPENGL_PLATFORM"] = "egl"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = os.getenv("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.4")

import mediapy as media
from track_mjx.agent import checkpointing
from track_mjx.analysis import rollout, render, utils
from track_mjx.agent.mlp_ppo.intention_network import make_intention_policy
from brax.training.acme import running_statistics
from track_mjx.analysis.rollout import create_environment

# ═══════════════════════════════════════════════════════════════════════════════
# CONFIGURATION
# ═══════════════════════════════════════════════════════════════════════════════

# Analysis configuration
H5_PATH = "/root/vast/eric/track-mjx/notebooks/rollouts_full_mouse_arm_new_250826_030533_134914.h5"
NEURAL_E = 4     # Default embedding dimension (if action-specific not available)
NEURAL_TAU = -1  # Time delay for neural activations
Tp = 1           # Prediction horizon
PAD = 300        # NaN padding between clips

# Lesioning configuration
ENABLE_LESION = True
TOP_K = 5       # Number of top neurons to lesion per action
LAYERS_TO_ANALYZE = ["layer_0", "layer_1", "layer_2"]

# Action-specific embedding dimensions
ACTION_E = {
    "act0": 6, "act1": 10, "act2": 5, "act3": 5, "act4": 10,
    "act5": 6, "act6": 8,  "act7": 6,  "act8": 7
}

# Layer mapping: activation layer name -> parameter layer name
LAYER_MAPPING = {
    "layer_0": "hidden_0",
    "layer_1": "hidden_1",
    "layer_2": "hidden_2"
}

# Action names (muscle names)
action_names = ["Pec_C", "Lat", "PD", "AD", "MD", "Triceps_Lateral", "Triceps_Long", "Brachialis", "Biceps_Long"]
joint_names = ["sh_elv", "sh_ext", "sh_rot", "elbow"]

# Load CCM matrices
ccm_matrix = np.array([
    [1.      , 0.525813, 0.203717, 0.568015, 0.691638, 0.696206, 0.678814, 0.682802, 0.53741 ],
    [0.535335, 1.      , 0.139094, 0.681686, 0.522431, 0.667265, 0.463946, 0.663696, 0.537878],
    [0.295953, 0.163062, 1.      , 0.176845, 0.31952 , 0.165365, 0.213455, 0.375396, 0.216511],
    [0.597855, 0.7982  , 0.139659, 1.      , 0.747117, 0.731527, 0.566099, 0.710519, 0.65959 ],
    [0.719331, 0.698684, 0.245862, 0.794672, 1.      , 0.685992, 0.615126, 0.671923, 0.641866],
    [0.544157, 0.633614, 0.227312, 0.685914, 0.516835, 1.      , 0.248369, 0.862421, 0.316522],
    [0.619141, 0.398754, 0.146874, 0.399908, 0.393907, 0.320955, 1.      , 0.317059, 0.369504],
    [0.430258, 0.553677, 0.390358, 0.498353, 0.316513, 0.78614 , 0.232873, 1.      , 0.148855],
    [0.526948, 0.625399, 0.311665, 0.642974, 0.517234, 0.535678, 0.541893, 0.51868 , 1.      ]
])

# Convert a2j matrix from string to DataFrame
ccm_a2j_data = {
    'sh_elv': [0.676606, 0.614237, 0.572824, 0.702254, 0.760244, 0.655498, 0.390534, 0.611454, 0.541563],
    'sh_ext': [0.650763, 0.589529, 0.481112, 0.722722, 0.726062, 0.567940, 0.422373, 0.475919, 0.539591],
    'sh_rot': [0.635858, 0.713230, 0.444087, 0.783400, 0.774024, 0.698159, 0.503080, 0.610788, 0.570240],
    'elbow':  [0.679048, 0.721847, 0.376172, 0.783492, 0.763284, 0.710400, 0.507781, 0.616184, 0.521874]
}
ccm_matrix_a2j = pd.DataFrame(ccm_a2j_data, index=action_names)

print(f"Configuration: analyzing {len(LAYERS_TO_ANALYZE)} layers, selecting TOP-{TOP_K} neurons per action")
print("Using action-specific embedding dimensions:")
for act_idx, (act, dim) in enumerate(ACTION_E.items()):
    action_name = action_names[act_idx] if act_idx < len(action_names) else act
    print(f"  {action_name} ({act}): E={dim}")

# ═══════════════════════════════════════════════════════════════════════════════
# CCM PREDICTION FUNCTIONS 
# ═══════════════════════════════════════════════════════════════════════════════

def predict_action_effects(actions_to_lesion):
    """
    Predict the effects of lesioning specific actions (muscles) on other actions
    using the CCM causal relationship matrix.
    
    Args:
        actions_to_lesion: List of action indices to lesion
        
    Returns:
        predicted_impact: Dictionary mapping action names to predicted RMSE increase
    """
    # Convert action names to indices if needed
    action_indices = []
    for action in actions_to_lesion:
        if isinstance(action, str):
            if action in action_names:
                action_indices.append(action_names.index(action))
            else:
                print(f"Warning: Action '{action}' not found in action_names")
        else:
            action_indices.append(action)
    
    # Calculate predicted impact on each action
    predicted_impact = {}
    for i, action_name in enumerate(action_names):
        # Skip if this is an action we're lesioning
        if i in action_indices:
            predicted_impact[action_name] = 1.0  # Maximum impact
            continue
        
        # Calculate impact based on CCM values
        impact = 0
        for lesion_idx in action_indices:
            # How much does the lesioned action causally affect this action?
            # Higher CCM value means stronger causal relationship
            causal_strength = ccm_matrix[lesion_idx, i]
            impact += causal_strength
        
        # Normalize impact score (0-1)
        impact = min(impact, 1.0)
        predicted_impact[action_name] = impact
    
    return predicted_impact

def predict_joint_effects(actions_to_lesion):
    """
    Predict the effects of lesioning specific actions (muscles) on joint angles
    using the CCM action-to-joint causal relationship matrix.
    
    Args:
        actions_to_lesion: List of action indices to lesion
        
    Returns:
        predicted_impact: Dictionary mapping joint names to predicted RMSE increase
    """
    # Convert action names to indices if needed
    action_indices = []
    for action in actions_to_lesion:
        if isinstance(action, str):
            if action in action_names:
                action_indices.append(action_names.index(action))
            else:
                print(f"Warning: Action '{action}' not found in action_names")
        else:
            action_indices.append(action)
    
    # Calculate predicted impact on each joint
    predicted_impact = {}
    for joint_name in joint_names:
        impact = 0
        for lesion_idx in action_indices:
            action_name = action_names[lesion_idx]
            # How much does the lesioned action causally affect this joint?
            causal_strength = ccm_matrix_a2j.loc[action_name, joint_name]
            impact += causal_strength
        
        # Normalize impact score (0-1)
        impact = min(impact / len(action_indices), 1.0)
        predicted_impact[joint_name] = impact
    
    return predicted_impact

def predict_lesion_impact(top_neurons_by_action):
    """
    Predict the impact of lesioning neurons on actions and joints.
    
    Args:
        top_neurons_by_action: Dict mapping action names to lists of neuron indices
        
    Returns:
        Tuple of (action_impact, joint_impact) dictionaries
    """
    print("\n=== PREDICTING LESIONING IMPACT USING CCM MATRICES ===")
    
    # Determine which actions will be most affected by the lesioning
    affected_actions = list(top_neurons_by_action.keys())
    print(f"Neurons for these actions will be lesioned: {', '.join(affected_actions)}")
    
    # Predict effect on other actions (muscles)
    action_impact = predict_action_effects(affected_actions)
    
    # Sort actions by predicted impact
    sorted_actions = sorted(action_impact.items(), key=lambda x: x[1], reverse=True)
    print("\nPredicted impact on muscle activations (ordered by severity):")
    for action, impact in sorted_actions:
        impact_level = "High" if impact > 0.7 else "Medium" if impact > 0.4 else "Low"
        print(f"  {action}: {impact:.4f} - {impact_level} impact")
    
    # Predict effect on joints
    joint_impact = predict_joint_effects(affected_actions)
    
    # Sort joints by predicted impact
    sorted_joints = sorted(joint_impact.items(), key=lambda x: x[1], reverse=True)
    print("\nPredicted impact on joint angles (ordered by severity):")
    for joint, impact in sorted_joints:
        impact_level = "High" if impact > 0.7 else "Medium" if impact > 0.4 else "Low"
        print(f"  {joint}: {impact:.4f} - {impact_level} impact")
        
    # Create impact visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Action impact plot
    action_names_list = [a for a, _ in sorted_actions]
    action_impact_vals = [i for _, i in sorted_actions]
    ax1.barh(action_names_list, action_impact_vals, color='skyblue')
    ax1.set_title('Predicted Impact on Muscle Activations')
    ax1.set_xlabel('Predicted Impact (0-1)')
    ax1.grid(alpha=0.3)
    
    # Joint impact plot
    joint_names_list = [j for j, _ in sorted_joints]
    joint_impact_vals = [i for _, i in sorted_joints]
    ax2.barh(joint_names_list, joint_impact_vals, color='lightcoral')
    ax2.set_title('Predicted Impact on Joint Angles')
    ax2.set_xlabel('Predicted Impact (0-1)')
    ax2.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig("predicted_lesioning_impact.png", dpi=150)
    plt.show()
    
    return action_impact, joint_impact

def visualize_ccm_networks():
    """Create visualizations of the CCM networks to better understand causal structures"""
    print("\n=== VISUALIZING CCM CAUSAL NETWORKS ===")
    
    # Create a more readable muscle-to-muscle CCM matrix visualization
    plt.figure(figsize=(10, 8))
    sns.heatmap(ccm_matrix, cmap="YlGnBu", annot=True, fmt=".2f",
                xticklabels=action_names, yticklabels=action_names)
    plt.title('Muscle-to-Muscle CCM Causal Relationships')
    plt.tight_layout()
    plt.savefig("muscle_to_muscle_ccm.png", dpi=150)
    plt.show()
    
    # Create a transposed action-to-joint CCM visualization (easier to read)
    plt.figure(figsize=(12, 6))
    sns.heatmap(ccm_matrix_a2j.T, cmap="YlGnBu", annot=True, fmt=".2f")
    plt.title('Joint-Muscle CCM Relationships (Transposed)')
    plt.tight_layout()
    plt.savefig("joint_to_muscle_ccm_transposed.png", dpi=150)
    plt.show()

def compare_predicted_vs_actual(predicted_action_impact, predicted_joint_impact, action_summary, joint_summary):
    """
    Compare the predicted impact from CCM with actual measured impact after lesioning
    
    Args:
        predicted_action_impact: Dict mapping action names to predicted impact (0-1)
        predicted_joint_impact: Dict mapping joint names to predicted impact (0-1)
        action_summary: List of dicts with actual action RMSE measurements
        joint_summary: List of dicts with actual joint RMSE measurements
    """
    print("\n=== COMPARING PREDICTED VS ACTUAL LESIONING EFFECTS ===")
    
    # Prepare action data
    action_data = []
    for item in action_summary:
        action = item['action']
        if action in predicted_action_impact:
            action_data.append({
                'action': action,
                'predicted_impact': predicted_action_impact[action],
                'actual_rmse': item['rmse'],
                'actual_max_diff': item['max_diff']
            })
    
    # Prepare joint data
    joint_data = []
    for item in joint_summary:
        joint = item['joint']
        if joint in predicted_joint_impact:
            joint_data.append({
                'joint': joint,
                'predicted_impact': predicted_joint_impact[joint],
                'actual_rmse_increase': item['rmse_percent'] / 100,  # Convert percent to fraction
                'actual_mae_increase': item['mae_percent'] / 100     # Convert percent to fraction
            })
    
    # Create comparison plots
    fig, axes = plt.subplots(2, 1, figsize=(12, 14))
    
    # Action comparison
    action_df = pd.DataFrame(action_data)
    action_df = action_df.sort_values('predicted_impact', ascending=False)
    
    # Normalize actual RMSE to 0-1 scale for comparison
    max_rmse = action_df['actual_rmse'].max()
    action_df['normalized_actual'] = action_df['actual_rmse'] / max_rmse if max_rmse > 0 else 0
    
    ax1 = axes[0]
    x = np.arange(len(action_df))
    width = 0.35
    
    ax1.bar(x - width/2, action_df['predicted_impact'], width, label='Predicted Impact', color='skyblue')
    ax1.bar(x + width/2, action_df['normalized_actual'], width, label='Normalized Actual RMSE', color='coral')
    
    ax1.set_xticks(x)
    ax1.set_xticklabels(action_df['action'], rotation=45, ha='right')
    ax1.set_title('Muscle Activation: Predicted vs Actual Impact')
    ax1.set_ylabel('Impact (0-1 scale)')
    ax1.legend()
    ax1.grid(alpha=0.3)
    
    # Calculate correlation
    action_corr = np.corrcoef(action_df['predicted_impact'], action_df['normalized_actual'])[0, 1]
    ax1.text(0.02, 0.95, f"Correlation: {action_corr:.3f}", transform=ax1.transAxes, 
             bbox=dict(facecolor='white', alpha=0.8))
    
    # Joint comparison
    joint_df = pd.DataFrame(joint_data)
    joint_df = joint_df.sort_values('predicted_impact', ascending=False)
    
    ax2 = axes[1]
    x = np.arange(len(joint_df))
    
    ax2.bar(x - width/2, joint_df['predicted_impact'], width, label='Predicted Impact', color='skyblue')
    ax2.bar(x + width/2, joint_df['actual_rmse_increase'], width, label='Actual RMSE % Increase', color='coral')
    
    ax2.set_xticks(x)
    ax2.set_xticklabels(joint_df['joint'], rotation=45, ha='right')
    ax2.set_title('Joint Angles: Predicted vs Actual Impact')
    ax2.set_ylabel('Impact (0-1 scale)')
    ax2.legend()
    ax2.grid(alpha=0.3)
    
    # Calculate correlation
    joint_corr = np.corrcoef(joint_df['predicted_impact'], joint_df['actual_rmse_increase'])[0, 1]
    ax2.text(0.02, 0.95, f"Correlation: {joint_corr:.3f}", transform=ax2.transAxes,
             bbox=dict(facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    plt.savefig("predicted_vs_actual_impact.png", dpi=150)
    plt.show()
    
    # Print summary statistics
    print(f"Action prediction correlation: {action_corr:.3f}")
    print(f"Joint prediction correlation: {joint_corr:.3f}")
    
    if action_corr > 0.7 and joint_corr > 0.7:
        print("\n✓ CCM-based predictions showed STRONG correlation with actual lesioning effects")
    elif action_corr > 0.4 and joint_corr > 0.4:
        print("\n✓ CCM-based predictions showed MODERATE correlation with actual lesioning effects")
    else:
        print("\n⚠ CCM-based predictions showed WEAK correlation with actual lesioning effects")

# ═══════════════════════════════════════════════════════════════════════════════
# HELPER FUNCTIONS - NEURAL ANALYSIS
# ═══════════════════════════════════════════════════════════════════════════════

def zscore(x, axis=None, eps=1e-12):
    """Z-score normalization with small epsilon to avoid division by zero."""
    mu = np.nanmean(x, axis=axis, keepdims=True)
    sd = np.nanstd(x, axis=axis, keepdims=True)
    return (x - mu) / np.maximum(sd, eps)

def concat_neural_action_padded(neural_dict, actions_data, pad=1, standardize=True):
    """Concatenate neural activations with action data, NaN-padded between clips."""
    N, T = next(iter(neural_dict.values())).shape[:2]
    
    frames = []
    row_cursor = 0
    
    for clip in range(N):
        # Add padding between clips (except before first clip)
        if clip > 0:
            pad_data = {"time": np.arange(row_cursor+1, row_cursor+pad+1)}
            
            # Add NaN columns for all neural activations
            for layer_name, layer_data in neural_dict.items():
                n_neurons = layer_data.shape[-1]
                for neuron in range(n_neurons):
                    pad_data[f"{layer_name}_n{neuron:03d}"] = np.nan
            
            # Add NaN columns for actions
            for a, action_name in enumerate(action_names):
                pad_data[f"action_{action_name}"] = np.nan
            
            pad_df = pd.DataFrame(pad_data)
            frames.append(pad_df)
            row_cursor += pad
        
        # Add actual data for this clip
        clip_data = {"time": np.arange(row_cursor+1, row_cursor+T+1)}
        
        # Neural activations (standardized)
        for layer_name, layer_data in neural_dict.items():
            n_neurons = layer_data.shape[-1]
            for neuron in range(n_neurons):
                x = layer_data[clip, :, neuron]
                clip_data[f"{layer_name}_n{neuron:03d}"] = zscore(x) if standardize else x
        
        # Action values (standardized)
        for a, action_name in enumerate(action_names):
            action_vals = actions_data[clip, :, a]
            clip_data[f"action_{action_name}"] = zscore(action_vals) if standardize else action_vals
        
        clip_df = pd.DataFrame(clip_data)
        frames.append(clip_df)
        row_cursor += T
    
    big_df = pd.concat(frames, ignore_index=True)
    return big_df

def build_lib_pred_from_nan_blocks(df, probe_col, split=0.6, min_edge=8):
    """Split each non-NaN block 60/40 into lib/pred for pyEDM."""
    x = df[probe_col].to_numpy()
    notnan = ~np.isnan(x)
    edges = np.diff(np.concatenate([[0], notnan.view(np.int8), [0]]))
    starts = np.where(edges == 1)[0] + 1  # 1-based indexing
    ends = np.where(edges == -1)[0]       # 1-based indexing
    
    lib_parts, pred_parts = [], []
    for s, e in zip(starts, ends):
        n = e - s + 1
        if n < 2 * min_edge:
            continue
        m = s + int(np.floor(split * n)) - 1
        m = max(m, s + min_edge - 1)
        m = min(m, e - min_edge)
        if m <= s or m >= e:
            continue
        lib_parts.extend([str(s), str(m)])
        pred_parts.extend([str(m + 1), str(e)])
    
    return " ".join(lib_parts), " ".join(pred_parts)

def safe_correlation(y, yhat, min_pairs=10):
    """Calculate correlation with safety checks."""
    y = np.asarray(y, float)
    yhat = np.asarray(yhat, float)
    mask = np.isfinite(y) & np.isfinite(yhat)
    
    if mask.sum() < min_pairs:
        return np.nan
    
    y_clean, yhat_clean = y[mask], yhat[mask]
    sy, syh = y_clean.std(ddof=0), yhat_clean.std(ddof=0)
    
    if sy <= 1e-12 or syh <= 1e-12:
        return np.nan
    
    return float(np.corrcoef(y_clean, yhat_clean)[0, 1])

def get_top_neurons_for_action(corr_df, action_name, layer_name, k=25):
    """Get top K neurons for a specific action."""
    action_col = f"action_{action_name}"
    
    # Ensure action column exists
    if action_col not in corr_df.columns:
        print(f"Warning: {action_col} not found in correlation dataframe")
        return []
    
    # Sort by absolute correlation
    corrs = corr_df[action_col].dropna().abs().sort_values(ascending=False)
    top_neurons = corrs.head(k).index.tolist()
    
    # Extract neuron indices from column names
    neuron_indices = []
    for neuron_col in top_neurons:
        neuron_idx = int(neuron_col.split('_n')[1])
        neuron_indices.append(neuron_idx)
    
    # Print top neurons and their correlations
    print(f"\nTop {k} neurons in {layer_name} for {action_name}:")
    for i, neuron_col in enumerate(top_neurons[:min(10, k)]):  # Show top 10 max
        neuron_idx = int(neuron_col.split('_n')[1])
        raw_corr = corr_df.loc[neuron_col, action_col]
        abs_corr = abs(raw_corr)
        print(f"  {i+1:2d}. Neuron {neuron_idx:3d}: {raw_corr:+.4f} (|ρ|={abs_corr:.4f})")
    
    if len(top_neurons) > 10:
        print(f"  ... and {len(top_neurons)-10} more")
    
    return neuron_indices

def neural_to_action_simplex_heatmap(df, neural_cols, action_cols, lib, pred, layer_name, action_E=None, tau=-1, Tp=0):
    """Perform simplex analysis: Neural Activations → Action Dynamics with action-specific embedding dimensions"""
    print(f"Running simplex analysis for {layer_name}: {len(neural_cols)} neurons → {len(action_cols)} actions")
    print(f"Using action-specific embedding dimensions")
    
    # Initialize correlation matrix
    corr_matrix = np.full((len(neural_cols), len(action_cols)), np.nan, dtype=float)
    
    # Total number of computations to perform
    total_pairs = len(neural_cols) * len(action_cols)
    
    # If no action-specific embedding dimensions provided, use default
    if action_E is None:
        action_E = {f"act{i}": NEURAL_E for i in range(len(action_cols))}
    
    # Print embedding dimensions used for each action
    print("\nEmbedding dimensions for each action:")
    for ai, action_col in enumerate(action_cols):
        action_idx = f"act{ai}"
        E = action_E.get(action_idx, NEURAL_E)  # Fall back to default if not specified
        print(f"  {action_col} (index {action_idx}): E={E}")
    
    # Use tqdm for progress tracking - no need for additional manual updates
    pbar = tqdm(total=total_pairs, desc=f"{layer_name}→Action Simplex")
    
    for ni, neural_col in enumerate(neural_cols):
        for ai, action_col in enumerate(action_cols):
            # Determine embedding dimension for this action
            action_idx = f"act{ai}"
            E = action_E.get(action_idx, NEURAL_E)  # Fall back to default if not specified
            
            try:
                # Run simplex prediction with action-specific embedding dimension
                pred_df = Simplex(
                    dataFrame=df,
                    lib=lib,
                    pred=pred,
                    columns=neural_col,
                    target=action_col,
                    E=E,  # Use action-specific E
                    tau=tau,
                    Tp=Tp,
                    ignoreNan=True,
                    showPlot=False
                )
                
                # Extract predictions and observations
                obs = pred_df["Observations"].to_numpy()
                hat = pred_df["Predictions"].to_numpy()
                
                # Calculate correlation
                corr_matrix[ni, ai] = safe_correlation(obs, hat, min_pairs=10)
                
            except Exception as e:
                print(f"Error for {neural_col} → {action_col} with E={E}: {e}")
                corr_matrix[ni, ai] = np.nan
            
            # Update progress bar without additional prints
            pbar.update(1)
    
    # Close progress bar
    pbar.close()
    print(f"Completed {layer_name} simplex analysis")
    
    # Convert to DataFrame
    corr_df = pd.DataFrame(corr_matrix, index=neural_cols, columns=action_cols)
    return corr_df

def analyze_layer_for_actions(layer_name, neural_dict, actions_data):
    """Analyze a single layer to identify neurons predictive of actions."""
    print(f"\n{'='*80}\nANALYZING LAYER: {layer_name} FOR ACTION PREDICTION\n{'='*80}")
    
    # Build a layer-specific neural dictionary (just this one layer)
    layer_neural_dict = {layer_name: neural_dict[layer_name]}
    
    # Build DataFrame with neural activations and actions
    big_df_actions = concat_neural_action_padded(
        neural_dict=layer_neural_dict,
        actions_data=actions_data,
        pad=PAD,
        standardize=True
    )
    
    # Extract column names for analysis
    neural_cols = [col for col in big_df_actions.columns if layer_name in col]
    action_cols = [col for col in big_df_actions.columns if col.startswith("action_")]
    
    # Build lib/pred indices
    probe_col = neural_cols[0]
    lib, pred = build_lib_pred_from_nan_blocks(big_df_actions, probe_col=probe_col, split=0.6, min_edge=8)
    
    # Run simplex analysis with action-specific embedding dimensions
    print(f"Starting {layer_name} Neural → Action Simplex Analysis...")
    corr_results = neural_to_action_simplex_heatmap(
        df=big_df_actions,
        neural_cols=neural_cols,
        action_cols=action_cols,
        lib=lib,
        pred=pred,
        layer_name=layer_name,
        action_E=ACTION_E,  # Pass the action-specific embedding dimensions
        tau=NEURAL_TAU,
        Tp=Tp
    )
    
    # Create and save heatmap
    plt.figure(figsize=(12, 10))
    
    # Compute mean absolute correlation per neuron for sorting
    mean_abs_corr = corr_results.abs().mean(axis=1).sort_values(ascending=False)
    sorted_neurons = mean_abs_corr.index.tolist()
    
    # Get top 50 neurons for visualization
    top_neurons = sorted_neurons[:50]
    top_corr_df = corr_results.loc[top_neurons]
    
    # Create heatmap
    plt.figure(figsize=(12, 14))
    sns.heatmap(top_corr_df, cmap="RdBu_r", center=0, vmin=-1, vmax=1, 
                xticklabels=True, yticklabels=True, annot=False)
    plt.title(f"{layer_name} Neural Activations → Actions Correlation (Top 50 Neurons)\nUsing action-specific embedding dimensions")
    plt.tight_layout()
    plt.savefig(f"{layer_name}_to_actions_heatmap.png", dpi=150)
    plt.close()
    
    # For each action, get top-k neurons
    top_neurons_by_action = {}
    all_top_neurons = set()
    
    for action_name in action_names:
        top_for_action = get_top_neurons_for_action(
            corr_df=corr_results,
            action_name=action_name,
            layer_name=layer_name,
            k=TOP_K
        )
        top_neurons_by_action[action_name] = top_for_action
        all_top_neurons.update(top_for_action)
    
    # Return unique neurons across all actions
    all_top_neurons_list = list(all_top_neurons)
    print(f"\nFound {len(all_top_neurons_list)} unique neurons in {layer_name} predictive of actions")
    
    return all_top_neurons_list, top_neurons_by_action, corr_results

# [Rest of the helper functions continue as in the original code...]
# ═══════════════════════════════════════════════════════════════════════════════
# HELPER FUNCTIONS - LESIONING
# ═══════════════════════════════════════════════════════════════════════════════

def create_lesioned_policy_complete_multi_layer(original_policy, neurons_by_layer):
    """
    Create a lesioned version of the policy by zeroing both incoming and outgoing weights
    for multiple layers at once.
    
    Args:
        original_policy: The original policy to lesion
        neurons_by_layer: Dict mapping layer names (e.g., "hidden_0") to lists of neuron indices
    
    Returns:
        Lesioned policy
    """
    # Make a deep copy of the policy to avoid modifying the original
    lesioned_policy = jtu.tree_map(lambda x: jnp.array(x), original_policy)
    
    # Extract policy parameters
    processor_params, policy_params = lesioned_policy
    decoder_params = policy_params['params']['decoder']
    
    print(f"\nLesioning neurons across {len(neurons_by_layer)} layers...")
    
    # Process each layer
    for target_layer, neuron_indices_to_lesion in neurons_by_layer.items():
        print(f"\n=== Lesioning layer: {target_layer} ===")
        
        if target_layer not in decoder_params:
            print(f"Warning: Layer {target_layer} not found in model. Available layers: {list(decoder_params.keys())}")
            continue
        
        layer_params = decoder_params[target_layer]
        neuron_indices_array = jnp.array(neuron_indices_to_lesion)
        
        print(f"Complete lesioning of {len(neuron_indices_to_lesion)} neurons in {target_layer}")
        
        # 1. Zero outgoing weights (columns in kernel, elements in bias)
        if 'kernel' in layer_params and 'bias' in layer_params:
            kernel = layer_params['kernel']
            bias = layer_params['bias']
            
            # Filter valid indices for outgoing weights
            valid_indices = neuron_indices_array[neuron_indices_array < kernel.shape[1]]
            
            new_kernel = kernel.at[:, valid_indices].set(0.0)  # Set to 0 not -1
            new_bias = bias.at[valid_indices].set(0.0)         # Set to 0 not -1
            
            layer_params['kernel'] = new_kernel
            layer_params['bias'] = new_bias
            
            print(f"✓ Zeroed outgoing weights for {len(valid_indices)} neurons")
        
        # 2. Zero incoming weights to those neurons (rows in the SAME layer's kernel)
        if 'kernel' in layer_params:
            kernel = layer_params['kernel']
            
            # Filter valid indices for incoming weights (rows)
            valid_indices = neuron_indices_array[neuron_indices_array < kernel.shape[0]]
            
            if len(valid_indices) > 0:
                # Zero rows
                new_kernel = kernel.at[valid_indices, :].set(0.0)  # Set to 0 not -1
                layer_params['kernel'] = new_kernel
                
                print(f"✓ Zeroed incoming weights for {len(valid_indices)} neurons")
        
        # 3. Also zero outgoing connections from next layer if it exists
        layer_mapping = {
            "hidden_0": "hidden_1",
            "hidden_1": "hidden_2", 
            "hidden_2": "hidden_3"
        }
        
        next_layer = layer_mapping.get(target_layer)
        if next_layer and next_layer in decoder_params:
            next_layer_params = decoder_params[next_layer]
            if 'kernel' in next_layer_params:
                next_kernel = next_layer_params['kernel']
                valid_indices = neuron_indices_array[neuron_indices_array < next_kernel.shape[0]]
                
                if len(valid_indices) > 0:
                    # Zero rows
                    new_next_kernel = next_kernel.at[valid_indices, :].set(0.0)  # Set to 0 not -1
                    next_layer_params['kernel'] = new_next_kernel
                    
                    print(f"✓ Zeroed connections from lesioned neurons to {next_layer}")
    
    print("\n✓ Multi-layer lesioning completed successfully")
    return lesioned_policy

def verify_multi_layer_lesioning(original_policy, lesioned_policy, neurons_by_layer):
    """Verify that lesioning has been successfully applied across multiple layers."""
    print("\n=== MULTI-LAYER LESIONING VERIFICATION ===")
    
    processor_params_orig, policy_params_orig = original_policy
    processor_params_les, policy_params_les = lesioned_policy
    
    verification_passed = True
    
    for target_layer, neuron_indices in neurons_by_layer.items():
        print(f"\nVerifying layer: {target_layer}")
        
        try:
            # Get the target layer parameters
            layer_params_orig = policy_params_orig['params']['decoder'][target_layer]
            layer_params_les = policy_params_les['params']['decoder'][target_layer]
            
            # Check a subset of neuron weights
            sample_neurons = neuron_indices[:min(3, len(neuron_indices))]
            
            for idx in sample_neurons:
                if idx < layer_params_les['kernel'].shape[1]:
                    # Check bias
                    bias_orig = layer_params_orig['bias'][idx]
                    bias_les = layer_params_les['bias'][idx]
                    bias_zeroed = jnp.allclose(bias_les, 0.0)
                    
                    # Check outgoing weights
                    outgoing_orig = layer_params_orig['kernel'][:, idx]
                    outgoing_les = layer_params_les['kernel'][:, idx]
                    outgoing_zeroed = jnp.allclose(outgoing_les, jnp.zeros_like(outgoing_les))
                    
                    print(f"  Neuron {idx}: Bias zeroed: {bias_zeroed}, Outgoing weights zeroed: {outgoing_zeroed}")
                    
                    if not bias_zeroed or not outgoing_zeroed:
                        verification_passed = False
            
            # Check overall changes
            kernel_diff = jnp.sum(jnp.abs(layer_params_orig['kernel'] - layer_params_les['kernel']))
            bias_diff = jnp.sum(jnp.abs(layer_params_orig['bias'] - layer_params_les['bias']))
            print(f"  Total weight changes: Kernel diff: {kernel_diff:.2f}, Bias diff: {bias_diff:.2f}")
            
            if kernel_diff <= 0 or bias_diff <= 0:
                verification_passed = False
                
        except Exception as e:
            print(f"Error in verification for {target_layer}: {e}")
            verification_passed = False
    
    return verification_passed

# ═══════════════════════════════════════════════════════════════════════════════
# ACTION COMPARISON FUNCTIONS
# ═══════════════════════════════════════════════════════════════════════════════

def compare_actions(baseline_actions, lesioned_actions):
    """Compare baseline and lesioned actions, generate plots and statistics"""
    print("\n=== COMPARING BASELINE VS LESIONED ACTIONS ===")
    
    n_actions = baseline_actions.shape[1]
    summary_data = []
    
    # Create figure for overall comparison
    plt.figure(figsize=(16, 14))
    
    # Process each action
    for action_idx, action_name in enumerate(action_names[:n_actions]):
        baseline_action = baseline_actions[:, action_idx]
        lesioned_action = lesioned_actions[:, action_idx]
        
        # Calculate metrics
        mae = np.mean(np.abs(baseline_action - lesioned_action))
        rmse = np.sqrt(np.mean((baseline_action - lesioned_action)**2))
        max_diff = np.max(np.abs(baseline_action - lesioned_action))
        
        # Store summary data
        summary_data.append({
            'action': action_name,
            'mae': mae,
            'rmse': rmse,
            'max_diff': max_diff
        })
        
        print(f"\n=== {action_name.upper()} ACTION STATISTICS ===")
        print(f"Mean Absolute Error: {mae:.4f}")
        print(f"RMSE: {rmse:.4f}")
        print(f"Maximum Difference: {max_diff:.4f}")
        
        # Individual action plots
        plt.subplot(3, 3, action_idx + 1)
        plt.plot(baseline_action, label='Baseline', color='blue', alpha=0.7)
        plt.plot(lesioned_action, label='Lesioned', color='red', alpha=0.7)
        plt.title(f'{action_name} Actions')
        plt.xlabel('Frame')
        plt.ylabel('Action Value')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Create detailed action figure
        plt.figure(figsize=(14, 12))
        
        # Plot 1: Time series comparison
        plt.subplot(2, 2, 1)
        plt.plot(baseline_action, label='Baseline', color='blue', alpha=0.7)
        plt.plot(lesioned_action, label='Lesioned', color='red', alpha=0.7)
        plt.title(f'{action_name} Actions')
        plt.xlabel('Frame')
        plt.ylabel('Action Value')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot 2: Difference between baseline and lesioned
        plt.subplot(2, 2, 2)
        plt.plot(lesioned_action - baseline_action, color='purple', alpha=0.7)
        plt.title(f'{action_name} Action Difference (Lesioned - Baseline)')
        plt.xlabel('Frame')
        plt.ylabel('Difference')
        plt.axhline(y=0, color='k', linestyle='-', alpha=0.3)
        plt.grid(True, alpha=0.3)
        
        # Plot 3: Distribution comparison
        plt.subplot(2, 2, 3)
        sns.kdeplot(baseline_action, label='Baseline', color='blue', fill=True, alpha=0.3)
        sns.kdeplot(lesioned_action, label='Lesioned', color='red', fill=True, alpha=0.3)
        plt.title(f'{action_name} Action Distribution')
        plt.xlabel('Action Value')
        plt.ylabel('Density')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot 4: Scatter plot of lesioned vs baseline
        plt.subplot(2, 2, 4)
        plt.scatter(baseline_action, lesioned_action, alpha=0.5, s=10)
        plt.plot([min(baseline_action), max(baseline_action)], 
                 [min(baseline_action), max(baseline_action)], 
                 'k--', alpha=0.5)
        plt.title(f'{action_name} Lesioned vs Baseline')
        plt.xlabel('Baseline Action')
        plt.ylabel('Lesioned Action')
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(Path(ckpt_path) / f"{action_name}_action_analysis.png", dpi=150)
        plt.close()
    
    # Finalize and save the action comparison plot
    plt.tight_layout()
    plt.savefig(Path(ckpt_path) / "all_actions_comparison.png", dpi=150)
    plt.close()
    
    # Create a summary bar chart comparing the effect of lesioning on each action
    plt.figure(figsize=(12, 8))
    
    # Bar chart of RMSE per action
    actions = [d['action'] for d in summary_data]
    rmse_values = [d['rmse'] for d in summary_data]
    
    bars = plt.bar(actions, rmse_values, color='skyblue')
    plt.title('RMSE Between Baseline and Lesioned Actions')
    plt.xlabel('Action')
    plt.ylabel('RMSE')
    plt.xticks(rotation=45)
    plt.grid(axis='y', alpha=0.3)
    
    # Annotate with values
    for bar, value in zip(bars, rmse_values):
        plt.text(bar.get_x() + bar.get_width()/2., 
                 value + 0.01,
                 f'{value:.4f}', 
                 ha='center', va='bottom', rotation=0)
    
    plt.tight_layout()
    plt.savefig(Path(ckpt_path) / "action_rmse_summary.png", dpi=150)
    plt.close()
    
    # Print summary table
    print("\n=== SUMMARY OF LESIONING EFFECTS ACROSS ACTIONS ===")
    print(f"{'Action':<20} {'MAE':>10} {'RMSE':>10} {'Max Difference':>15}")
    print("-" * 60)
    for data in summary_data:
        print(f"{data['action']:<20} {data['mae']:>10.4f} {data['rmse']:>10.4f} {data['max_diff']:>15.4f}")
        
    return summary_data

# ═══════════════════════════════════════════════════════════════════════════════
# MAIN EXECUTION
# ═══════════════════════════════════════════════════════════════════════════════

print("\n=== STARTING NEURAL ANALYSIS AND LESIONING ===")

# Visualize CCM networks to understand causal relationships
visualize_ccm_networks()

# Step 1: Load neural and action data for analysis
print("\nLoading neural activation and action data...")
with h5py.File(H5_PATH, "r") as f:
    actions = f["actions"][...]  # Load actions data
    
    # Neural activations from all decoder layers
    layer_data = {}
    for layer_name in LAYERS_TO_ANALYZE:
        layer_data[layer_name] = f[f"decoder_activations/{layer_name}"][:]
        print(f"{layer_name} shape: {layer_data[layer_name].shape}")
    
    print(f"Actions shape: {actions.shape}")

# Step 2: Analyze each layer to identify neurons predictive of actions
neurons_by_layer = {}
neurons_by_action = {}
correlation_results = {}

for layer_name in LAYERS_TO_ANALYZE:
    # Find the corresponding parameter layer name
    param_layer = LAYER_MAPPING[layer_name]
    # Analyze layer to get top neurons
    top_neurons, top_by_action, corr_df = analyze_layer_for_actions(layer_name, layer_data, actions)
    neurons_by_layer[param_layer] = top_neurons
    neurons_by_action[layer_name] = top_by_action
    correlation_results[layer_name] = corr_df

print("\n=== IDENTIFIED NEURONS TO LESION ===")
for layer_name, neurons in neurons_by_layer.items():
    print(f"{layer_name}: {len(neurons)} unique neurons")

# Step 3: Predict impact of lesioning using CCM matrices
predicted_action_impact, predicted_joint_impact = predict_lesion_impact(neurons_by_action["layer_0"])

# Step 4: Load checkpoint and config
ckpt_path = Path.cwd().parent / "model_checkpoints/250826_030533_134914"
ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path)
cfg = ckpt["cfg"]

# Configure data path
cfg.data_path = "/root/vast/eric/stac-mjx/refined_STACed_data/A36-1_2023-07-18_16-54-01_lightOff_trial01_ik.h5,/root/vast/eric/stac-mjx/refined_STACed_data/A36-1_2023-07-18_16-54-01_lightOff_trial04_ik.h5,/root/vast/eric/stac-mjx/refined_STACed_data/A36-1_2023-07-18_16-54-01_lightOff_trial09_ik.h5,/root/vast/eric/stac-mjx/refined_STACed_data/A36-1_2023-07-18_16-54-01_lightOff_trial10_ik.h5,/root/vast/eric/stac-mjx/refined_STACed_data/A36-1_2023-07-18_16-54-01_lightOff_trial13_ik.h5"
cfg.train_setup.checkpoint_to_restore = ckpt_path

# Create environment for testing
env = rollout.create_environment(cfg)

# Step 5: Create baseline (non-lesioned) rollout for comparison
print("\n=== GENERATING BASELINE ROLLOUT FOR COMPARISON ===")
baseline_inference_fn = checkpointing.load_inference_fn(cfg, ckpt["policy"])
baseline_rollout_gen = rollout.create_rollout_generator(
    cfg, env, baseline_inference_fn, 
    log_activations=True, log_metrics=True, log_sensor_data=True
)
baseline_rollout = baseline_rollout_gen(clip_idx=1)
print("✓ Baseline rollout generated")

# Step 6: Apply multi-layer lesioning to the policy
if ENABLE_LESION:
    print("\n=== CREATING MULTI-LAYER LESIONED POLICY ===")
    
    # Create policy network structure first
    original_policy = ckpt["policy"]
    lesioned_policy = create_lesioned_policy_complete_multi_layer(
        original_policy, 
        neurons_by_layer
    )
    
    # Verify that lesioning was properly applied
    verification_passed = verify_multi_layer_lesioning(original_policy, lesioned_policy, neurons_by_layer)
    print(f"\nMulti-layer lesioning verification {'PASSED' if verification_passed else 'FAILED'}")
    
    # Replace the policy in the checkpoint
    ckpt["policy"] = lesioned_policy
    print(f"Checkpoint policy updated with lesioned policy")
    
    # Use the standard loader with the modified checkpoint
    print("\n=== SETTING UP LESIONED INFERENCE FUNCTION ===")
    inference_fn = checkpointing.load_inference_fn(cfg, ckpt["policy"])
    print("✓ Lesioned inference function created")
else:
    # Use the original policy from checkpoint
    inference_fn = checkpointing.load_inference_fn(cfg, ckpt["policy"])
    print("✓ Original inference function loaded")

# Step 7: Generate rollout with the lesioned policy
print("\n=== GENERATING ROLLOUT WITH MULTI-LAYER LESIONED POLICY ===")
generate_rollout = rollout.create_rollout_generator(
    cfg, 
    env, 
    inference_fn, 
    log_activations=True, 
    log_metrics=True, 
    log_sensor_data=True
)

lesioned_rollout = generate_rollout(clip_idx=1)
print("✓ Lesioned rollout generated")

# Step 8: Extract and analyze activations from rollout
if ENABLE_LESION and lesioned_rollout.get('activations') is not None:
    print("\n=== ANALYZING ROLLOUT ACTIVATIONS ===")
    activations = lesioned_rollout['activations']['decoder']
    
    # Check each layer
    for activation_layer, target_layer in LAYER_MAPPING.items():
        if activation_layer in activations:
            layer_acts = activations[activation_layer]
            neurons_to_check = neurons_by_layer[target_layer]
            
            # Check activation stats for lesioned neurons
            valid_neurons = [n for n in neurons_to_check if n < layer_acts.shape[-1]]
            if valid_neurons:
                lesioned_acts = layer_acts[:, valid_neurons]
                avg_activation = jnp.mean(jnp.abs(lesioned_acts))
                max_activation = jnp.max(jnp.abs(lesioned_acts))
                
                print(f"\n{activation_layer} ({target_layer}) lesioned neurons stats:")
                print(f"  Average absolute activation: {avg_activation:.6f}")
                print(f"  Maximum absolute activation: {max_activation:.6f}")
                print(f"  {'⚠️ NEURONS STILL ACTIVE' if max_activation > 0.1 else '✓ NEURONS PROPERLY SILENCED'}")
        else:
            print(f"Warning: {activation_layer} not found in activations")

# Step 9: Compare actions between baseline and lesioned rollouts
print("\n=== COMPARING BASELINE AND LESIONED ACTIONS ===")
baseline_actions = np.array(baseline_rollout['actions']) if 'actions' in baseline_rollout else None
lesioned_actions = np.array(lesioned_rollout['actions']) if 'actions' in lesioned_rollout else None

action_summary = None
if baseline_actions is not None and lesioned_actions is not None:
    action_summary = compare_actions(baseline_actions, lesioned_actions)
else:
    print("ERROR: Could not extract actions from rollouts for comparison")

# Step 10: Render comparison videos
print("\n=== RENDERING COMPARISON VIDEOS ===")

# Render lesioned video
lesioned_frames, lesioned_framerate = render.render_rollout(
    cfg, 
    lesioned_rollout, 
    height=480,
    width=640,
)

# Render baseline video
baseline_frames, baseline_framerate = render.render_rollout(
    cfg, 
    baseline_rollout, 
    height=480,
    width=640,
)

# Save the videos
lesioned_video_path = Path(ckpt_path) / f"rollout_actions_lesioned_top{TOP_K}.mp4"
baseline_video_path = Path(ckpt_path) / f"rollout_baseline.mp4"

media.write_video(lesioned_video_path, lesioned_frames, fps=lesioned_framerate)
media.write_video(baseline_video_path, baseline_frames, fps=baseline_framerate)

print(f"✓ Lesioned video saved to {lesioned_video_path}")
print(f"✓ Baseline video saved to {baseline_video_path}")

# Display the lesioned video
print("\n=== DISPLAYING MULTI-LAYER LESIONED VIDEO ===")
media.show_video(lesioned_frames, fps=lesioned_framerate)

# Step 11: Compare joint positions between baseline and lesioned rollouts
print("\n=== COMPARING CONTROL VS LESIONED PERFORMANCE ACROSS ALL JOINTS ===")

# Joint names for better labeling
JOINT_NAMES = ["sh_elv", "sh_ext", "sh_rot", "elbow"]

def extract_joint_data(rollout, key_name='qposes_rollout'):
    """Extract data for all joints from specified key"""
    if key_name in rollout:
        print(f"Found {key_name} with shape: {rollout[key_name].shape}")
        return np.array(rollout[key_name])
    
    print(f"Could not find {key_name}. Available arrays:")
    for key, value in rollout.items():
        if isinstance(value, np.ndarray):
            print(f"  {key}: {value.shape}")
        elif hasattr(value, 'keys'):
            print(f"  {key} (dict/object with keys): {list(value.keys())}")
    
    return None

# Extract data for all joints
baseline_joints = extract_joint_data(baseline_rollout, 'qposes_rollout')
lesioned_joints = extract_joint_data(lesioned_rollout, 'qposes_rollout')
reference_joints = extract_joint_data(baseline_rollout, 'qposes_ref')

# If reference is not found in baseline, try lesioned
if reference_joints is None:
    reference_joints = extract_joint_data(lesioned_rollout, 'qposes_ref')

# Extract rewards
baseline_rewards = np.array(baseline_rollout['rewards']) if 'rewards' in baseline_rollout else None
lesioned_rewards = np.array(lesioned_rollout['rewards']) if 'rewards' in lesioned_rollout else None

# Fall back to state_rewards if needed
if baseline_rewards is None and 'state_rewards' in baseline_rollout:
    baseline_rewards = np.array(baseline_rollout['state_rewards'])
if lesioned_rewards is None and 'state_rewards' in lesioned_rollout:
    lesioned_rewards = np.array(lesioned_rollout['state_rewards'])

# Proceed only if we found all the necessary data
joint_summary = None
if baseline_joints is None or lesioned_joints is None or reference_joints is None:
    print("Could not find required joint data in the rollouts")
else:
    # Print reward statistics
    if baseline_rewards is not None and lesioned_rewards is not None:
        print("\n=== OVERALL REWARD STATISTICS ===")
        print(f"Control mean reward: {np.mean(baseline_rewards):.4f}, Lesioned mean reward: {np.mean(lesioned_rewards):.4f}")
        print(f"Reward reduction: {np.mean(baseline_rewards) - np.mean(lesioned_rewards):.4f} ({(1 - np.mean(lesioned_rewards)/np.mean(baseline_rewards))*100:.1f}%)")
        
        # Plot reward comparison
        plt.figure(figsize=(8, 6))
        plt.plot(baseline_rewards, label='Control', color='blue', alpha=0.7)
        plt.plot(lesioned_rewards, label='Lesioned', color='red', alpha=0.7)
        plt.title('Reward Over Time')
        plt.xlabel('Frame')
        plt.ylabel('Reward')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(Path(ckpt_path) / "overall_lesion_rewards.png", dpi=150)
        plt.close()
    
    # Calculate overall metrics for summary table
    joint_summary = []
    
    # Create a 2x2 grid of joint plots
    plt.figure(figsize=(16, 14))
    
    # Process each joint
    for joint_idx, joint_name in enumerate(JOINT_NAMES):
        baseline_joint = baseline_joints[:, joint_idx]
        lesioned_joint = lesioned_joints[:, joint_idx]
        reference_joint = reference_joints[:, joint_idx]
        
        # Calculate error metrics
        baseline_error = baseline_joint - reference_joint
        lesioned_error = lesioned_joint - reference_joint
        
        # Calculate absolute errors for visualization
        baseline_abs_error = np.abs(baseline_error)
        lesioned_abs_error = np.abs(lesioned_error)
        
        # Calculate statistics
        baseline_rmse = np.sqrt(np.mean(baseline_error**2))
        lesioned_rmse = np.sqrt(np.mean(lesioned_error**2))
        baseline_mae = np.mean(baseline_abs_error)
        lesioned_mae = np.mean(lesioned_abs_error)
        error_var_ratio = np.var(lesioned_error) / np.var(baseline_error)
        
        # Store summary data
        joint_summary.append({
            'joint': joint_name,
            'baseline_rmse': baseline_rmse,
            'lesioned_rmse': lesioned_rmse,
            'rmse_increase': (lesioned_rmse - baseline_rmse),
            'rmse_percent': (lesioned_rmse / baseline_rmse - 1) * 100,
            'baseline_mae': baseline_mae,
            'lesioned_mae': lesioned_mae,
            'mae_increase': (lesioned_mae - baseline_mae),
            'mae_percent': (lesioned_mae / baseline_mae - 1) * 100,
            'error_var_ratio': error_var_ratio
        })
        
        print(f"\n=== {joint_name.upper()} ERROR STATISTICS ===")
        print(f"Control mean error: {np.mean(baseline_error):.4f} rad, Lesioned mean error: {np.mean(lesioned_error):.4f} rad")
        print(f"Control RMSE: {baseline_rmse:.4f} rad, Lesioned RMSE: {lesioned_rmse%load_ext autoreload
%autoreload 2

import os, copy, yaml, h5py
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm import tqdm
from omegaconf import OmegaConf
from orbax import checkpoint as ocp
from pyEDM import Simplex
from scipy import signal

# Environment setup
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["MUJOCO_GL"] = "egl"        # or "osmesa"
os.environ["PYOPENGL_PLATFORM"] = "egl"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = os.getenv("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.4")

import mediapy as media
from track_mjx.agent import checkpointing
from track_mjx.analysis import rollout, render, utils
from track_mjx.agent.mlp_ppo.intention_network import make_intention_policy
from brax.training.acme import running_statistics
from track_mjx.analysis.rollout import create_environment

# ═══════════════════════════════════════════════════════════════════════════════
# CONFIGURATION
# ═══════════════════════════════════════════════════════════════════════════════

# Analysis configuration
H5_PATH = "/root/vast/eric/track-mjx/notebooks/rollouts_full_mouse_arm_new_250826_030533_134914.h5"
NEURAL_E = 4     # Default embedding dimension (if action-specific not available)
NEURAL_TAU = -1  # Time delay for neural activations
Tp = 1           # Prediction horizon
PAD = 300        # NaN padding between clips

# Lesioning configuration
ENABLE_LESION = True
TOP_K = 5       # Number of top neurons to lesion per action
LAYERS_TO_ANALYZE = ["layer_0", "layer_1", "layer_2"]

# Action-specific embedding dimensions
ACTION_E = {
    "act0": 6, "act1": 10, "act2": 5, "act3": 5, "act4": 10,
    "act5": 6, "act6": 8,  "act7": 6,  "act8": 7
}

# Layer mapping: activation layer name -> parameter layer name
LAYER_MAPPING = {
    "layer_0": "hidden_0",
    "layer_1": "hidden_1",
    "layer_2": "hidden_2"
}

# Action names (muscle names)
action_names = ["Pec_C", "Lat", "PD", "AD", "MD", "Triceps_Lateral", "Triceps_Long", "Brachialis", "Biceps_Long"]
joint_names = ["sh_elv", "sh_ext", "sh_rot", "elbow"]

# Load CCM matrices
ccm_matrix = np.array([
    [1.      , 0.525813, 0.203717, 0.568015, 0.691638, 0.696206, 0.678814, 0.682802, 0.53741 ],
    [0.535335, 1.      , 0.139094, 0.681686, 0.522431, 0.667265, 0.463946, 0.663696, 0.537878],
    [0.295953, 0.163062, 1.      , 0.176845, 0.31952 , 0.165365, 0.213455, 0.375396, 0.216511],
    [0.597855, 0.7982  , 0.139659, 1.      , 0.747117, 0.731527, 0.566099, 0.710519, 0.65959 ],
    [0.719331, 0.698684, 0.245862, 0.794672, 1.      , 0.685992, 0.615126, 0.671923, 0.641866],
    [0.544157, 0.633614, 0.227312, 0.685914, 0.516835, 1.      , 0.248369, 0.862421, 0.316522],
    [0.619141, 0.398754, 0.146874, 0.399908, 0.393907, 0.320955, 1.      , 0.317059, 0.369504],
    [0.430258, 0.553677, 0.390358, 0.498353, 0.316513, 0.78614 , 0.232873, 1.      , 0.148855],
    [0.526948, 0.625399, 0.311665, 0.642974, 0.517234, 0.535678, 0.541893, 0.51868 , 1.      ]
])

# Convert a2j matrix from string to DataFrame
ccm_a2j_data = {
    'sh_elv': [0.676606, 0.614237, 0.572824, 0.702254, 0.760244, 0.655498, 0.390534, 0.611454, 0.541563],
    'sh_ext': [0.650763, 0.589529, 0.481112, 0.722722, 0.726062, 0.567940, 0.422373, 0.475919, 0.539591],
    'sh_rot': [0.635858, 0.713230, 0.444087, 0.783400, 0.774024, 0.698159, 0.503080, 0.610788, 0.570240],
    'elbow':  [0.679048, 0.721847, 0.376172, 0.783492, 0.763284, 0.710400, 0.507781, 0.616184, 0.521874]
}
ccm_matrix_a2j = pd.DataFrame(ccm_a2j_data, index=action_names)

print(f"Configuration: analyzing {len(LAYERS_TO_ANALYZE)} layers, selecting TOP-{TOP_K} neurons per action")
print("Using action-specific embedding dimensions:")
for act_idx, (act, dim) in enumerate(ACTION_E.items()):
    action_name = action_names[act_idx] if act_idx < len(action_names) else act
    print(f"  {action_name} ({act}): E={dim}")

# ═══════════════════════════════════════════════════════════════════════════════
# CCM PREDICTION FUNCTIONS 
# ═══════════════════════════════════════════════════════════════════════════════

def predict_action_effects(actions_to_lesion):
    """
    Predict the effects of lesioning specific actions (muscles) on other actions
    using the CCM causal relationship matrix.
    
    Args:
        actions_to_lesion: List of action indices to lesion
        
    Returns:
        predicted_impact: Dictionary mapping action names to predicted RMSE increase
    """
    # Convert action names to indices if needed
    action_indices = []
    for action in actions_to_lesion:
        if isinstance(action, str):
            if action in action_names:
                action_indices.append(action_names.index(action))
            else:
                print(f"Warning: Action '{action}' not found in action_names")
        else:
            action_indices.append(action)
    
    # Calculate predicted impact on each action
    predicted_impact = {}
    for i, action_name in enumerate(action_names):
        # Skip if this is an action we're lesioning
        if i in action_indices:
            predicted_impact[action_name] = 1.0  # Maximum impact
            continue
        
        # Calculate impact based on CCM values
        impact = 0
        for lesion_idx in action_indices:
            # How much does the lesioned action causally affect this action?
            # Higher CCM value means stronger causal relationship
            causal_strength = ccm_matrix[lesion_idx, i]
            impact += causal_strength
        
        # Normalize impact score (0-1)
        impact = min(impact, 1.0)
        predicted_impact[action_name] = impact
    
    return predicted_impact

def predict_joint_effects(actions_to_lesion):
    """
    Predict the effects of lesioning specific actions (muscles) on joint angles
    using the CCM action-to-joint causal relationship matrix.
    
    Args:
        actions_to_lesion: List of action indices to lesion
        
    Returns:
        predicted_impact: Dictionary mapping joint names to predicted RMSE increase
    """
    # Convert action names to indices if needed
    action_indices = []
    for action in actions_to_lesion:
        if isinstance(action, str):
            if action in action_names:
                action_indices.append(action_names.index(action))
            else:
                print(f"Warning: Action '{action}' not found in action_names")
        else:
            action_indices.append(action)
    
    # Calculate predicted impact on each joint
    predicted_impact = {}
    for joint_name in joint_names:
        impact = 0
        for lesion_idx in action_indices:
            action_name = action_names[lesion_idx]
            # How much does the lesioned action causally affect this joint?
            causal_strength = ccm_matrix_a2j.loc[action_name, joint_name]
            impact += causal_strength
        
        # Normalize impact score (0-1)
        impact = min(impact / len(action_indices), 1.0)
        predicted_impact[joint_name] = impact
    
    return predicted_impact

def predict_lesion_impact(top_neurons_by_action):
    """
    Predict the impact of lesioning neurons on actions and joints.
    
    Args:
        top_neurons_by_action: Dict mapping action names to lists of neuron indices
        
    Returns:
        Tuple of (action_impact, joint_impact) dictionaries
    """
    print("\n=== PREDICTING LESIONING IMPACT USING CCM MATRICES ===")
    
    # Determine which actions will be most affected by the lesioning
    affected_actions = list(top_neurons_by_action.keys())
    print(f"Neurons for these actions will be lesioned: {', '.join(affected_actions)}")
    
    # Predict effect on other actions (muscles)
    action_impact = predict_action_effects(affected_actions)
    
    # Sort actions by predicted impact
    sorted_actions = sorted(action_impact.items(), key=lambda x: x[1], reverse=True)
    print("\nPredicted impact on muscle activations (ordered by severity):")
    for action, impact in sorted_actions:
        impact_level = "High" if impact > 0.7 else "Medium" if impact > 0.4 else "Low"
        print(f"  {action}: {impact:.4f} - {impact_level} impact")
    
    # Predict effect on joints
    joint_impact = predict_joint_effects(affected_actions)
    
    # Sort joints by predicted impact
    sorted_joints = sorted(joint_impact.items(), key=lambda x: x[1], reverse=True)
    print("\nPredicted impact on joint angles (ordered by severity):")
    for joint, impact in sorted_joints:
        impact_level = "High" if impact > 0.7 else "Medium" if impact > 0.4 else "Low"
        print(f"  {joint}: {impact:.4f} - {impact_level} impact")
        
    # Create impact visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Action impact plot
    action_names_list = [a for a, _ in sorted_actions]
    action_impact_vals = [i for _, i in sorted_actions]
    ax1.barh(action_names_list, action_impact_vals, color='skyblue')
    ax1.set_title('Predicted Impact on Muscle Activations')
    ax1.set_xlabel('Predicted Impact (0-1)')
    ax1.grid(alpha=0.3)
    
    # Joint impact plot
    joint_names_list = [j for j, _ in sorted_joints]
    joint_impact_vals = [i for _, i in sorted_joints]
    ax2.barh(joint_names_list, joint_impact_vals, color='lightcoral')
    ax2.set_title('Predicted Impact on Joint Angles')
    ax2.set_xlabel('Predicted Impact (0-1)')
    ax2.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig("predicted_lesioning_impact.png", dpi=150)
    plt.show()
    
    return action_impact, joint_impact

def visualize_ccm_networks():
    """Create visualizations of the CCM networks to better understand causal structures"""
    print("\n=== VISUALIZING CCM CAUSAL NETWORKS ===")
    
    # Create a more readable muscle-to-muscle CCM matrix visualization
    plt.figure(figsize=(10, 8))
    sns.heatmap(ccm_matrix, cmap="YlGnBu", annot=True, fmt=".2f",
                xticklabels=action_names, yticklabels=action_names)
    plt.title('Muscle-to-Muscle CCM Causal Relationships')
    plt.tight_layout()
    plt.savefig("muscle_to_muscle_ccm.png", dpi=150)
    plt.show()
    
    # Create a transposed action-to-joint CCM visualization (easier to read)
    plt.figure(figsize=(12, 6))
    sns.heatmap(ccm_matrix_a2j.T, cmap="YlGnBu", annot=True, fmt=".2f")
    plt.title('Joint-Muscle CCM Relationships (Transposed)')
    plt.tight_layout()
    plt.savefig("joint_to_muscle_ccm_transposed.png", dpi=150)
    plt.show()

def compare_predicted_vs_actual(predicted_action_impact, predicted_joint_impact, action_summary, joint_summary):
    """
    Compare the predicted impact from CCM with actual measured impact after lesioning
    
    Args:
        predicted_action_impact: Dict mapping action names to predicted impact (0-1)
        predicted_joint_impact: Dict mapping joint names to predicted impact (0-1)
        action_summary: List of dicts with actual action RMSE measurements
        joint_summary: List of dicts with actual joint RMSE measurements
    """
    print("\n=== COMPARING PREDICTED VS ACTUAL LESIONING EFFECTS ===")
    
    # Prepare action data
    action_data = []
    for item in action_summary:
        action = item['action']
        if action in predicted_action_impact:
            action_data.append({
                'action': action,
                'predicted_impact': predicted_action_impact[action],
                'actual_rmse': item['rmse'],
                'actual_max_diff': item['max_diff']
            })
    
    # Prepare joint data
    joint_data = []
    for item in joint_summary:
        joint = item['joint']
        if joint in predicted_joint_impact:
            joint_data.append({
                'joint': joint,
                'predicted_impact': predicted_joint_impact[joint],
                'actual_rmse_increase': item['rmse_percent'] / 100,  # Convert percent to fraction
                'actual_mae_increase': item['mae_percent'] / 100     # Convert percent to fraction
            })
    
    # Create comparison plots
    fig, axes = plt.subplots(2, 1, figsize=(12, 14))
    
    # Action comparison
    action_df = pd.DataFrame(action_data)
    action_df = action_df.sort_values('predicted_impact', ascending=False)
    
    # Normalize actual RMSE to 0-1 scale for comparison
    max_rmse = action_df['actual_rmse'].max()
    action_df['normalized_actual'] = action_df['actual_rmse'] / max_rmse if max_rmse > 0 else 0
    
    ax1 = axes[0]
    x = np.arange(len(action_df))
    width = 0.35
    
    ax1.bar(x - width/2, action_df['predicted_impact'], width, label='Predicted Impact', color='skyblue')
    ax1.bar(x + width/2, action_df['normalized_actual'], width, label='Normalized Actual RMSE', color='coral')
    
    ax1.set_xticks(x)
    ax1.set_xticklabels(action_df['action'], rotation=45, ha='right')
    ax1.set_title('Muscle Activation: Predicted vs Actual Impact')
    ax1.set_ylabel('Impact (0-1 scale)')
    ax1.legend()
    ax1.grid(alpha=0.3)
    
    # Calculate correlation
    action_corr = np.corrcoef(action_df['predicted_impact'], action_df['normalized_actual'])[0, 1]
    ax1.text(0.02, 0.95, f"Correlation: {action_corr:.3f}", transform=ax1.transAxes, 
             bbox=dict(facecolor='white', alpha=0.8))
    
    # Joint comparison
    joint_df = pd.DataFrame(joint_data)
    joint_df = joint_df.sort_values('predicted_impact', ascending=False)
    
    ax2 = axes[1]
    x = np.arange(len(joint_df))
    
    ax2.bar(x - width/2, joint_df['predicted_impact'], width, label='Predicted Impact', color='skyblue')
    ax2.bar(x + width/2, joint_df['actual_rmse_increase'], width, label='Actual RMSE % Increase', color='coral')
    
    ax2.set_xticks(x)
    ax2.set_xticklabels(joint_df['joint'], rotation=45, ha='right')
    ax2.set_title('Joint Angles: Predicted vs Actual Impact')
    ax2.set_ylabel('Impact (0-1 scale)')
    ax2.legend()
    ax2.grid(alpha=0.3)
    
    # Calculate correlation
    joint_corr = np.corrcoef(joint_df['predicted_impact'], joint_df['actual_rmse_increase'])[0, 1]
    ax2.text(0.02, 0.95, f"Correlation: {joint_corr:.3f}", transform=ax2.transAxes,
             bbox=dict(facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    plt.savefig("predicted_vs_actual_impact.png", dpi=150)
    plt.show()
    
    # Print summary statistics
    print(f"Action prediction correlation: {action_corr:.3f}")
    print(f"Joint prediction correlation: {joint_corr:.3f}")
    
    if action_corr > 0.7 and joint_corr > 0.7:
        print("\n✓ CCM-based predictions showed STRONG correlation with actual lesioning effects")
    elif action_corr > 0.4 and joint_corr > 0.4:
        print("\n✓ CCM-based predictions showed MODERATE correlation with actual lesioning effects")
    else:
        print("\n⚠ CCM-based predictions showed WEAK correlation with actual lesioning effects")

# ═══════════════════════════════════════════════════════════════════════════════
# HELPER FUNCTIONS - NEURAL ANALYSIS
# ═══════════════════════════════════════════════════════════════════════════════

def zscore(x, axis=None, eps=1e-12):
    """Z-score normalization with small epsilon to avoid division by zero."""
    mu = np.nanmean(x, axis=axis, keepdims=True)
    sd = np.nanstd(x, axis=axis, keepdims=True)
    return (x - mu) / np.maximum(sd, eps)

def concat_neural_action_padded(neural_dict, actions_data, pad=1, standardize=True):
    """Concatenate neural activations with action data, NaN-padded between clips."""
    N, T = next(iter(neural_dict.values())).shape[:2]
    
    frames = []
    row_cursor = 0
    
    for clip in range(N):
        # Add padding between clips (except before first clip)
        if clip > 0:
            pad_data = {"time": np.arange(row_cursor+1, row_cursor+pad+1)}
            
            # Add NaN columns for all neural activations
            for layer_name, layer_data in neural_dict.items():
                n_neurons = layer_data.shape[-1]
                for neuron in range(n_neurons):
                    pad_data[f"{layer_name}_n{neuron:03d}"] = np.nan
            
            # Add NaN columns for actions
            for a, action_name in enumerate(action_names):
                pad_data[f"action_{action_name}"] = np.nan
            
            pad_df = pd.DataFrame(pad_data)
            frames.append(pad_df)
            row_cursor += pad
        
        # Add actual data for this clip
        clip_data = {"time": np.arange(row_cursor+1, row_cursor+T+1)}
        
        # Neural activations (standardized)
        for layer_name, layer_data in neural_dict.items():
            n_neurons = layer_data.shape[-1]
            for neuron in range(n_neurons):
                x = layer_data[clip, :, neuron]
                clip_data[f"{layer_name}_n{neuron:03d}"] = zscore(x) if standardize else x
        
        # Action values (standardized)
        for a, action_name in enumerate(action_names):
            action_vals = actions_data[clip, :, a]
            clip_data[f"action_{action_name}"] = zscore(action_vals) if standardize else action_vals
        
        clip_df = pd.DataFrame(clip_data)
        frames.append(clip_df)
        row_cursor += T
    
    big_df = pd.concat(frames, ignore_index=True)
    return big_df

def build_lib_pred_from_nan_blocks(df, probe_col, split=0.6, min_edge=8):
    """Split each non-NaN block 60/40 into lib/pred for pyEDM."""
    x = df[probe_col].to_numpy()
    notnan = ~np.isnan(x)
    edges = np.diff(np.concatenate([[0], notnan.view(np.int8), [0]]))
    starts = np.where(edges == 1)[0] + 1  # 1-based indexing
    ends = np.where(edges == -1)[0]       # 1-based indexing
    
    lib_parts, pred_parts = [], []
    for s, e in zip(starts, ends):
        n = e - s + 1
        if n < 2 * min_edge:
            continue
        m = s + int(np.floor(split * n)) - 1
        m = max(m, s + min_edge - 1)
        m = min(m, e - min_edge)
        if m <= s or m >= e:
            continue
        lib_parts.extend([str(s), str(m)])
        pred_parts.extend([str(m + 1), str(e)])
    
    return " ".join(lib_parts), " ".join(pred_parts)

def safe_correlation(y, yhat, min_pairs=10):
    """Calculate correlation with safety checks."""
    y = np.asarray(y, float)
    yhat = np.asarray(yhat, float)
    mask = np.isfinite(y) & np.isfinite(yhat)
    
    if mask.sum() < min_pairs:
        return np.nan
    
    y_clean, yhat_clean = y[mask], yhat[mask]
    sy, syh = y_clean.std(ddof=0), yhat_clean.std(ddof=0)
    
    if sy <= 1e-12 or syh <= 1e-12:
        return np.nan
    
    return float(np.corrcoef(y_clean, yhat_clean)[0, 1])

def get_top_neurons_for_action(corr_df, action_name, layer_name, k=25):
    """Get top K neurons for a specific action."""
    action_col = f"action_{action_name}"
    
    # Ensure action column exists
    if action_col not in corr_df.columns:
        print(f"Warning: {action_col} not found in correlation dataframe")
        return []
    
    # Sort by absolute correlation
    corrs = corr_df[action_col].dropna().abs().sort_values(ascending=False)
    top_neurons = corrs.head(k).index.tolist()
    
    # Extract neuron indices from column names
    neuron_indices = []
    for neuron_col in top_neurons:
        neuron_idx = int(neuron_col.split('_n')[1])
        neuron_indices.append(neuron_idx)
    
    # Print top neurons and their correlations
    print(f"\nTop {k} neurons in {layer_name} for {action_name}:")
    for i, neuron_col in enumerate(top_neurons[:min(10, k)]):  # Show top 10 max
        neuron_idx = int(neuron_col.split('_n')[1])
        raw_corr = corr_df.loc[neuron_col, action_col]
        abs_corr = abs(raw_corr)
        print(f"  {i+1:2d}. Neuron {neuron_idx:3d}: {raw_corr:+.4f} (|ρ|={abs_corr:.4f})")
    
    if len(top_neurons) > 10:
        print(f"  ... and {len(top_neurons)-10} more")
    
    return neuron_indices

def neural_to_action_simplex_heatmap(df, neural_cols, action_cols, lib, pred, layer_name, action_E=None, tau=-1, Tp=0):
    """Perform simplex analysis: Neural Activations → Action Dynamics with action-specific embedding dimensions"""
    print(f"Running simplex analysis for {layer_name}: {len(neural_cols)} neurons → {len(action_cols)} actions")
    print(f"Using action-specific embedding dimensions")
    
    # Initialize correlation matrix
    corr_matrix = np.full((len(neural_cols), len(action_cols)), np.nan, dtype=float)
    
    # Total number of computations to perform
    total_pairs = len(neural_cols) * len(action_cols)
    
    # If no action-specific embedding dimensions provided, use default
    if action_E is None:
        action_E = {f"act{i}": NEURAL_E for i in range(len(action_cols))}
    
    # Print embedding dimensions used for each action
    print("\nEmbedding dimensions for each action:")
    for ai, action_col in enumerate(action_cols):
        action_idx = f"act{ai}"
        E = action_E.get(action_idx, NEURAL_E)  # Fall back to default if not specified
        print(f"  {action_col} (index {action_idx}): E={E}")
    
    # Use tqdm for progress tracking - no need for additional manual updates
    pbar = tqdm(total=total_pairs, desc=f"{layer_name}→Action Simplex")
    
    for ni, neural_col in enumerate(neural_cols):
        for ai, action_col in enumerate(action_cols):
            # Determine embedding dimension for this action
            action_idx = f"act{ai}"
            E = action_E.get(action_idx, NEURAL_E)  # Fall back to default if not specified
            
            try:
                # Run simplex prediction with action-specific embedding dimension
                pred_df = Simplex(
                    dataFrame=df,
                    lib=lib,
                    pred=pred,
                    columns=neural_col,
                    target=action_col,
                    E=E,  # Use action-specific E
                    tau=tau,
                    Tp=Tp,
                    ignoreNan=True,
                    showPlot=False
                )
                
                # Extract predictions and observations
                obs = pred_df["Observations"].to_numpy()
                hat = pred_df["Predictions"].to_numpy()
                
                # Calculate correlation
                corr_matrix[ni, ai] = safe_correlation(obs, hat, min_pairs=10)
                
            except Exception as e:
                print(f"Error for {neural_col} → {action_col} with E={E}: {e}")
                corr_matrix[ni, ai] = np.nan
            
            # Update progress bar without additional prints
            pbar.update(1)
    
    # Close progress bar
    pbar.close()
    print(f"Completed {layer_name} simplex analysis")
    
    # Convert to DataFrame
    corr_df = pd.DataFrame(corr_matrix, index=neural_cols, columns=action_cols)
    return corr_df

def analyze_layer_for_actions(layer_name, neural_dict, actions_data):
    """Analyze a single layer to identify neurons predictive of actions."""
    print(f"\n{'='*80}\nANALYZING LAYER: {layer_name} FOR ACTION PREDICTION\n{'='*80}")
    
    # Build a layer-specific neural dictionary (just this one layer)
    layer_neural_dict = {layer_name: neural_dict[layer_name]}
    
    # Build DataFrame with neural activations and actions
    big_df_actions = concat_neural_action_padded(
        neural_dict=layer_neural_dict,
        actions_data=actions_data,
        pad=PAD,
        standardize=True
    )
    
    # Extract column names for analysis
    neural_cols = [col for col in big_df_actions.columns if layer_name in col]
    action_cols = [col for col in big_df_actions.columns if col.startswith("action_")]
    
    # Build lib/pred indices
    probe_col = neural_cols[0]
    lib, pred = build_lib_pred_from_nan_blocks(big_df_actions, probe_col=probe_col, split=0.6, min_edge=8)
    
    # Run simplex analysis with action-specific embedding dimensions
    print(f"Starting {layer_name} Neural → Action Simplex Analysis...")
    corr_results = neural_to_action_simplex_heatmap(
        df=big_df_actions,
        neural_cols=neural_cols,
        action_cols=action_cols,
        lib=lib,
        pred=pred,
        layer_name=layer_name,
        action_E=ACTION_E,  # Pass the action-specific embedding dimensions
        tau=NEURAL_TAU,
        Tp=Tp
    )
    
    # Create and save heatmap
    plt.figure(figsize=(12, 10))
    
    # Compute mean absolute correlation per neuron for sorting
    mean_abs_corr = corr_results.abs().mean(axis=1).sort_values(ascending=False)
    sorted_neurons = mean_abs_corr.index.tolist()
    
    # Get top 50 neurons for visualization
    top_neurons = sorted_neurons[:50]
    top_corr_df = corr_results.loc[top_neurons]
    
    # Create heatmap
    plt.figure(figsize=(12, 14))
    sns.heatmap(top_corr_df, cmap="RdBu_r", center=0, vmin=-1, vmax=1, 
                xticklabels=True, yticklabels=True, annot=False)
    plt.title(f"{layer_name} Neural Activations → Actions Correlation (Top 50 Neurons)\nUsing action-specific embedding dimensions")
    plt.tight_layout()
    plt.savefig(f"{layer_name}_to_actions_heatmap.png", dpi=150)
    plt.close()
    
    # For each action, get top-k neurons
    top_neurons_by_action = {}
    all_top_neurons = set()
    
    for action_name in action_names:
        top_for_action = get_top_neurons_for_action(
            corr_df=corr_results,
            action_name=action_name,
            layer_name=layer_name,
            k=TOP_K
        )
        top_neurons_by_action[action_name] = top_for_action
        all_top_neurons.update(top_for_action)
    
    # Return unique neurons across all actions
    all_top_neurons_list = list(all_top_neurons)
    print(f"\nFound {len(all_top_neurons_list)} unique neurons in {layer_name} predictive of actions")
    
    return all_top_neurons_list, top_neurons_by_action, corr_results

# [Rest of the helper functions continue as in the original code...]
# ═══════════════════════════════════════════════════════════════════════════════
# HELPER FUNCTIONS - LESIONING
# ═══════════════════════════════════════════════════════════════════════════════

def create_lesioned_policy_complete_multi_layer(original_policy, neurons_by_layer):
    """
    Create a lesioned version of the policy by zeroing both incoming and outgoing weights
    for multiple layers at once.
    
    Args:
        original_policy: The original policy to lesion
        neurons_by_layer: Dict mapping layer names (e.g., "hidden_0") to lists of neuron indices
    
    Returns:
        Lesioned policy
    """
    # Make a deep copy of the policy to avoid modifying the original
    lesioned_policy = jtu.tree_map(lambda x: jnp.array(x), original_policy)
    
    # Extract policy parameters
    processor_params, policy_params = lesioned_policy
    decoder_params = policy_params['params']['decoder']
    
    print(f"\nLesioning neurons across {len(neurons_by_layer)} layers...")
    
    # Process each layer
    for target_layer, neuron_indices_to_lesion in neurons_by_layer.items():
        print(f"\n=== Lesioning layer: {target_layer} ===")
        
        if target_layer not in decoder_params:
            print(f"Warning: Layer {target_layer} not found in model. Available layers: {list(decoder_params.keys())}")
            continue
        
        layer_params = decoder_params[target_layer]
        neuron_indices_array = jnp.array(neuron_indices_to_lesion)
        
        print(f"Complete lesioning of {len(neuron_indices_to_lesion)} neurons in {target_layer}")
        
        # 1. Zero outgoing weights (columns in kernel, elements in bias)
        if 'kernel' in layer_params and 'bias' in layer_params:
            kernel = layer_params['kernel']
            bias = layer_params['bias']
            
            # Filter valid indices for outgoing weights
            valid_indices = neuron_indices_array[neuron_indices_array < kernel.shape[1]]
            
            new_kernel = kernel.at[:, valid_indices].set(0.0)  # Set to 0 not -1
            new_bias = bias.at[valid_indices].set(0.0)         # Set to 0 not -1
            
            layer_params['kernel'] = new_kernel
            layer_params['bias'] = new_bias
            
            print(f"✓ Zeroed outgoing weights for {len(valid_indices)} neurons")
        
        # 2. Zero incoming weights to those neurons (rows in the SAME layer's kernel)
        if 'kernel' in layer_params:
            kernel = layer_params['kernel']
            
            # Filter valid indices for incoming weights (rows)
            valid_indices = neuron_indices_array[neuron_indices_array < kernel.shape[0]]
            
            if len(valid_indices) > 0:
                # Zero rows
                new_kernel = kernel.at[valid_indices, :].set(0.0)  # Set to 0 not -1
                layer_params['kernel'] = new_kernel
                
                print(f"✓ Zeroed incoming weights for {len(valid_indices)} neurons")
        
        # 3. Also zero outgoing connections from next layer if it exists
        layer_mapping = {
            "hidden_0": "hidden_1",
            "hidden_1": "hidden_2", 
            "hidden_2": "hidden_3"
        }
        
        next_layer = layer_mapping.get(target_layer)
        if next_layer and next_layer in decoder_params:
            next_layer_params = decoder_params[next_layer]
            if 'kernel' in next_layer_params:
                next_kernel = next_layer_params['kernel']
                valid_indices = neuron_indices_array[neuron_indices_array < next_kernel.shape[0]]
                
                if len(valid_indices) > 0:
                    # Zero rows
                    new_next_kernel = next_kernel.at[valid_indices, :].set(0.0)  # Set to 0 not -1
                    next_layer_params['kernel'] = new_next_kernel
                    
                    print(f"✓ Zeroed connections from lesioned neurons to {next_layer}")
    
    print("\n✓ Multi-layer lesioning completed successfully")
    return lesioned_policy

def verify_multi_layer_lesioning(original_policy, lesioned_policy, neurons_by_layer):
    """Verify that lesioning has been successfully applied across multiple layers."""
    print("\n=== MULTI-LAYER LESIONING VERIFICATION ===")
    
    processor_params_orig, policy_params_orig = original_policy
    processor_params_les, policy_params_les = lesioned_policy
    
    verification_passed = True
    
    for target_layer, neuron_indices in neurons_by_layer.items():
        print(f"\nVerifying layer: {target_layer}")
        
        try:
            # Get the target layer parameters
            layer_params_orig = policy_params_orig['params']['decoder'][target_layer]
            layer_params_les = policy_params_les['params']['decoder'][target_layer]
            
            # Check a subset of neuron weights
            sample_neurons = neuron_indices[:min(3, len(neuron_indices))]
            
            for idx in sample_neurons:
                if idx < layer_params_les['kernel'].shape[1]:
                    # Check bias
                    bias_orig = layer_params_orig['bias'][idx]
                    bias_les = layer_params_les['bias'][idx]
                    bias_zeroed = jnp.allclose(bias_les, 0.0)
                    
                    # Check outgoing weights
                    outgoing_orig = layer_params_orig['kernel'][:, idx]
                    outgoing_les = layer_params_les['kernel'][:, idx]
                    outgoing_zeroed = jnp.allclose(outgoing_les, jnp.zeros_like(outgoing_les))
                    
                    print(f"  Neuron {idx}: Bias zeroed: {bias_zeroed}, Outgoing weights zeroed: {outgoing_zeroed}")
                    
                    if not bias_zeroed or not outgoing_zeroed:
                        verification_passed = False
            
            # Check overall changes
            kernel_diff = jnp.sum(jnp.abs(layer_params_orig['kernel'] - layer_params_les['kernel']))
            bias_diff = jnp.sum(jnp.abs(layer_params_orig['bias'] - layer_params_les['bias']))
            print(f"  Total weight changes: Kernel diff: {kernel_diff:.2f}, Bias diff: {bias_diff:.2f}")
            
            if kernel_diff <= 0 or bias_diff <= 0:
                verification_passed = False
                
        except Exception as e:
            print(f"Error in verification for {target_layer}: {e}")
            verification_passed = False
    
    return verification_passed

# ═══════════════════════════════════════════════════════════════════════════════
# ACTION COMPARISON FUNCTIONS
# ═══════════════════════════════════════════════════════════════════════════════

def compare_actions(baseline_actions, lesioned_actions):
    """Compare baseline and lesioned actions, generate plots and statistics"""
    print("\n=== COMPARING BASELINE VS LESIONED ACTIONS ===")
    
    n_actions = baseline_actions.shape[1]
    summary_data = []
    
    # Create figure for overall comparison
    plt.figure(figsize=(16, 14))
    
    # Process each action
    for action_idx, action_name in enumerate(action_names[:n_actions]):
        baseline_action = baseline_actions[:, action_idx]
        lesioned_action = lesioned_actions[:, action_idx]
        
        # Calculate metrics
        mae = np.mean(np.abs(baseline_action - lesioned_action))
        rmse = np.sqrt(np.mean((baseline_action - lesioned_action)**2))
        max_diff = np.max(np.abs(baseline_action - lesioned_action))
        
        # Store summary data
        summary_data.append({
            'action': action_name,
            'mae': mae,
            'rmse': rmse,
            'max_diff': max_diff
        })
        
        print(f"\n=== {action_name.upper()} ACTION STATISTICS ===")
        print(f"Mean Absolute Error: {mae:.4f}")
        print(f"RMSE: {rmse:.4f}")
        print(f"Maximum Difference: {max_diff:.4f}")
        
        # Individual action plots
        plt.subplot(3, 3, action_idx + 1)
        plt.plot(baseline_action, label='Baseline', color='blue', alpha=0.7)
        plt.plot(lesioned_action, label='Lesioned', color='red', alpha=0.7)
        plt.title(f'{action_name} Actions')
        plt.xlabel('Frame')
        plt.ylabel('Action Value')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Create detailed action figure
        plt.figure(figsize=(14, 12))
        
        # Plot 1: Time series comparison
        plt.subplot(2, 2, 1)
        plt.plot(baseline_action, label='Baseline', color='blue', alpha=0.7)
        plt.plot(lesioned_action, label='Lesioned', color='red', alpha=0.7)
        plt.title(f'{action_name} Actions')
        plt.xlabel('Frame')
        plt.ylabel('Action Value')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot 2: Difference between baseline and lesioned
        plt.subplot(2, 2, 2)
        plt.plot(lesioned_action - baseline_action, color='purple', alpha=0.7)
        plt.title(f'{action_name} Action Difference (Lesioned - Baseline)')
        plt.xlabel('Frame')
        plt.ylabel('Difference')
        plt.axhline(y=0, color='k', linestyle='-', alpha=0.3)
        plt.grid(True, alpha=0.3)
        
        # Plot 3: Distribution comparison
        plt.subplot(2, 2, 3)
        sns.kdeplot(baseline_action, label='Baseline', color='blue', fill=True, alpha=0.3)
        sns.kdeplot(lesioned_action, label='Lesioned', color='red', fill=True, alpha=0.3)
        plt.title(f'{action_name} Action Distribution')
        plt.xlabel('Action Value')
        plt.ylabel('Density')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot 4: Scatter plot of lesioned vs baseline
        plt.subplot(2, 2, 4)
        plt.scatter(baseline_action, lesioned_action, alpha=0.5, s=10)
        plt.plot([min(baseline_action), max(baseline_action)], 
                 [min(baseline_action), max(baseline_action)], 
                 'k--', alpha=0.5)
        plt.title(f'{action_name} Lesioned vs Baseline')
        plt.xlabel('Baseline Action')
        plt.ylabel('Lesioned Action')
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(Path(ckpt_path) / f"{action_name}_action_analysis.png", dpi=150)
        plt.close()
    
    # Finalize and save the action comparison plot
    plt.tight_layout()
    plt.savefig(Path(ckpt_path) / "all_actions_comparison.png", dpi=150)
    plt.close()
    
    # Create a summary bar chart comparing the effect of lesioning on each action
    plt.figure(figsize=(12, 8))
    
    # Bar chart of RMSE per action
    actions = [d['action'] for d in summary_data]
    rmse_values = [d['rmse'] for d in summary_data]
    
    bars = plt.bar(actions, rmse_values, color='skyblue')
    plt.title('RMSE Between Baseline and Lesioned Actions')
    plt.xlabel('Action')
    plt.ylabel('RMSE')
    plt.xticks(rotation=45)
    plt.grid(axis='y', alpha=0.3)
    
    # Annotate with values
    for bar, value in zip(bars, rmse_values):
        plt.text(bar.get_x() + bar.get_width()/2., 
                 value + 0.01,
                 f'{value:.4f}', 
                 ha='center', va='bottom', rotation=0)
    
    plt.tight_layout()
    plt.savefig(Path(ckpt_path) / "action_rmse_summary.png", dpi=150)
    plt.close()
    
    # Print summary table
    print("\n=== SUMMARY OF LESIONING EFFECTS ACROSS ACTIONS ===")
    print(f"{'Action':<20} {'MAE':>10} {'RMSE':>10} {'Max Difference':>15}")
    print("-" * 60)
    for data in summary_data:
        print(f"{data['action']:<20} {data['mae']:>10.4f} {data['rmse']:>10.4f} {data['max_diff']:>15.4f}")
        
    return summary_data

# ═══════════════════════════════════════════════════════════════════════════════
# MAIN EXECUTION
# ═══════════════════════════════════════════════════════════════════════════════

print("\n=== STARTING NEURAL ANALYSIS AND LESIONING ===")

# Visualize CCM networks to understand causal relationships
visualize_ccm_networks()

# Step 1: Load neural and action data for analysis
print("\nLoading neural activation and action data...")
with h5py.File(H5_PATH, "r") as f:
    actions = f["actions"][...]  # Load actions data
    
    # Neural activations from all decoder layers
    layer_data = {}
    for layer_name in LAYERS_TO_ANALYZE:
        layer_data[layer_name] = f[f"decoder_activations/{layer_name}"][:]
        print(f"{layer_name} shape: {layer_data[layer_name].shape}")
    
    print(f"Actions shape: {actions.shape}")

# Step 2: Analyze each layer to identify neurons predictive of actions
neurons_by_layer = {}
neurons_by_action = {}
correlation_results = {}

for layer_name in LAYERS_TO_ANALYZE:
    # Find the corresponding parameter layer name
    param_layer = LAYER_MAPPING[layer_name]
    # Analyze layer to get top neurons
    top_neurons, top_by_action, corr_df = analyze_layer_for_actions(layer_name, layer_data, actions)
    neurons_by_layer[param_layer] = top_neurons
    neurons_by_action[layer_name] = top_by_action
    correlation_results[layer_name] = corr_df

print("\n=== IDENTIFIED NEURONS TO LESION ===")
for layer_name, neurons in neurons_by_layer.items():
    print(f"{layer_name}: {len(neurons)} unique neurons")

# Step 3: Predict impact of lesioning using CCM matrices
predicted_action_impact, predicted_joint_impact = predict_lesion_impact(neurons_by_action["layer_0"])

# Step 4: Load checkpoint and config
ckpt_path = Path.cwd().parent / "model_checkpoints/250826_030533_134914"
ckpt = checkpointing.load_checkpoint_for_eval(ckpt_path)
cfg = ckpt["cfg"]

# Configure data path
cfg.data_path = "/root/vast/eric/stac-mjx/refined_STACed_data/A36-1_2023-07-18_16-54-01_lightOff_trial01_ik.h5,/root/vast/eric/stac-mjx/refined_STACed_data/A36-1_2023-07-18_16-54-01_lightOff_trial04_ik.h5,/root/vast/eric/stac-mjx/refined_STACed_data/A36-1_2023-07-18_16-54-01_lightOff_trial09_ik.h5,/root/vast/eric/stac-mjx/refined_STACed_data/A36-1_2023-07-18_16-54-01_lightOff_trial10_ik.h5,/root/vast/eric/stac-mjx/refined_STACed_data/A36-1_2023-07-18_16-54-01_lightOff_trial13_ik.h5"
cfg.train_setup.checkpoint_to_restore = ckpt_path

# Create environment for testing
env = rollout.create_environment(cfg)

# Step 5: Create baseline (non-lesioned) rollout for comparison
print("\n=== GENERATING BASELINE ROLLOUT FOR COMPARISON ===")
baseline_inference_fn = checkpointing.load_inference_fn(cfg, ckpt["policy"])
baseline_rollout_gen = rollout.create_rollout_generator(
    cfg, env, baseline_inference_fn, 
    log_activations=True, log_metrics=True, log_sensor_data=True
)
baseline_rollout = baseline_rollout_gen(clip_idx=1)
print("✓ Baseline rollout generated")

# Step 6: Apply multi-layer lesioning to the policy
if ENABLE_LESION:
    print("\n=== CREATING MULTI-LAYER LESIONED POLICY ===")
    
    # Create policy network structure first
    original_policy = ckpt["policy"]
    lesioned_policy = create_lesioned_policy_complete_multi_layer(
        original_policy, 
        neurons_by_layer
    )
    
    # Verify that lesioning was properly applied
    verification_passed = verify_multi_layer_lesioning(original_policy, lesioned_policy, neurons_by_layer)
    print(f"\nMulti-layer lesioning verification {'PASSED' if verification_passed else 'FAILED'}")
    
    # Replace the policy in the checkpoint
    ckpt["policy"] = lesioned_policy
    print(f"Checkpoint policy updated with lesioned policy")
    
    # Use the standard loader with the modified checkpoint
    print("\n=== SETTING UP LESIONED INFERENCE FUNCTION ===")
    inference_fn = checkpointing.load_inference_fn(cfg, ckpt["policy"])
    print("✓ Lesioned inference function created")
else:
    # Use the original policy from checkpoint
    inference_fn = checkpointing.load_inference_fn(cfg, ckpt["policy"])
    print("✓ Original inference function loaded")

# Step 7: Generate rollout with the lesioned policy
print("\n=== GENERATING ROLLOUT WITH MULTI-LAYER LESIONED POLICY ===")
generate_rollout = rollout.create_rollout_generator(
    cfg, 
    env, 
    inference_fn, 
    log_activations=True, 
    log_metrics=True, 
    log_sensor_data=True
)

lesioned_rollout = generate_rollout(clip_idx=1)
print("✓ Lesioned rollout generated")

# Step 8: Extract and analyze activations from rollout
if ENABLE_LESION and lesioned_rollout.get('activations') is not None:
    print("\n=== ANALYZING ROLLOUT ACTIVATIONS ===")
    activations = lesioned_rollout['activations']['decoder']
    
    # Check each layer
    for activation_layer, target_layer in LAYER_MAPPING.items():
        if activation_layer in activations:
            layer_acts = activations[activation_layer]
            neurons_to_check = neurons_by_layer[target_layer]
            
            # Check activation stats for lesioned neurons
            valid_neurons = [n for n in neurons_to_check if n < layer_acts.shape[-1]]
            if valid_neurons:
                lesioned_acts = layer_acts[:, valid_neurons]
                avg_activation = jnp.mean(jnp.abs(lesioned_acts))
                max_activation = jnp.max(jnp.abs(lesioned_acts))
                
                print(f"\n{activation_layer} ({target_layer}) lesioned neurons stats:")
                print(f"  Average absolute activation: {avg_activation:.6f}")
                print(f"  Maximum absolute activation: {max_activation:.6f}")
                print(f"  {'⚠️ NEURONS STILL ACTIVE' if max_activation > 0.1 else '✓ NEURONS PROPERLY SILENCED'}")
        else:
            print(f"Warning: {activation_layer} not found in activations")

# Step 9: Compare actions between baseline and lesioned rollouts
print("\n=== COMPARING BASELINE AND LESIONED ACTIONS ===")
baseline_actions = np.array(baseline_rollout['actions']) if 'actions' in baseline_rollout else None
lesioned_actions = np.array(lesioned_rollout['actions']) if 'actions' in lesioned_rollout else None

action_summary = None
if baseline_actions is not None and lesioned_actions is not None:
    action_summary = compare_actions(baseline_actions, lesioned_actions)
else:
    print("ERROR: Could not extract actions from rollouts for comparison")

# Step 10: Render comparison videos
print("\n=== RENDERING COMPARISON VIDEOS ===")

# Render lesioned video
lesioned_frames, lesioned_framerate = render.render_rollout(
    cfg, 
    lesioned_rollout, 
    height=480,
    width=640,
)

# Render baseline video
baseline_frames, baseline_framerate = render.render_rollout(
    cfg, 
    baseline_rollout, 
    height=480,
    width=640,
)

# Save the videos
lesioned_video_path = Path(ckpt_path) / f"rollout_actions_lesioned_top{TOP_K}.mp4"
baseline_video_path = Path(ckpt_path) / f"rollout_baseline.mp4"

media.write_video(lesioned_video_path, lesioned_frames, fps=lesioned_framerate)
media.write_video(baseline_video_path, baseline_frames, fps=baseline_framerate)

print(f"✓ Lesioned video saved to {lesioned_video_path}")
print(f"✓ Baseline video saved to {baseline_video_path}")

# Display the lesioned video
print("\n=== DISPLAYING MULTI-LAYER LESIONED VIDEO ===")
media.show_video(lesioned_frames, fps=lesioned_framerate)

# Step 11: Compare joint positions between baseline and lesioned rollouts
print("\n=== COMPARING CONTROL VS LESIONED PERFORMANCE ACROSS ALL JOINTS ===")

# Joint names for better labeling
JOINT_NAMES = ["sh_elv", "sh_ext", "sh_rot", "elbow"]

def extract_joint_data(rollout, key_name='qposes_rollout'):
    """Extract data for all joints from specified key"""
    if key_name in rollout:
        print(f"Found {key_name} with shape: {rollout[key_name].shape}")
        return np.array(rollout[key_name])
    
    print(f"Could not find {key_name}. Available arrays:")
    for key, value in rollout.items():
        if isinstance(value, np.ndarray):
            print(f"  {key}: {value.shape}")
        elif hasattr(value, 'keys'):
            print(f"  {key} (dict/object with keys): {list(value.keys())}")
    
    return None

# Extract data for all joints
baseline_joints = extract_joint_data(baseline_rollout, 'qposes_rollout')
lesioned_joints = extract_joint_data(lesioned_rollout, 'qposes_rollout')
reference_joints = extract_joint_data(baseline_rollout, 'qposes_ref')

# If reference is not found in baseline, try lesioned
if reference_joints is None:
    reference_joints = extract_joint_data(lesioned_rollout, 'qposes_ref')

# Extract rewards
baseline_rewards = np.array(baseline_rollout['rewards']) if 'rewards' in baseline_rollout else None
lesioned_rewards = np.array(lesioned_rollout['rewards']) if 'rewards' in lesioned_rollout else None

# Fall back to state_rewards if needed
if baseline_rewards is None and 'state_rewards' in baseline_rollout:
    baseline_rewards = np.array(baseline_rollout['state_rewards'])
if lesioned_rewards is None and 'state_rewards' in lesioned_rollout:
    lesioned_rewards = np.array(lesioned_rollout['state_rewards'])

# Proceed only if we found all the necessary data
joint_summary = None
if baseline_joints is None or lesioned_joints is None or reference_joints is None:
    print("Could not find required joint data in the rollouts")
else:
    # Print reward statistics
    if baseline_rewards is not None and lesioned_rewards is not None:
        print("\n=== OVERALL REWARD STATISTICS ===")
        print(f"Control mean reward: {np.mean(baseline_rewards):.4f}, Lesioned mean reward: {np.mean(lesioned_rewards):.4f}")
        print(f"Reward reduction: {np.mean(baseline_rewards) - np.mean(lesioned_rewards):.4f} ({(1 - np.mean(lesioned_rewards)/np.mean(baseline_rewards))*100:.1f}%)")
        
        # Plot reward comparison
        plt.figure(figsize=(8, 6))
        plt.plot(baseline_rewards, label='Control', color='blue', alpha=0.7)
        plt.plot(lesioned_rewards, label='Lesioned', color='red', alpha=0.7)
        plt.title('Reward Over Time')
        plt.xlabel('Frame')
        plt.ylabel('Reward')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(Path(ckpt_path) / "overall_lesion_rewards.png", dpi=150)
        plt.close()
    
    # Calculate overall metrics for summary table
    joint_summary = []
    
    # Create a 2x2 grid of joint plots
    plt.figure(figsize=(16, 14))
    
    # Process each joint
    for joint_idx, joint_name in enumerate(JOINT_NAMES):
        baseline_joint = baseline_joints[:, joint_idx]
        lesioned_joint = lesioned_joints[:, joint_idx]
        reference_joint = reference_joints[:, joint_idx]
        
        # Calculate error metrics
        baseline_error = baseline_joint - reference_joint
        lesioned_error = lesioned_joint - reference_joint
        
        # Calculate absolute errors for visualization
        baseline_abs_error = np.abs(baseline_error)
        lesioned_abs_error = np.abs(lesioned_error)
        
        # Calculate statistics
        baseline_rmse = np.sqrt(np.mean(baseline_error**2))
        lesioned_rmse = np.sqrt(np.mean(lesioned_error**2))
        baseline_mae = np.mean(baseline_abs_error)
        lesioned_mae = np.mean(lesioned_abs_error)
        error_var_ratio = np.var(lesioned_error) / np.var(baseline_error)
        
        # Store summary data
        joint_summary.append({
            'joint': joint_name,
            'baseline_rmse': baseline_rmse,
            'lesioned_rmse': lesioned_rmse,
            'rmse_increase': (lesioned_rmse - baseline_rmse),
            'rmse_percent': (lesioned_rmse / baseline_rmse - 1) * 100,
            'baseline_mae': baseline_mae,
            'lesioned_mae': lesioned_mae,
            'mae_increase': (lesioned_mae - baseline_mae),
            'mae_percent': (lesioned_mae / baseline_mae - 1) * 100,
            'error_var_ratio': error_var_ratio
        })
        
        print(f"\n=== {joint_name.upper()} ERROR STATISTICS ===")
        print(f"Control mean error: {np.mean(baseline_error):.4f} rad, Lesioned mean error: {np.mean(lesioned_error):.4f} rad")
        print(f"Control RMSE: {baseline_rmse:.4f} rad, Lesioned RMSE: {lesioned_rmse:.4f} rad")
        print(f"Control MAE: {baseline_mae:.4f} rad, Lesioned MAE: {lesioned_mae:.4f} rad")
        print(f"Control error std: {np.std(baseline_error):.4f} rad, Lesioned error std: {np.std(lesioned_error):.4f} rad")
        print(f"Error variance ratio (lesioned/control): {error_var_ratio:.4f}")
        
        # Individual joint plots
        plt.subplot(2, 2, joint_idx + 1)
        
        # Plot time series with reference
        plt.plot(reference_joint, label='Reference', color='green', linestyle='--', linewidth=2)
        plt.plot(baseline_joint, label='Control', color='blue', alpha=0.7)
        plt.plot(lesioned_joint, label='Lesioned', color='red', alpha=0.7)
        plt.title(f'{joint_name.upper()} Trajectories')
        plt.xlabel('Frame')
        plt.ylabel('Joint Position (rad)')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Create detailed single-joint figure
        plt.figure(figsize=(14, 12))
        
        # Plot 1: Time series comparison with reference
        plt.subplot(2, 2, 1)
        plt.plot(reference_joint, label='Reference', color='green', linestyle='--', linewidth=2)
        plt.plot(baseline_joint, label='Control', color='blue', alpha=0.7)
        plt.plot(lesioned_joint, label='Lesioned', color='red', alpha=0.7)
        plt.title(f'{joint_name.upper()} Trajectories')
        plt.xlabel('Frame')
        plt.ylabel('Joint Position (rad)')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot 2: Error distribution comparison (using absolute errors)
        plt.subplot(2, 2, 2)
        sns.kdeplot(baseline_abs_error, label='Control Error', color='blue', fill=True, alpha=0.3)
        sns.kdeplot(lesioned_abs_error, label='Lesioned Error', color='red', fill=True, alpha=0.3)
        plt.title(f'{joint_name.upper()} Absolute Error Distribution')
        plt.xlabel('Absolute Error (rad)')
        plt.ylabel('Density')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot 3: Absolute error time series
        plt.subplot(2, 2, 3)
        plt.plot(baseline_abs_error, label='Control', color='blue', alpha=0.7)
        plt.plot(lesioned_abs_error, label='Lesioned', color='red', alpha=0.7)
        plt.title(f'{joint_name.upper()} Absolute Error Over Time')
        plt.xlabel('Frame')
        plt.ylabel('Absolute Error (rad)')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # Plot 4: Box plot comparison of absolute errors
        plt.subplot(2, 2, 4)
        error_data = [baseline_abs_error, lesioned_abs_error]
        plt.boxplot(error_data, labels=['Control Error', 'Lesioned Error'])
        plt.title(f'{joint_name.upper()} Absolute Error Distribution')
        plt.ylabel('Absolute Error (rad)')
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(Path(ckpt_path) / f"{joint_name}_lesion_error_analysis.png", dpi=150)
        plt.close()

    # Finalize and save the joint comparison plot
    plt.tight_layout()
    plt.savefig(Path(ckpt_path) / "all_joints_lesion_comparison.png", dpi=150)
    plt.close()
    
    # Create a summary bar chart comparing the effect of lesioning on each joint
    plt.figure(figsize=(12, 8))
    
    # Bar chart of RMSE percent increase
    joints = [d['joint'] for d in joint_summary]
    rmse_pct_increases = [d['rmse_percent'] for d in joint_summary]
    
    bars = plt.bar(joints, rmse_pct_increases, color='skyblue')
    plt.axhline(y=0, color='k', linestyle='-', alpha=0.3)
    plt.title('Percent Increase in RMSE After Lesioning')
    plt.xlabel('Joint')
    plt.ylabel('% Increase in RMSE')
    plt.grid(axis='y', alpha=0.3)
    
    # Annotate with values
    for bar, value in zip(bars, rmse_pct_increases):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., 
                 height + (5 if height > 0 else -15),
                 f'{value:.1f}%', 
                 ha='center', va='bottom', rotation=0)
    
    plt.tight_layout()
    plt.savefig(Path(ckpt_path) / "joint_error_increase_summary.png", dpi=150)
    plt.close()
    
    # Print summary table
    print("\n=== SUMMARY OF LESIONING EFFECTS ACROSS JOINTS ===")
    print(f"{'Joint':<10} {'Baseline RMSE':>13} {'Lesioned RMSE':>14} {'RMSE Δ%':>10} {'Baseline MAE':>13} {'Lesioned MAE':>14} {'MAE Δ%':>10} {'Error Var Ratio':>15}")
    print("-" * 100)
    for data in joint_summary:
        print(f"{data['joint']:<10} {data['baseline_rmse']:>13.4f} {data['lesioned_rmse']:>14.4f} {data['rmse_percent']:>+10.1f}% {data['baseline_mae']:>13.4f} {data['lesioned_mae']:>14.4f} {data['mae_percent']:>+10.1f}% {data['error_var_ratio']:>15.2f}")

    # Compare predicted vs actual impact if data is available
    if action_summary is not None and joint_summary is not None:
        compare_predicted_vs_actual(predicted_action_impact, predicted_joint_impact, action_summary, joint_summary)

# Step 12: Print summary of lesioning
print("\n=== MULTI-LAYER ACTION-PREDICTING NEURON LESIONING SUMMARY ===")
for layer_name, neurons in neurons_by_layer.items():
    print(f"{layer_name}: Lesioned {len(neurons)} unique neurons")

total_neurons = sum(len(neurons) for neurons in neurons_by_layer.values())
print(f"Total neurons lesioned: {total_neurons}")
print(f"\nThis experiment dynamically identified and lesioned the top {TOP_K} neurons")
print("from each layer that are most predictive of each action/muscle activation")
print("Using action-specific embedding dimensions for optimal analysis")