# VLA Introspection: Inside SmolVLA

This notebook opens up a fine-tuned SmolVLA model and analyzes:
1. **Attention maps** — What does the VLA look at in the scene?
2. **Action trajectories** — How does flow matching generate action plans?
3. **Language conditioning** — Does the model actually use the instruction?
4. **Failure analysis** — Where and why does it break?
5. **Representation analysis** — What did fine-tuning change in the embedding space?

**Prerequisites**: This notebook loads the pre-finetuned checkpoint from [`jadechoghari/smolvla_metaworld`](https://huggingface.co/jadechoghari/smolvla_metaworld) on HuggingFace Hub. No local training is required.

## 1. Setup

In [None]:
# Clone the project repo so all code is available
!git clone https://github.com/vivpra89/VLA_BD.git /content/VLA_BD
%cd /content/VLA_BD

In [None]:
!pip install -q condacolab
import condacolab
condacolab.install()

In [None]:
!git clone https://github.com/huggingface/lerobot.git
!conda install ffmpeg=7.1.1 -c conda-forge -y
!cd lerobot && pip install -e ".[smolvla]"
!pip install "gymnasium==1.1.0" metaworld matplotlib seaborn scikit-learn

In [None]:
import sys
sys.path.insert(0, "/content/lerobot/src")

import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.gridspec import GridSpec
import seaborn as sns
from sklearn.manifold import TSNE
from collections import defaultdict
import json, os, copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

OUTPUT_DIR = "/content/introspection_outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)

## 2. Load models (pretrained + fine-tuned)

In [None]:
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy

# Load fine-tuned model from HuggingFace Hub
FINETUNED_PATH = "jadechoghari/smolvla_metaworld"
PRETRAINED_PATH = "lerobot/smolvla_base"

print("Loading fine-tuned model...")
model_finetuned = SmolVLAPolicy.from_pretrained(FINETUNED_PATH)
model_finetuned.to(device)
model_finetuned.eval()

print("Loading pretrained (zero-shot) model...")
model_pretrained = SmolVLAPolicy.from_pretrained(PRETRAINED_PATH)
model_pretrained.to(device)
model_pretrained.eval()

print(f"Model parameters: {sum(p.numel() for p in model_finetuned.parameters()) / 1e6:.1f}M")

## 3. Set up Meta-World environment and collect observations

In [None]:
import metaworld
import metaworld.envs

def make_metaworld_env(task_name="assembly-v3"):
    """Create a Meta-World environment and return it with a task."""
    mt = metaworld.MT1(task_name, seed=42)
    env = mt.train_classes[task_name]()
    task = mt.train_tasks[0]
    env.set_task(task)
    return env

def collect_episode(env, policy_model, task_desc, max_steps=200, record_internals=False):
    """Roll out one episode, optionally recording model internals."""
    obs, info = env.reset()
    frames = []
    actions_taken = []
    rewards = []
    internals = []  # will store attention weights, hidden states, etc.
    success = False

    policy_model.reset()

    for step in range(max_steps):
        img = env.render()
        frames.append(img)

        # Build observation dict matching SmolVLA's expected format
        obs_dict = build_obs_dict(obs, img, task_desc, device)

        if record_internals:
            action, step_internals = forward_with_hooks(policy_model, obs_dict)
            internals.append(step_internals)
        else:
            with torch.no_grad():
                action = policy_model.select_action(obs_dict)

        action_np = action.cpu().numpy().flatten()
        actions_taken.append(action_np)

        obs, reward, terminated, truncated, info = env.step(action_np)
        rewards.append(reward)

        if info.get("success", False):
            success = True
            break
        if terminated or truncated:
            break

    return {
        "frames": frames,
        "actions": np.array(actions_taken),
        "rewards": np.array(rewards),
        "success": success,
        "internals": internals,
        "num_steps": len(frames),
    }


def build_obs_dict(obs_array, image, task_desc, device):
    """Convert raw Meta-World observation + rendered image into SmolVLA-compatible dict.
    
    NOTE: This is a simplified adapter. In the full LeRobot pipeline, the env wrapper
    handles this. You may need to adjust keys based on the actual model config.
    """
    from lerobot.utils.constants import OBS_STATE, OBS_LANGUAGE_TOKENS, OBS_LANGUAGE_ATTENTION_MASK

    # Image: HWC uint8 -> BCHW float [0,1]
    img_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
    img_tensor = img_tensor.unsqueeze(0).to(device)

    # State: proprioceptive state from Meta-World obs
    state = torch.from_numpy(obs_array[:18].astype(np.float32)).unsqueeze(0).to(device)

    # Language tokens: tokenize with the model's processor
    tokenizer = model_finetuned.model.vlm_with_expert.processor.tokenizer
    tokens = tokenizer(task_desc, return_tensors="pt", padding="max_length", max_length=32, truncation=True)

    obs_dict = {
        "observation.images.top": img_tensor,
        OBS_STATE: state,
        OBS_LANGUAGE_TOKENS: tokens["input_ids"].to(device),
        OBS_LANGUAGE_ATTENTION_MASK: tokens["attention_mask"].bool().to(device),
    }
    return obs_dict

print("Helper functions defined.")

## 4. Attention Visualization: "What does the VLA look at?"

We hook into the vision encoder (SigLIP) and cross-attention layers to capture attention weights,
then overlay them on the input image as heatmaps.

In [None]:
class AttentionCapture:
    """Context manager that hooks into model layers to capture attention weights."""

    def __init__(self, model):
        self.model = model
        self.hooks = []
        self.captured = {}

    def __enter__(self):
        vlm_model = self.model.model.vlm_with_expert

        # Hook into the eager_attention_forward to capture attention probs
        original_attn = vlm_model.eager_attention_forward
        self._original_attn = original_attn
        self._call_count = 0
        capture_ref = self

        def hooked_attn(attention_mask, batch_size, head_dim, query_states, key_states, value_states):
            # Compute attention weights manually to capture them
            q = query_states.to(dtype=torch.float32).transpose(1, 2)
            k = key_states.to(dtype=torch.float32).transpose(1, 2)

            att_weights = torch.matmul(q, k.transpose(2, 3)) * (head_dim ** -0.5)
            big_neg = torch.finfo(att_weights.dtype).min
            masked = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
            probs = torch.nn.functional.softmax(masked, dim=-1)

            capture_ref.captured[f"attn_layer_{capture_ref._call_count}"] = probs.detach().cpu()
            capture_ref._call_count += 1

            # Continue with original computation
            return original_attn(attention_mask, batch_size, head_dim, query_states, key_states, value_states)

        vlm_model.eager_attention_forward = hooked_attn
        return self

    def __exit__(self, *args):
        self.model.model.vlm_with_expert.eager_attention_forward = self._original_attn


def visualize_attention_over_image(image, attn_weights, title="Attention Map"):
    """Overlay averaged attention weights on the input image.
    
    attn_weights: (num_heads, query_len, key_len) — we average over heads and
    focus on what the action tokens attend to in the image region.
    """
    # Average over heads
    attn_avg = attn_weights.mean(dim=0)  # (query_len, key_len)

    # The first N tokens in the key dimension correspond to image patches.
    # SmolVLA uses 64 visual tokens per image (after pixel shuffle compression).
    num_img_tokens = 64
    img_grid_size = int(np.sqrt(num_img_tokens))  # 8x8

    # Sum attention from all query positions to image tokens
    img_attn = attn_avg[:, :num_img_tokens].sum(dim=0)  # (num_img_tokens,)
    img_attn = img_attn / img_attn.max()  # normalize to [0,1]

    # Reshape to spatial grid
    attn_map = img_attn.reshape(img_grid_size, img_grid_size).numpy()

    # Resize to image dimensions
    from scipy.ndimage import zoom
    h, w = image.shape[:2]
    attn_resized = zoom(attn_map, (h / img_grid_size, w / img_grid_size), order=1)

    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    axes[0].imshow(image)
    axes[0].set_title("Input Image")
    axes[0].axis("off")

    axes[1].imshow(attn_resized, cmap="hot", interpolation="bilinear")
    axes[1].set_title("Attention Heatmap")
    axes[1].axis("off")

    axes[2].imshow(image)
    axes[2].imshow(attn_resized, cmap="jet", alpha=0.5, interpolation="bilinear")
    axes[2].set_title("Overlay")
    axes[2].axis("off")

    plt.suptitle(title, fontsize=14, fontweight="bold")
    plt.tight_layout()
    return fig

print("Attention capture tools defined.")

In [None]:
# Run attention capture on a single observation
env = make_metaworld_env("assembly-v3")
obs, _ = env.reset()
img = env.render()
task_desc = "assemble the peg into the hole"

obs_dict = build_obs_dict(obs, img, task_desc, device)

# Capture attention for fine-tuned model
with AttentionCapture(model_finetuned) as cap_ft:
    with torch.no_grad():
        model_finetuned.reset()
        _ = model_finetuned.select_action(obs_dict)

# Capture attention for pretrained model
with AttentionCapture(model_pretrained) as cap_pt:
    with torch.no_grad():
        model_pretrained.reset()
        _ = model_pretrained.select_action(obs_dict)

print(f"Captured {len(cap_ft.captured)} attention layers (fine-tuned)")
print(f"Captured {len(cap_pt.captured)} attention layers (pretrained)")

# Visualize early layer attention (layer 2 — low-level features)
if "attn_layer_2" in cap_ft.captured:
    fig = visualize_attention_over_image(
        img, cap_ft.captured["attn_layer_2"][0],
        title="Fine-tuned: Early Layer Attention (Layer 2)"
    )
    fig.savefig(f"{OUTPUT_DIR}/attention_finetuned_layer2.png", dpi=150, bbox_inches="tight")
    plt.show()

if "attn_layer_2" in cap_pt.captured:
    fig = visualize_attention_over_image(
        img, cap_pt.captured["attn_layer_2"][0],
        title="Pretrained (zero-shot): Early Layer Attention (Layer 2)"
    )
    fig.savefig(f"{OUTPUT_DIR}/attention_pretrained_layer2.png", dpi=150, bbox_inches="tight")
    plt.show()

### Attention across task phases

Capture attention at different points during an episode to show how focus shifts:
- **Reaching**: model should attend to the target object
- **Grasping**: attention should shift to gripper-object contact
- **Moving**: attention on destination

In [None]:
env = make_metaworld_env("assembly-v3")
obs, _ = env.reset()
model_finetuned.reset()

phase_frames = {"early (reaching)": None, "mid (grasping)": None, "late (placing)": None}
phase_attns = {"early (reaching)": None, "mid (grasping)": None, "late (placing)": None}
phase_steps = {"early (reaching)": 10, "mid (grasping)": 50, "late (placing)": 100}

for step in range(150):
    img = env.render()
    obs_dict = build_obs_dict(obs, img, "assemble the peg into the hole", device)

    for phase_name, target_step in phase_steps.items():
        if step == target_step:
            with AttentionCapture(model_finetuned) as cap:
                with torch.no_grad():
                    action = model_finetuned.select_action(obs_dict)
            phase_frames[phase_name] = img.copy()
            # Use a middle layer for the clearest signal
            mid_layer = f"attn_layer_{len(cap.captured) // 2}"
            if mid_layer in cap.captured:
                phase_attns[phase_name] = cap.captured[mid_layer][0]
            continue

    with torch.no_grad():
        action = model_finetuned.select_action(obs_dict)

    action_np = action.cpu().numpy().flatten()
    obs, reward, terminated, truncated, info = env.step(action_np)
    if terminated or truncated:
        break

# Plot attention across phases
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
for idx, (phase_name, frame) in enumerate(phase_frames.items()):
    if frame is not None:
        axes[0, idx].imshow(frame)
        axes[0, idx].set_title(f"Step {phase_steps[phase_name]}: {phase_name}", fontsize=11)
        axes[0, idx].axis("off")

        if phase_attns[phase_name] is not None:
            attn = phase_attns[phase_name].mean(dim=0)
            num_img_tokens = min(64, attn.shape[-1])
            img_attn = attn[:, :num_img_tokens].sum(dim=0)
            img_attn = img_attn / (img_attn.max() + 1e-8)
            grid_size = int(np.sqrt(num_img_tokens))
            attn_map = img_attn[:grid_size**2].reshape(grid_size, grid_size).numpy()

            from scipy.ndimage import zoom
            h, w = frame.shape[:2]
            attn_resized = zoom(attn_map, (h / grid_size, w / grid_size), order=1)

            axes[1, idx].imshow(frame)
            axes[1, idx].imshow(attn_resized, cmap="jet", alpha=0.5, interpolation="bilinear")
            axes[1, idx].set_title(f"Attention overlay", fontsize=11)
            axes[1, idx].axis("off")

plt.suptitle("Attention Shift Across Task Phases (assembly-v3)", fontsize=14, fontweight="bold")
plt.tight_layout()
fig.savefig(f"{OUTPUT_DIR}/attention_phases.png", dpi=150, bbox_inches="tight")
plt.show()

## 5. Action Trajectory Visualization: "What does the VLA plan?"

SmolVLA uses **flow matching** to generate action chunks. Starting from pure noise,
it iteratively denoises to produce a sequence of future actions.

We visualize:
- The denoising trajectory: noise -> clean actions across denoising steps
- The predicted action chunk (xyz gripper positions)
- Action distribution from multiple noise samples

In [None]:
def capture_denoising_trajectory(model, obs_dict, num_noise_samples=10):
    """Run flow matching inference and capture the intermediate x_t at each denoising step."""
    model.eval()
    flow_model = model.model  # VLAFlowMatching

    # Prepare inputs
    images, img_masks = model.prepare_images(obs_dict)
    state = model.prepare_state(obs_dict)
    lang_tokens = obs_dict["observation.language_tokens"]
    lang_masks = obs_dict["observation.language_attention_mask"]

    bsize = state.shape[0]
    actions_shape = (bsize, flow_model.config.chunk_size, flow_model.config.max_action_dim)

    all_trajectories = []  # list of (num_denoise_steps, chunk_size, action_dim)

    for sample_idx in range(num_noise_samples):
        noise = flow_model.sample_noise(actions_shape, device)

        # Compute prefix KV cache
        prefix_embs, prefix_pad_masks, prefix_att_masks = flow_model.embed_prefix(
            images, img_masks, lang_tokens, lang_masks, state=state
        )
        prefix_att_2d_masks = flow_model.make_att_2d_masks(prefix_pad_masks, prefix_att_masks) if hasattr(flow_model, 'make_att_2d_masks') else None

        # Manual denoising loop to capture intermediates
        num_steps = flow_model.config.num_steps
        dt = -1.0 / num_steps
        x_t = noise.clone()
        trajectory = [x_t.detach().cpu().clone()]

        for step in range(num_steps):
            time = 1.0 + step * dt
            time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)

            suffix_embs, suffix_pad_masks, suffix_att_masks = flow_model.embed_suffix(x_t, time_tensor)
            # Simplified: we capture x_t at each step
            # In practice you'd run the full denoise_step, but the key insight
            # is the progression from noise to structured actions
            x_t = x_t + dt * flow_model.denoise_step(
                prefix_pad_masks=prefix_pad_masks,
                past_key_values=None,  # recompute for simplicity
                x_t=x_t,
                timestep=time_tensor,
            )
            trajectory.append(x_t.detach().cpu().clone())

        all_trajectories.append(torch.stack(trajectory))  # (num_steps+1, B, chunk, action_dim)

    return all_trajectories


