# 单层注意力分析 / Single-Layer Attention Analysis

## 工作流程 / Workflow

**Step 1:** 运行注意力提取脚本 / Run the attention extraction script:
```bash
bash experiments/extract_attention.sh ./Wan2.1-T2V-1.3B 20 cache
```

**Step 2:** 更新下方 `DATA_PATH` 指向生成的 `.pt` 文件 / Update `DATA_PATH` below

**Step 3:** 运行所有 cell 生成图表 / Run all cells to generate plots

---

绘制两张图 / Two plots:
1. **2D 热力图**: Query Frame × Key Frame，3×4 网格显示 12 个 head
2. **Per-Head Grid**: 最后一个 block 对各帧的注意力柱状图

In [1]:
import math
import os

import seaborn as sns
from matplotlib import colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import torch

plt.rcParams.update({
        "svg.fonttype": "none",
        "font.family": "sans-serif",
        "font.sans-serif": ["Arial", "DejaVu Sans", "Helvetica"],
        "font.size": 11,
        "axes.labelsize": 12,
        "axes.titlesize": 13,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "legend.fontsize": 10,
        "figure.dpi": 150,
        "savefig.dpi": 300,
        "savefig.bbox": "tight",
        "savefig.pad_inches": 0.1,  
})

## 配置

In [None]:
# ============================================================
# 配置 / Configuration
# ============================================================
# 运行 extract_attention.sh 后，将 DATA_PATH 指向生成的 .pt 文件
# After running extract_attention.sh, point DATA_PATH to the generated .pt file
#
# Example:
#   bash experiments/extract_attention.sh ./Wan2.1-T2V-1.3B 20 cache
#   -> generates cache/layer20.pt

DATA_PATH = "../cache/layer20.pt"
SAVE_DIR = "attention_analysis"
SAVE_SVG = True

# 检查文件是否存在
import os
if not os.path.exists(DATA_PATH):
    print(f"⚠️  Warning: {DATA_PATH} not found!")
    print("   Run: bash experiments/extract_attention.sh <ckpt_dir> <layer_index>")
else:
    print(f"✓ Found: {DATA_PATH}")

## 加载数据

In [None]:
data = torch.load(DATA_PATH, map_location="cpu", weights_only=False)

print("=" * 60)
print(f"Layer: {data['layer_index']}")
print(f"Prompt: {data.get('prompt', 'N/A')}")
print(f"Num frames (saved): {data.get('num_frames', 'N/A')}")
print(f"Num heads: {data.get('num_heads', 'N/A')}")
print(f"Block sizes: {data.get('block_sizes', 'N/A')}")
print(f"Last block Q frames (saved): {data.get('last_block_query_frames', 'N/A')}")

# 加载数据
layer_idx = data['layer_index']
num_heads = data['num_heads']

# 更新保存目录（使用层索引）
SAVE_DIR_LAYER = os.path.join(SAVE_DIR, f"layer{layer_idx}")

# 完整的 frame×frame 注意力矩阵（latent 帧）
full_frame_attn = data['full_frame_attention'].float().numpy()  # [num_heads, Q, K]
latent_num_frames = full_frame_attn.shape[1]
print(f"\nFull attention shape (latent): {full_frame_attn.shape}")
print(f"Range: [{full_frame_attn.min():.4f}, {full_frame_attn.max():.4f}]")

# ============================================================
# Frame View Config
# ============================================================
# 注意：模型在时间维度有下采样（vae_stride[0]=4），所以 latent 帧数通常小于原始帧数。
# 如果你想按原始帧 (e.g. 21) 来看，把 USE_ORIG_FRAMES=True 并设置 ORIG_NUM_FRAMES。
USE_ORIG_FRAMES = True
ORIG_NUM_FRAMES = data.get('orig_num_frames', 21)  # 如果不同，请手动改
TEMPORAL_STRIDE = data.get('vae_stride_t', 4)  # t2v-1.3B 为 4

