# Figure 4 Reproduction: Temporal Attention Visualization

This notebook reproduces Figure 4 from the DeepForcing paper, showing how attention weights
are distributed across frames during autoregressive video generation.

**Expected Pattern (from the paper):**
- High attention at the beginning (initial frames provide strong context)
- Attention drops in the middle
- High attention again for recent/current frames

## Prerequisites

Run the extraction script first:
```bash
python run_extraction.py \
    --config_path configs/self_forcing_dmd.yaml \
    --checkpoint_path checkpoints/self_forcing_dmd.pt \
    --output_path attention_cache.pt
```

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

## 1. Load Attention Weights

In [None]:
# Load the cached attention weights
CACHE_PATH = "attention_cache.pt"

if not Path(CACHE_PATH).exists():
    raise FileNotFoundError(
        f"Attention cache not found at {CACHE_PATH}. "
        "Please run run_extraction.py first."
    )

data = torch.load(CACHE_PATH, map_location='cpu')

print("Loaded attention data:")
print(f"  - Prompt: {data.get('prompt', 'N/A')}")
print(f"  - Number of frames: {data.get('num_frames', 'N/A')}")
print(f"  - Frame sequence length: {data.get('frame_seq_length', 'N/A')}")
print(f"  - Frames per block: {data.get('num_frame_per_block', 'N/A')}")
print(f"  - Number of transformer blocks: {data.get('num_transformer_blocks', 'N/A')}")
print(f"  - Captured layer indices: {data.get('layer_indices', 'N/A')}")
print(f"  - Number of captured weight tensors: {len(data['attention_weights'])}")

In [None]:
# Inspect the captured attention weights
attention_weights = data['attention_weights']

print("\nCaptured attention weight tensors:")
for i, w in enumerate(attention_weights):
    print(f"  [{i}] Layer {w['layer_idx']}: "
          f"attn_shape={w['attn_weights'].shape}, "
          f"q_shape={w['q_shape']}, k_shape={w['k_shape']}")

## 2. Process Attention Weights

The attention tensor shape is typically `(Batch, Heads, Query_Len, Key_Len)`.

For Figure 4, we need to:
1. Select a specific layer and head
2. Average across the Query dimension to get the "average attention per key position"
3. Map token positions to frame indices

In [None]:
def process_attention_for_figure4(
    attention_data: dict,
    weight_idx: int = 0,
    head_idx: int = 0,
    frame_seq_length: int = None,
):
    """
    Process attention weights for Figure 4 visualization.
    
    Args:
        attention_data: The loaded attention cache data
        weight_idx: Index of the captured weight tensor to use
        head_idx: Attention head index to visualize
        frame_seq_length: Number of tokens per frame (default: from data)
    
    Returns:
        frame_attention: Average attention weight per frame
        frame_indices: Frame indices
    """
    if frame_seq_length is None:
        frame_seq_length = attention_data.get('frame_seq_length', 1560)
    
    # Get the attention weights: shape [B, N, Lq, Lk]
    attn = attention_data['attention_weights'][weight_idx]['attn_weights']
    print(f"Raw attention shape: {attn.shape}")
    
    # Extract batch=0, head=head_idx
    # Shape: [Lq, Lk]
    attn_matrix = attn[0, head_idx].float().numpy()
    print(f"Selected head attention shape: {attn_matrix.shape}")
    
    lq, lk = attn_matrix.shape
    
    # Calculate number of frames
    num_query_frames = lq // frame_seq_length
    num_key_frames = lk // frame_seq_length
    print(f"Query frames: {num_query_frames}, Key frames: {num_key_frames}")
    
    # Method 1: Average across all query positions
    # This gives us "how much attention does each key position receive on average"
    avg_attention_per_key = attn_matrix.mean(axis=0)  # [Lk]
    
    # Group by frame: average attention per frame
    frame_attention = []
    for f in range(num_key_frames):
        start = f * frame_seq_length
        end = min((f + 1) * frame_seq_length, lk)
        frame_avg = avg_attention_per_key[start:end].mean()
        frame_attention.append(frame_avg)
    
    frame_attention = np.array(frame_attention)
    frame_indices = np.arange(num_key_frames)
    
    return frame_attention, frame_indices, attn_matrix

In [None]:
# Process the first captured attention tensor
if len(attention_weights) > 0:
    frame_attention, frame_indices, raw_attn = process_attention_for_figure4(
        data, weight_idx=0, head_idx=0
    )
    
    print(f"\nFrame attention shape: {frame_attention.shape}")
    print(f"Frame indices: {frame_indices}")
    print(f"Attention values: {frame_attention}")
else:
    print("No attention weights captured!")

## 3. Plot Figure 4: Attention Distribution Across Frames