def plot_denoising_process(trajectories, action_dims=(0, 1, 2), labels=("dx", "dy", "dz")):
    """Plot how actions evolve from noise to clean predictions across denoising steps."""
    fig, axes = plt.subplots(1, len(action_dims), figsize=(5 * len(action_dims), 4))
    if len(action_dims) == 1:
        axes = [axes]

    traj = trajectories[0]  # (num_steps+1, B, chunk_size, action_dim)
    num_steps = traj.shape[0]

    for ax_idx, (dim, label) in enumerate(zip(action_dims, labels)):
        for chunk_step in range(min(traj.shape[2], 10)):  # show first 10 action steps
            values = traj[:, 0, chunk_step, dim].numpy()
            alpha = 1.0 - 0.07 * chunk_step
            axes[ax_idx].plot(range(num_steps), values, alpha=max(alpha, 0.2), linewidth=1.5)

        axes[ax_idx].set_xlabel("Denoising Step")
        axes[ax_idx].set_ylabel(f"Action: {label}")
        axes[ax_idx].set_title(f"Flow Matching Denoising: {label}")
        axes[ax_idx].grid(True, alpha=0.3)
        axes[ax_idx].axhline(y=0, color="black", linewidth=0.5, linestyle="--")

    plt.suptitle("Noise → Clean Actions (each line = one action step in the chunk)",
                 fontsize=12, fontweight="bold")
    plt.tight_layout()
    return fig


