## 参数统计
```python
DATA_PATH = "../cache/layer29.pt"
SAVE_DIR = "attention_analysis/layer29"
SAVE_SVG = True
```

In [1]:
import os
import torch
import re
import cairosvg
import papermill as pm
from io import BytesIO
from tqdm.auto import tqdm
from PIL import Image, ImageDraw, ImageFont

In [2]:
pro_dir = 'frames72/castle_mountain' # 注意力 pt 文件夹入口
data_paths = [(f"../cache/" + pro_dir + f"/layer{i}.pt") for i in range(30)]
data_example = torch.load('../cache/' + pro_dir +'/layer3.pt', map_location='cpu', weights_only=True)
BASE_DIR = "attention_analysis/" + pro_dir
os.makedirs("runs", exist_ok=True)
os.makedirs(BASE_DIR, exist_ok=True)
ROWS, COLUMNS, LABEL_HEIGHT = 5, 6, 40

In [3]:
def print_info(data_eg):
    print('=' * 30 + " prompt " + '=' * 30)
    print(data_eg['prompt'])
    print('=' * 30 + " keys list " + '=' * 30)
    print('\n'.join(data_eg.keys()))
    print('=' * 28 + " keys content " + '=' * 28)
    for key in data_eg.keys():
        if isinstance(data_eg[key], torch.Tensor):
            print(f"Tensor '{key}', dim: {data_eg[key].shape}")
        else:
            print(f"key: '{key}' content type: {type(data_eg[key])}")

In [4]:
print_info(data_example)

A medieval castle standing on a foggy mountain, twilight lighting
layer_index
full_frame_attention
last_block_frame_attention
is_logits
prompt
num_frames
frame_seq_length
num_frame_per_block
num_heads
block_sizes
query_frames
key_frames
last_block_query_frames
extraction_method
chunk_frames
key: 'layer_index' content type: <class 'int'>
Tensor 'full_frame_attention', dim: torch.Size([12, 72, 72])
Tensor 'last_block_frame_attention', dim: torch.Size([12, 72])
key: 'is_logits' content type: <class 'bool'>
key: 'prompt' content type: <class 'str'>
key: 'num_frames' content type: <class 'int'>
key: 'frame_seq_length' content type: <class 'int'>
key: 'num_frame_per_block' content type: <class 'int'>
key: 'num_heads' content type: <class 'int'>
key: 'block_sizes' content type: <class 'list'>
key: 'query_frames' content type: <class 'list'>
key: 'key_frames' content type: <class 'list'>
key: 'last_block_query_frames' content type: <class 'list'>
key: 'extraction_method' content type: <class '

In [None]:
for i, path in enumerate(tqdm(data_paths, desc="Total Progress")):
    output_notebook = f"runs/result_layer_{i}.ipynb"
    os.makedirs("runs", exist_ok=True)
    
    try:
        pm.execute_notebook(
            'extract_all_attention.ipynb',
            output_notebook,
            parameters={
                'DATA_PATH': path,
                'SAVE_DIR': (f"attention_analysis/" + pro_dir + f"/layer{i}"),
                'SAVE_SVG': True
            },
            progress_bar=True
        )
    except KeyboardInterrupt:
        print("\n[!] KeyboardInterrupt.")
        break
    except Exception as e:
        print(f"\n[X] processing {path} ERROR: {e}")
        continue

Total Progress:   0%|          | 0/30 [00:00<?, ?it/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

Executing:   0%|          | 0/13 [00:00<?, ?cell/s]

In [None]:
def get_layer_num(dir_name):
    """从文件夹名提取数字，确保按 0, 1, 2...29 排序而非 0, 1, 10..."""
    match = re.search(r'layer(\d+)', dir_name)
    return int(match.group(1)) if match else -1

In [None]:
def create_grid(typ):
    # 1. 获取并排序所有层目录
    layer_dirs = [d for d in os.listdir(BASE_DIR) if d.startswith('layer')]
    layer_dirs.sort(key=get_layer_num)
    
    processed_images = []
    cell_width = 0
    cell_height = 0

    print("开始渲染 SVG 并添加标注...")
    
    for folder in layer_dirs:
        layer_idx = get_layer_num(folder)
        file_path = ""
        if typ ==  "heatmap":
            file_path = os.path.join(BASE_DIR, folder, f"{folder}_2d_heatmap_all_heads.svg")
        else:
            file_path = os.path.join(BASE_DIR, folder, f"{folder}_perhead_grid.svg")
        
        if not os.path.exists(file_path):
            continue

        # 将 SVG 转为 PNG 图片
        png_data = cairosvg.svg2png(url=file_path)
        img = Image.open(BytesIO(png_data)).convert("RGB")
        
        # 统一尺寸（以第一张图为准）
        if cell_width == 0:
            cell_width = img.width
            cell_height = img.height

        # 创建一个带标题的单元格画布
        cell = Image.new("RGB", (cell_width, cell_height + LABEL_HEIGHT), (255, 255, 255))
        cell.paste(img, (0, LABEL_HEIGHT))
        
        # 绘制文字标签
        draw = ImageDraw.Draw(cell)
        # 如果报错，可以移除 font 参数使用默认字体
        try:
            font = ImageFont.truetype("arial.ttf", 24)
        except:
            font = ImageFont.load_default()
            
        draw.text((10, 5), f"Layer {layer_idx}", fill=(0, 0, 0), font=font)
        
        processed_images.append(cell)

    if not processed_images:
        print("未找到图片，请检查路径。")
        return

    # 2. 创建最终的大画布
    grid_width = cell_width * COLUMNS
    grid_height = (cell_height + LABEL_HEIGHT) * ROWS
    final_image = Image.new("RGB", (grid_width, grid_height), (255, 255, 255))

    # 3. 拼接
    for i, img in enumerate(processed_images):
        x = (i % COLUMNS) * cell_width
        y = (i // COLUMNS) * (cell_height + LABEL_HEIGHT)
        final_image.paste(img, (x, y))

    # 4. 保存
    final_image.save(OUTPUT_FILE)
    print(f"✅ 成功！对比图已生成：{OUTPUT_FILE}")
    print(f"尺寸: {grid_width}x{grid_height}")

In [None]:
OUTPUT_FILE = BASE_DIR + "/heatmap_attention_comparison_grid.png"
create_grid("heatmap")
OUTPUT_FILE = BASE_DIR + "/perhead_attention_comparison_grid.png"
create_grid("perhead")