In [None]:
def plot_figure4(
    frame_attention: np.ndarray,
    frame_indices: np.ndarray,
    title: str = "Temporal Attention Distribution (Figure 4)",
    save_path: str = None,
):
    """
    Create the Figure 4 plot showing attention distribution across frames.
    """
    fig, ax = plt.subplots(figsize=(12, 6))
    
    # Plot the attention distribution
    ax.plot(frame_indices, frame_attention, 'o-', linewidth=2, markersize=8, 
            color='#2E86AB', label='Average Attention')
    
    # Fill under the curve
    ax.fill_between(frame_indices, frame_attention, alpha=0.3, color='#2E86AB')
    
    # Labels and title
    ax.set_xlabel('Key Frame Index', fontsize=14)
    ax.set_ylabel('Average Attention Weight', fontsize=14)
    ax.set_title(title, fontsize=16, fontweight='bold')
    
    # Set x-axis to show all frame indices
    ax.set_xticks(frame_indices)
    ax.set_xlim(frame_indices[0] - 0.5, frame_indices[-1] + 0.5)
    
    # Add grid
    ax.grid(True, alpha=0.3)
    
    # Add annotations for interesting patterns
    max_idx = np.argmax(frame_attention)
    min_idx = np.argmin(frame_attention)
    
    ax.annotate(f'Max: {frame_attention[max_idx]:.4f}',
                xy=(frame_indices[max_idx], frame_attention[max_idx]),
                xytext=(10, 10), textcoords='offset points',
                fontsize=10, color='green',
                arrowprops=dict(arrowstyle='->', color='green', alpha=0.7))
    
    ax.annotate(f'Min: {frame_attention[min_idx]:.4f}',
                xy=(frame_indices[min_idx], frame_attention[min_idx]),
                xytext=(10, -20), textcoords='offset points',
                fontsize=10, color='red',
                arrowprops=dict(arrowstyle='->', color='red', alpha=0.7))
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved figure to {save_path}")
    
    plt.show()
    return fig

In [None]:
# Create the Figure 4 plot
if len(attention_weights) > 0:
    fig = plot_figure4(
        frame_attention, 
        frame_indices,
        title="Temporal Self-Attention Distribution (Reproduction of Figure 4)",
        save_path="figure4_reproduction.png"
    )

## 4. Multi-Head Analysis

Let's visualize attention patterns across different heads to see if they show different behaviors.

In [None]:
def plot_multi_head_attention(
    attention_data: dict,
    weight_idx: int = 0,
    num_heads_to_show: int = 8,
    save_path: str = None,
):
    """
    Plot attention distributions for multiple heads.
    """
    attn = attention_data['attention_weights'][weight_idx]['attn_weights']
    total_heads = attn.shape[1]
    num_heads_to_show = min(num_heads_to_show, total_heads)
    
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    axes = axes.flatten()
    
    for head_idx in range(num_heads_to_show):
        frame_attention, frame_indices, _ = process_attention_for_figure4(
            attention_data, weight_idx=weight_idx, head_idx=head_idx
        )
        
        ax = axes[head_idx]
        ax.plot(frame_indices, frame_attention, 'o-', linewidth=1.5, markersize=4)
        ax.fill_between(frame_indices, frame_attention, alpha=0.3)
        ax.set_title(f'Head {head_idx}', fontsize=12)
        ax.set_xlabel('Frame Index', fontsize=10)
        ax.set_ylabel('Avg Attention', fontsize=10)
        ax.grid(True, alpha=0.3)
    
    plt.suptitle('Attention Distribution Across Different Heads', fontsize=14, fontweight='bold')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved figure to {save_path}")
    
    plt.show()
    return fig

In [None]:
# Plot multi-head attention
if len(attention_weights) > 0:
    fig = plot_multi_head_attention(
        data,
        weight_idx=0,
        num_heads_to_show=8,
        save_path="figure4_multihead.png"
    )

## 5. Full Attention Matrix Heatmap

Visualize the complete attention matrix to understand Query-Key relationships.

In [None]:
def plot_attention_heatmap(
    attn_matrix: np.ndarray,
    frame_seq_length: int = 1560,
    downsample_factor: int = 100,
    title: str = "Attention Matrix Heatmap",
    save_path: str = None,
):
    """
    Plot the full attention matrix as a heatmap.
    """
    # Downsample for visualization
    lq, lk = attn_matrix.shape
    
    # Use frame-level averaging instead of token-level
    num_q_frames = lq // frame_seq_length
    num_k_frames = lk // frame_seq_length
    
    # Create frame-level attention matrix
    frame_attn = np.zeros((num_q_frames, num_k_frames))
    for qi in range(num_q_frames):
        for ki in range(num_k_frames):
            q_start, q_end = qi * frame_seq_length, min((qi + 1) * frame_seq_length, lq)
            k_start, k_end = ki * frame_seq_length, min((ki + 1) * frame_seq_length, lk)
            frame_attn[qi, ki] = attn_matrix[q_start:q_end, k_start:k_end].mean()
    
    fig, ax = plt.subplots(figsize=(10, 8))
    
    im = ax.imshow(frame_attn, cmap='viridis', aspect='auto')
    plt.colorbar(im, ax=ax, label='Attention Weight')
    
    ax.set_xlabel('Key Frame Index', fontsize=12)
    ax.set_ylabel('Query Frame Index', fontsize=12)
    ax.set_title(title, fontsize=14, fontweight='bold')
    
    # Set ticks
    ax.set_xticks(range(num_k_frames))
    ax.set_yticks(range(num_q_frames))
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved figure to {save_path}")
    
    plt.show()
    return fig