def plot_action_distribution(trajectories, action_dims=(0, 1)):
    """Show distribution of final predicted actions from different noise seeds.
    Tight distribution = model is confident. Wide = uncertain.
    """
    final_actions = []
    for traj in trajectories:
        final = traj[-1, 0, :, :]  # (chunk_size, action_dim)
        final_actions.append(final.numpy())

    final_actions = np.array(final_actions)  # (num_samples, chunk_size, action_dim)

    fig, ax = plt.subplots(figsize=(8, 6))
    colors = plt.cm.viridis(np.linspace(0, 1, len(final_actions)))

    for i, actions in enumerate(final_actions):
        ax.plot(actions[:, action_dims[0]], actions[:, action_dims[1]],
                "o-", color=colors[i], alpha=0.6, markersize=3, linewidth=1)

    # Show mean trajectory
    mean_actions = final_actions.mean(axis=0)
    ax.plot(mean_actions[:, action_dims[0]], mean_actions[:, action_dims[1]],
            "k*-", markersize=8, linewidth=2, label="Mean trajectory")

    ax.set_xlabel(f"Action dim {action_dims[0]} (dx)")
    ax.set_ylabel(f"Action dim {action_dims[1]} (dy)")
    ax.set_title("Action Distribution from Multiple Noise Samples\n(tight = confident, spread = uncertain)")
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    return fig