if USE_ORIG_FRAMES and ORIG_NUM_FRAMES is not None:
    orig_indices = np.arange(ORIG_NUM_FRAMES, dtype=int)
    orig_to_latent = np.minimum(orig_indices // TEMPORAL_STRIDE, latent_num_frames - 1)
    display_full_frame_attn = full_frame_attn[:, orig_to_latent][:, :, orig_to_latent]
    display_num_frames = ORIG_NUM_FRAMES
    display_frame_indices = orig_indices.tolist()
    print(f"\nUsing expanded attention for original frames: {ORIG_NUM_FRAMES} (stride={TEMPORAL_STRIDE}).")
else:
    display_full_frame_attn = full_frame_attn
    display_num_frames = latent_num_frames
    display_frame_indices = list(range(latent_num_frames))
    print(f"\nUsing latent-frame attention: {latent_num_frames} frames.")

# 选择要看的 query 帧（按原始帧索引）
QUERY_FRAMES = [18, 19, 20] if display_num_frames >= 21 else list(range(max(display_num_frames - 3, 0), display_num_frames))
last_block_q_frames = QUERY_FRAMES
last_block_attn = display_full_frame_attn[:, QUERY_FRAMES, :].mean(axis=1)  # [num_heads, K]

print(f"Last block query frames (display): {last_block_q_frames}")
print(f"Last block attention shape: {last_block_attn.shape}")
print(f"Save directory: {SAVE_DIR_LAYER}")


## 图1: 2D 热力图 (Query Frame × Key Frame)

In [None]:
# ============================================================
# 图1: 2D 热力图 - 每个 head 一张小图
# X-axis: Key Frame Index
# Y-axis: Query Frame Index
# Layout: 3×4 grid for 12 heads
# ============================================================

ncols = 4
nrows = math.ceil(num_heads / ncols)

fig, axes = plt.subplots(nrows, ncols, figsize=(16, 12))
axes = axes.flatten()

# 全局 colorbar 范围（对称，便于 RdBu_r 显示）
vmax = max(abs(display_full_frame_attn.min()), abs(display_full_frame_attn.max()))
vmin = -vmax

for h in range(num_heads):
    ax = axes[h]
    attn_map = display_full_frame_attn[h]  # [Q, K]
    
    im = ax.imshow(
        attn_map,
        cmap="RdBu_r",
        aspect="auto",
        origin="lower",
        vmin=vmin,
        vmax=vmax,
        interpolation="nearest"
    )
    
    ax.set_title(f"Head {h}", fontsize=11, fontweight="bold")
    ax.set_xlabel("Key Frame", fontsize=9)
    ax.set_ylabel("Query Frame", fontsize=9)
    
    # 设置刻度
    tick_pos = [0, display_num_frames // 2, display_num_frames - 1]
    ax.set_xticks(tick_pos)
    ax.set_xticklabels(tick_pos)
    ax.set_yticks(tick_pos)
    ax.set_yticklabels(tick_pos)

# 隐藏多余的子图
for k in range(num_heads, len(axes)):
    axes[k].axis("off")

# 添加 colorbar
fig.subplots_adjust(right=0.92)
cbar_ax = fig.add_axes([0.94, 0.15, 0.02, 0.7])
cbar = fig.colorbar(im, cax=cbar_ax)
cbar.set_label("Attention Logits", fontsize=11)

fig.suptitle(
    f"Layer {layer_idx}: 2D Attention Maps (All Heads)\n"
    f"Query Frames: 0-{display_num_frames-1}, Key Frames: 0-{display_num_frames-1}",
    fontsize=14,
    fontweight="bold",
    y=1.02
)

plt.tight_layout(rect=[0, 0, 0.92, 0.98])

if SAVE_SVG:
    os.makedirs(SAVE_DIR_LAYER, exist_ok=True)
    save_path = os.path.join(SAVE_DIR_LAYER, f"layer{layer_idx}_2d_heatmap_all_heads.svg")
    plt.savefig(save_path, format="svg", bbox_inches="tight")
    print(f"Saved: {save_path}")

plt.show()


## 图2: Per-Head Grid (最后一个 block 对各帧的注意力)

In [None]:
# ============================================================
# 图2: Per-Head Grid 柱状图
# 显示 QUERY_FRAMES 对各 key frame 的注意力（默认 18-20）
# ============================================================

key_indices = np.arange(display_num_frames)

ncols = 4
nrows = math.ceil(num_heads / ncols)

fig, axes = plt.subplots(nrows, ncols, figsize=(14, 3 * nrows))
axes = axes.flatten()

BAR_COLOR = sns.color_palette("colorblind")[0]

for h in range(num_heads):
    ax = axes[h]
    head = last_block_attn[h]  # [K]
    
    # 计算 sink score (首帧 - 中间帧均值)
    first = head[0]
    middle = head[1:-1].mean() if len(head) > 2 else head.mean()
    sink_score = first - middle
    
    ax.bar(key_indices, head, alpha=0.85, width=0.8, color=BAR_COLOR)
    ax.plot(key_indices, head, "o-", color="black", linewidth=1, markersize=2)
    ax.set_title(f"H{h} (sink={sink_score:.2f})", fontsize=10, fontweight="bold")
    ax.tick_params(axis="both", which="major", labelsize=7)
    ax.grid(True, alpha=0.3)
    
    if len(key_indices) > 10:
        ax.set_xticks([0, len(key_indices) // 2, len(key_indices) - 1])

for k in range(num_heads, len(axes)):
    axes[k].axis("off")

fig.suptitle(
    f"Layer {layer_idx}: Per-Head Attention Distribution\n"
    f"Query: frames {last_block_q_frames}",
    fontsize=12,
    fontweight="bold",
    y=1.00 + (0.01 * nrows),
)

plt.tight_layout()

if SAVE_SVG:
    os.makedirs(SAVE_DIR_LAYER, exist_ok=True)
    save_path = os.path.join(SAVE_DIR_LAYER, f"layer{layer_idx}_perhead_grid.svg")
    plt.savefig(save_path, format="svg", bbox_inches="tight")
    print(f"Saved: {save_path}")

plt.show()


## 统计

In [6]:
print("=" * 60)
print(f"Layer {layer_idx} Statistics")
print("=" * 60)

# 全局统计
diag = np.array([display_full_frame_attn[h, i, i] for h in range(num_heads) for i in range(display_num_frames)])
diag_mean = diag.mean()
first_col = display_full_frame_attn[:, :, 0]
first_col_mean = first_col[first_col != 0].mean() if (first_col != 0).any() else 0

print(f"Diagonal mean (self-attention): {diag_mean:.4f}")
print(f"First frame mean (sink): {first_col_mean:.4f}")

print("\nPer-Head Statistics (selected frames):")
for h in range(num_heads):
    head = last_block_attn[h]
    first = head[0]
    middle = head[1:-1].mean() if len(head) > 2 else head.mean()
    last = head[-1]
    sink = first - middle
    
    print(f"  H{h:2d}: first={first:7.3f}, mid={middle:7.3f}, last={last:7.3f}, sink={sink:+7.3f}")


Layer 20 Statistics
Diagonal mean (self-attention): 7.9511
First frame mean (sink): 1.4547

Per-Head Statistics (last block):
  H 0: first=  1.787, mid=  1.454, last=  2.680, sink= +0.333
  H 1: first=  1.240, mid=  0.186, last=  5.461, sink= +1.055
  H 2: first=  4.270, mid=  0.990, last=  7.281, sink= +3.279
  H 3: first=  3.805, mid=  0.452, last=  7.125, sink= +3.352
  H 4: first=  3.789, mid=  3.106, last=  2.883, sink= +0.683
  H 5: first=  2.662, mid=  0.583, last=  6.707, sink= +2.079
  H 6: first=  4.168, mid=  3.730, last=  5.309, sink= +0.438
  H 7: first=  6.180, mid=  1.992, last= 10.375, sink= +4.188
  H 8: first=  0.476, mid= -0.379, last=  4.203, sink= +0.855
  H 9: first=  1.579, mid=  0.468, last=  3.674, sink= +1.111
  H10: first=  2.215, mid=  1.685, last=  3.797, sink= +0.530
  H11: first=  1.619, mid=  0.062, last=  5.547, sink= +1.557