In [None]:
# Plot attention heatmap
if len(attention_weights) > 0:
    frame_seq_length = data.get('frame_seq_length', 1560)
    fig = plot_attention_heatmap(
        raw_attn,
        frame_seq_length=frame_seq_length,
        title="Frame-Level Attention Matrix (Head 0)",
        save_path="figure4_heatmap.png"
    )

## 6. Head-Averaged Attention

Average across all heads for a more robust visualization.

In [None]:
def plot_head_averaged_attention(
    attention_data: dict,
    weight_idx: int = 0,
    save_path: str = None,
):
    """
    Plot attention distribution averaged across all heads.
    """
    attn = attention_data['attention_weights'][weight_idx]['attn_weights']
    frame_seq_length = attention_data.get('frame_seq_length', 1560)
    total_heads = attn.shape[1]
    
    all_frame_attentions = []
    
    for head_idx in range(total_heads):
        frame_attention, frame_indices, _ = process_attention_for_figure4(
            attention_data, weight_idx=weight_idx, head_idx=head_idx
        )
        all_frame_attentions.append(frame_attention)
    
    all_frame_attentions = np.array(all_frame_attentions)  # [num_heads, num_frames]
    
    # Compute mean and std
    mean_attention = all_frame_attentions.mean(axis=0)
    std_attention = all_frame_attentions.std(axis=0)
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    # Plot mean with error band
    ax.plot(frame_indices, mean_attention, 'o-', linewidth=2, markersize=8,
            color='#2E86AB', label='Mean across heads')
    ax.fill_between(frame_indices,
                    mean_attention - std_attention,
                    mean_attention + std_attention,
                    alpha=0.3, color='#2E86AB', label='Â±1 std')
    
    ax.set_xlabel('Key Frame Index', fontsize=14)
    ax.set_ylabel('Average Attention Weight', fontsize=14)
    ax.set_title('Head-Averaged Temporal Attention Distribution', fontsize=16, fontweight='bold')
    ax.set_xticks(frame_indices)
    ax.legend(fontsize=12)
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved figure to {save_path}")
    
    plt.show()
    return fig, mean_attention, std_attention

In [None]:
# Plot head-averaged attention
if len(attention_weights) > 0:
    fig, mean_attn, std_attn = plot_head_averaged_attention(
        data,
        weight_idx=0,
        save_path="figure4_head_averaged.png"
    )
    
    print(f"\nMean attention per frame: {mean_attn}")
    print(f"Std attention per frame: {std_attn}")

## 7. Summary Statistics

In [None]:
if len(attention_weights) > 0:
    print("=" * 60)
    print("ATTENTION ANALYSIS SUMMARY")
    print("=" * 60)
    print(f"\nPrompt: {data.get('prompt', 'N/A')}")
    print(f"Number of frames: {data.get('num_frames', 'N/A')}")
    print(f"Captured layer: {attention_weights[0]['layer_idx']}")
    print(f"\nAttention distribution (head-averaged):")
    print(f"  - Max attention frame: {np.argmax(mean_attn)} (value: {mean_attn.max():.6f})")
    print(f"  - Min attention frame: {np.argmin(mean_attn)} (value: {mean_attn.min():.6f})")
    print(f"  - First frame attention: {mean_attn[0]:.6f}")
    print(f"  - Last frame attention: {mean_attn[-1]:.6f}")
    print(f"\nExpected pattern (from Figure 4):")
    print("  - High attention at beginning (context)")
    print("  - Lower attention in middle")
    print("  - High attention for recent frames")
    print("=" * 60)

## 8. Notes

### Interpreting the Results

1. **High attention at the start**: The model relies heavily on initial frames to establish scene context
2. **Attention sink pattern**: Some models show high attention to the very first few tokens (attention sink)
3. **Recent frame attention**: High attention to recent frames for temporal consistency
4. **U-shaped pattern**: Classic pattern showing beginning + end attention, with a dip in the middle

### Differences from Original Figure 4

- The exact pattern depends on the checkpoint and inference settings
- Different layers may show different patterns
- KV cache rolling (for local attention) affects which frames are visible