print("Action trajectory tools defined.")

In [None]:
# NOTE: The capture_denoising_trajectory function above requires access to the model's
# internal denoising loop. If the KV cache handling is complex, use this simpler approach:
# Run model.predict_action_chunk() multiple times with different noise seeds.

env = make_metaworld_env("assembly-v3")
obs, _ = env.reset()
img = env.render()
obs_dict = build_obs_dict(obs, img, "assemble the peg into the hole", device)

# Collect action predictions from multiple noise seeds
all_action_chunks = []
for seed in range(10):
    torch.manual_seed(seed)
    model_finetuned.reset()
    with torch.no_grad():
        chunk = model_finetuned.predict_action_chunk(obs_dict)
    all_action_chunks.append(chunk.cpu().numpy())

all_action_chunks = np.array(all_action_chunks)  # (10, 1, chunk_size, action_dim)
all_action_chunks = all_action_chunks[:, 0, :, :]  # (10, chunk_size, action_dim)

# Plot action distribution (xy plane)
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

dim_pairs = [(0, 1, "dx vs dy"), (0, 2, "dx vs dz"), (1, 2, "dy vs dz")]
for ax, (d1, d2, title) in zip(axes, dim_pairs):
    colors = plt.cm.viridis(np.linspace(0, 1, 10))
    for i in range(10):
        ax.plot(all_action_chunks[i, :, d1], all_action_chunks[i, :, d2],
                "o-", color=colors[i], alpha=0.5, markersize=2, linewidth=1)
    mean = all_action_chunks.mean(axis=0)
    ax.plot(mean[:, d1], mean[:, d2], "r*-", markersize=6, linewidth=2, label="Mean")
    ax.set_title(title)
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.suptitle("Predicted Action Chunks (10 noise seeds)\nTight bundle = confident, spread = uncertain",
             fontsize=13, fontweight="bold")
plt.tight_layout()
fig.savefig(f"{OUTPUT_DIR}/action_distribution.png", dpi=150, bbox_inches="tight")
plt.show()

## 6. Language Conditioning Probe: "Does it actually read the instruction?"

Test whether changing the language instruction actually changes the predicted actions.
If the model ignores language, actions should be identical regardless of instruction.

In [None]:
env = make_metaworld_env("assembly-v3")
obs, _ = env.reset()
img = env.render()

instructions = [
    "assemble the peg into the hole",
    "turn the dial clockwise",
    "press the handle from the side",
    "pick up the red block",
    "asdkjh random nonsense words xyz",  # scrambled
]

# Same visual observation, different instructions
torch.manual_seed(42)  # fix noise seed for fair comparison
action_by_instruction = {}

for instr in instructions:
    obs_dict = build_obs_dict(obs, img, instr, device)
    torch.manual_seed(42)
    model_finetuned.reset()
    with torch.no_grad():
        chunk = model_finetuned.predict_action_chunk(obs_dict)
    action_by_instruction[instr] = chunk.cpu().numpy()[0]  # (chunk_size, action_dim)

# Compute pairwise cosine similarity of action chunks
from sklearn.metrics.pairwise import cosine_similarity

flat_actions = np.array([v.flatten() for v in action_by_instruction.values()])
sim_matrix = cosine_similarity(flat_actions)

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Similarity heatmap
short_labels = [instr[:25] + "..." if len(instr) > 25 else instr for instr in instructions]
sns.heatmap(sim_matrix, xticklabels=short_labels, yticklabels=short_labels,
            annot=True, fmt=".2f", cmap="RdYlGn", vmin=0, vmax=1, ax=axes[0])
axes[0].set_title("Cosine Similarity of Action Chunks\n(same image, different instructions)", fontsize=11)
axes[0].tick_params(axis='x', rotation=30)
axes[0].tick_params(axis='y', rotation=0)

# Action trajectories overlay
colors = plt.cm.tab10(np.linspace(0, 1, len(instructions)))
for i, (instr, actions) in enumerate(action_by_instruction.items()):
    label = instr[:30] + "..." if len(instr) > 30 else instr
    axes[1].plot(actions[:, 0], actions[:, 1], "o-", color=colors[i],
                 alpha=0.7, markersize=3, linewidth=1.5, label=label)

axes[1].set_xlabel("Action dim 0 (dx)")
axes[1].set_ylabel("Action dim 1 (dy)")
axes[1].set_title("Action Trajectories by Instruction\n(divergence = language matters)", fontsize=11)
axes[1].legend(fontsize=8, loc="best")
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
fig.savefig(f"{OUTPUT_DIR}/language_conditioning.png", dpi=150, bbox_inches="tight")
plt.show()

print("\nKey insight: If the model uses language, different instructions should produce")
print("low similarity scores and divergent trajectories. Scrambled text should differ from all.")

## 7. Failure Analysis: "Where and why does it break?"

Run multiple episodes, categorize failures, and visualize what the model was
attending to and predicting at the moment of failure.

In [None]:
def run_eval_episodes(env, model, task_desc, n_episodes=20, max_steps=200):
    """Run multiple episodes and categorize outcomes."""
    episodes = []
    for ep in range(n_episodes):
        obs, _ = env.reset()
        model.reset()
        frames = []
        actions = []
        rewards = []
        gripper_positions = []

        for step in range(max_steps):
            img = env.render()
            frames.append(img)
            obs_dict = build_obs_dict(obs, img, task_desc, device)

            with torch.no_grad():
                action = model.select_action(obs_dict)

            action_np = action.cpu().numpy().flatten()
            actions.append(action_np)
            gripper_positions.append(obs[:3].copy())  # gripper xyz

            obs, reward, terminated, truncated, info = env.step(action_np)
            rewards.append(reward)

            if info.get("success", False):
                break
            if terminated or truncated:
                break

        episodes.append({
            "success": info.get("success", False),
            "frames": frames,
            "actions": np.array(actions),
            "rewards": np.array(rewards),
            "gripper_positions": np.array(gripper_positions),
            "num_steps": len(frames),
            "total_reward": sum(rewards),
        })
        status = "SUCCESS" if info.get("success", False) else "FAIL"
        print(f"  Episode {ep+1}/{n_episodes}: {status} ({len(frames)} steps, reward={sum(rewards):.2f})")

    return episodes

# Run evaluation
env = make_metaworld_env("assembly-v3")
print("Running evaluation episodes...")
episodes = run_eval_episodes(env, model_finetuned, "assemble the peg into the hole", n_episodes=20)

successes = [e for e in episodes if e["success"]]
failures = [e for e in episodes if not e["success"]]
print(f"\nSuccess rate: {len(successes)}/{len(episodes)} ({100*len(successes)/len(episodes):.0f}%)")

In [None]:
# Analyze failure patterns
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Episode length distribution
success_lengths = [e["num_steps"] for e in successes]
failure_lengths = [e["num_steps"] for e in failures]

axes[0, 0].hist(success_lengths, bins=15, alpha=0.7, color="green", label="Success")
axes[0, 0].hist(failure_lengths, bins=15, alpha=0.7, color="red", label="Failure")
axes[0, 0].set_xlabel("Episode Length (steps)")
axes[0, 0].set_ylabel("Count")
axes[0, 0].set_title("Episode Length Distribution")
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# 2. Reward curves
for ep in successes[:3]:
    axes[0, 1].plot(np.cumsum(ep["rewards"]), color="green", alpha=0.5)
for ep in failures[:3]:
    axes[0, 1].plot(np.cumsum(ep["rewards"]), color="red", alpha=0.5)
axes[0, 1].set_xlabel("Step")
axes[0, 1].set_ylabel("Cumulative Reward")
axes[0, 1].set_title("Reward Curves (green=success, red=failure)")
axes[0, 1].grid(True, alpha=0.3)

# 3. Gripper trajectories (top-down view, xy plane)
for ep in successes[:5]:
    pos = ep["gripper_positions"]
    axes[1, 0].plot(pos[:, 0], pos[:, 1], "g-", alpha=0.4, linewidth=1)
    axes[1, 0].plot(pos[0, 0], pos[0, 1], "go", markersize=4)
    axes[1, 0].plot(pos[-1, 0], pos[-1, 1], "g^", markersize=6)

for ep in failures[:5]:
    pos = ep["gripper_positions"]
    axes[1, 0].plot(pos[:, 0], pos[:, 1], "r-", alpha=0.4, linewidth=1)
    axes[1, 0].plot(pos[0, 0], pos[0, 1], "ro", markersize=4)
    axes[1, 0].plot(pos[-1, 0], pos[-1, 1], "r^", markersize=6)

axes[1, 0].set_xlabel("X position")
axes[1, 0].set_ylabel("Y position")
axes[1, 0].set_title("Gripper Trajectories (top-down)\n(circle=start, triangle=end)")
axes[1, 0].grid(True, alpha=0.3)

# 4. Action magnitude over time (smoothness indicator)
for ep in successes[:3]:
    magnitudes = np.linalg.norm(ep["actions"][:, :3], axis=1)
    axes[1, 1].plot(magnitudes, color="green", alpha=0.5)
for ep in failures[:3]:
    magnitudes = np.linalg.norm(ep["actions"][:, :3], axis=1)
    axes[1, 1].plot(magnitudes, color="red", alpha=0.5)
axes[1, 1].set_xlabel("Step")
axes[1, 1].set_ylabel("Action Magnitude")
axes[1, 1].set_title("Action Magnitude Over Time\n(erratic = loss of control)")
axes[1, 1].grid(True, alpha=0.3)

plt.suptitle("Failure Analysis: assembly-v3", fontsize=14, fontweight="bold")
plt.tight_layout()
fig.savefig(f"{OUTPUT_DIR}/failure_analysis.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# Show frames from a successful vs failed episode side by side
if successes and failures:
    fig, axes = plt.subplots(2, 5, figsize=(20, 8))

    best_success = max(successes, key=lambda e: e["total_reward"])
    worst_failure = min(failures, key=lambda e: e["total_reward"])

    for row, (ep, label) in enumerate([(best_success, "SUCCESS"), (worst_failure, "FAILURE")]):
        n = len(ep["frames"])
        indices = [int(i * (n - 1) / 4) for i in range(5)]
        for col, idx in enumerate(indices):
            axes[row, col].imshow(ep["frames"][idx])
            axes[row, col].set_title(f"{label} — step {idx}", fontsize=10,
                                      color="green" if label == "SUCCESS" else "red")
            axes[row, col].axis("off")

    plt.suptitle("Best Success vs Worst Failure (assembly-v3)", fontsize=14, fontweight="bold")
    plt.tight_layout()
    fig.savefig(f"{OUTPUT_DIR}/success_vs_failure_frames.png", dpi=150, bbox_inches="tight")
    plt.show()

## 8. Representation Analysis: "What did fine-tuning change?"

Extract hidden states from the VLM backbone for different tasks, and compare
how the representation space looks before vs after fine-tuning using t-SNE.

In [None]:
def extract_hidden_states(model, env, task_desc, n_obs=50):
    """Extract VLM hidden states for n observations from the environment."""
    hidden_states = []
    obs, _ = env.reset()
    model.reset()

    for i in range(n_obs):
        img = env.render()
        obs_dict = build_obs_dict(obs, img, task_desc, device)

        # Run the prefix embedding (VLM) and capture the output
        images, img_masks = model.prepare_images(obs_dict)
        state = model.prepare_state(obs_dict)
        lang_tokens = obs_dict["observation.language_tokens"]
        lang_masks = obs_dict["observation.language_attention_mask"]

        with torch.no_grad():
            prefix_embs, _, _ = model.model.embed_prefix(
                images, img_masks, lang_tokens, lang_masks, state=state
            )
            # Use mean-pooled prefix embeddings as the representation
            h = prefix_embs.mean(dim=1).cpu().numpy()  # (1, hidden_dim)
            hidden_states.append(h[0])

        # Step the environment to get diverse observations
        with torch.no_grad():
            action = model.select_action(obs_dict)
        obs, _, terminated, truncated, _ = env.step(action.cpu().numpy().flatten())
        if terminated or truncated:
            obs, _ = env.reset()
            model.reset()

    return np.array(hidden_states)  # (n_obs, hidden_dim)

# Collect representations for 3 tasks, for both models
tasks_for_analysis = [
    ("assembly-v3", "assemble the peg into the hole"),
    ("dial-turn-v3", "turn the dial clockwise"),
    ("handle-press-side-v3", "press the handle from the side"),
]

n_obs_per_task = 30
all_hidden_ft = []  # fine-tuned
all_hidden_pt = []  # pretrained
all_labels = []

for task_name, task_desc in tasks_for_analysis:
    print(f"Extracting representations for {task_name}...")
    env = make_metaworld_env(task_name)

    h_ft = extract_hidden_states(model_finetuned, env, task_desc, n_obs=n_obs_per_task)
    h_pt = extract_hidden_states(model_pretrained, env, task_desc, n_obs=n_obs_per_task)

    all_hidden_ft.append(h_ft)
    all_hidden_pt.append(h_pt)
    all_labels.extend([task_name.replace("-v3", "")] * n_obs_per_task)

all_hidden_ft = np.concatenate(all_hidden_ft)
all_hidden_pt = np.concatenate(all_hidden_pt)
print(f"Collected {len(all_labels)} representations per model")

In [None]:
# Run t-SNE and compare
fig, axes = plt.subplots(1, 2, figsize=(16, 7))

unique_tasks = list(set(all_labels))
task_colors = {task: plt.cm.tab10(i) for i, task in enumerate(unique_tasks)}

for ax, (hidden, title) in zip(axes, [
    (all_hidden_pt, "Pretrained (zero-shot)"),
    (all_hidden_ft, "Fine-tuned on Meta-World"),
]):
    tsne = TSNE(n_components=2, perplexity=15, random_state=42, n_iter=1000)
    embedded = tsne.fit_transform(hidden)

    for task in unique_tasks:
        mask = np.array(all_labels) == task
        ax.scatter(embedded[mask, 0], embedded[mask, 1],
                   c=[task_colors[task]], label=task, alpha=0.7, s=40, edgecolors="white", linewidth=0.5)

    ax.set_title(title, fontsize=13, fontweight="bold")
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.2)
    ax.set_xlabel("t-SNE dim 1")
    ax.set_ylabel("t-SNE dim 2")

plt.suptitle("Representation Space: Before vs After Fine-Tuning\n"
             "(clear clusters after fine-tuning = model learned task structure)",
             fontsize=14, fontweight="bold")
plt.tight_layout()
fig.savefig(f"{OUTPUT_DIR}/tsne_comparison.png", dpi=150, bbox_inches="tight")
plt.show()

## 9. Summary of all visualizations

All figures saved to the output directory. Use these directly in your presentation slides.

In [None]:
import glob
saved_files = sorted(glob.glob(f"{OUTPUT_DIR}/*.png"))
print("Saved visualizations:")
for f in saved_files:
    print(f"  {f}")

print(f"\nTotal: {len(saved_files)} figures")
print("\nTo download: Use Colab's file browser (left panel) or run:")
print(f"  !zip -r /content/introspection_figures.zip {OUTPUT_DIR}")
print("  from google.colab import files; files.download('/content/introspection_figures.zip')")

In [None]:
# Package all figures for download
!zip -r /content/introspection_figures.zip {OUTPUT_DIR}
from google.colab import files
files.download("/content/introspection_figures.zip")