# 提取 Block 9 Token 500 的 Attention Map 和 Value

本 notebook 用於提取 3 個 frame 的 block 9 中 token 500 的：
- Attention Map（注意力權重分佈）
- Attention Values（如需要也可以提取 V 矩陣）

## 1. 載入必要的套件和模型

In [None]:
from visualize_attention import (
    AttentionExtractor,
    extract_both_attentions,
    visualize_attention_on_image,
    get_attention_values
)
from vggt.models.vggt import VGGT
from vggt.utils.load_fn import load_and_preprocess_images
import torch
import numpy as np
import matplotlib.pyplot as plt

# 設定裝置
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用裝置: {device}")

# 載入模型
print("載入 VGGT 模型...")
model = VGGT.from_pretrained("facebook/VGGT-1B").to(device)
model.eval()
print("模型載入完成！")

## 2. 載入 3 張圖片

請修改 `image_paths` 為你想要使用的 3 張圖片路徑

In [None]:
# 修改這裡的圖片路徑
image_paths = [
    "examples/llff_flower/images/000.png",
    "examples/llff_flower/images/005.png",
    "examples/llff_flower/images/010.png"
]

# 載入並預處理圖片
print(f"載入 {len(image_paths)} 張圖片...")
images = load_and_preprocess_images(image_paths).to(device)
print(f"圖片 shape: {images.shape}")
print("圖片載入完成！")

## 3. 提取 Block 9 的 Attention

提取 block 9 的 frame attention 和 global attention

In [None]:
# 提取 block 9 的 attention
block_idx = 9
token_idx = 500
head_idx = 0  # 可以修改為其他 head

print(f"提取 Block {block_idx} 的 attention...")
result = extract_both_attentions(model, images, block_idx=block_idx)

print("\n提取結果：")
print(f"  Frame Attention shape: {result['frame_attention']['attn_weights'].shape}")
print(f"  Global Attention shape: {result['global_attention']['attn_weights'].shape}")
print(f"  Patch start index: {result['patch_start_idx']}")

## 4. 提取 Token 500 的 Attention Map（數值）

### 4.1 Global Attention - 3 個 Frame 的 Attention Map

In [None]:
# 提取 global attention 的純數值
global_values = get_attention_values(
    attn_weights=result['global_attention']['attn_weights'],
    images=images,
    token_idx=token_idx,
    head_idx=head_idx,
    patch_start_idx=result['patch_start_idx'],
    attention_type='global'
)

print("Global Attention Values:")
print(f"  Attention maps shape: {global_values['attention_maps'].shape}")  # [3, grid_h, grid_w]
print(f"  Attention maps resized shape: {global_values['attention_maps_resized'].shape}")  # [3, H, W]
print(f"  Images shape: {global_values['images'].shape}")  # [3, H, W, 3]
print(f"\nMetadata:")
for key, value in global_values['metadata'].items():
    print(f"  {key}: {value}")

### 4.2 Frame Attention - 第一個 Frame 的 Attention Map

In [None]:
# 提取 frame attention 的純數值
frame_values = get_attention_values(
    attn_weights=result['frame_attention']['attn_weights'],
    images=images,
    token_idx=token_idx,
    head_idx=head_idx,
    patch_start_idx=result['patch_start_idx'],
    attention_type='frame'
)

print("Frame Attention Values:")
print(f"  Attention maps shape: {frame_values['attention_maps'].shape}")  # [grid_h, grid_w]
print(f"  Attention maps resized shape: {frame_values['attention_maps_resized'].shape}")  # [H, W]
print(f"\nMetadata:")
for key, value in frame_values['metadata'].items():
    print(f"  {key}: {value}")

## 5. 視覺化 Attention Map

### 5.1 視覺化 Global Attention（3 個 Frame）

In [None]:
# 視覺化 global attention
fig_global = visualize_attention_on_image(
    attn_weights=result['global_attention']['attn_weights'],
    images=images,
    token_idx=token_idx,
    head_idx=head_idx,
    patch_start_idx=result['patch_start_idx'],
    attention_type='global',
    layer_name=f'Block {block_idx}'
)
plt.show()

### 5.2 視覺化 Frame Attention

In [None]:
# 視覺化 frame attention
fig_frame = visualize_attention_on_image(
    attn_weights=result['frame_attention']['attn_weights'],
    images=images,
    token_idx=token_idx,
    head_idx=head_idx,
    patch_start_idx=result['patch_start_idx'],
    attention_type='frame',
    layer_name=f'Block {block_idx}'
)
plt.show()

## 6. 查看具體的 Attention 數值

### 6.1 查看每個 Frame 的 Attention 統計

In [None]:
# 查看 3 個 frame 的 attention 統計資訊
print("Global Attention 統計（Token 500 對 3 個 Frame 的注意力）:")
print("="*60)

for frame_idx in range(3):
    frame_attn = global_values['attention_maps'][frame_idx]
    print(f"\nFrame {frame_idx}:")
    print(f"  Shape: {frame_attn.shape}")
    print(f"  Min: {frame_attn.min():.6f}")
    print(f"  Max: {frame_attn.max():.6f}")
    print(f"  Mean: {frame_attn.mean():.6f}")
    print(f"  Std: {frame_attn.std():.6f}")
    print(f"  Sum: {frame_attn.sum():.6f}")

### 6.2 儲存 Attention Map 數值

In [None]:
# 儲存為 .npy 檔案供後續使用
output_data = {
    'global_attention_maps': global_values['attention_maps'],  # [3, grid_h, grid_w]
    'global_attention_maps_resized': global_values['attention_maps_resized'],  # [3, H, W]
    'frame_attention_map': frame_values['attention_maps'],  # [grid_h, grid_w]
    'frame_attention_map_resized': frame_values['attention_maps_resized'],  # [H, W]
    'metadata': {
        'block_idx': block_idx,
        'token_idx': token_idx,
        'head_idx': head_idx,
        **global_values['metadata']
    }
}

# 儲存
np.save('block9_token500_attention.npy', output_data)
print("已儲存 attention maps 到 'block9_token500_attention.npy'")

# 顯示儲存的資料結構
print("\n儲存的資料結構：")
print(f"  global_attention_maps: {output_data['global_attention_maps'].shape}")
print(f"  global_attention_maps_resized: {output_data['global_attention_maps_resized'].shape}")
print(f"  frame_attention_map: {output_data['frame_attention_map'].shape}")
print(f"  frame_attention_map_resized: {output_data['frame_attention_map_resized'].shape}")

## 7. （可選）查看原始 Attention Weights

如果你需要完整的 attention weight 矩陣（所有 token 之間的關係）：

In [None]:
# Global attention: token 500 對所有其他 token 的 attention weights
global_attn_weights = result['global_attention']['attn_weights']
token_500_global_attn = global_attn_weights[0, head_idx, token_idx, :]  # [num_tokens]

print(f"Token 500 的 Global Attention Weights:")
print(f"  Shape: {token_500_global_attn.shape}")
print(f"  總 token 數: {token_500_global_attn.shape[0]}")
print(f"  Sum (應該 ≈ 1.0): {token_500_global_attn.sum():.6f}")

# 找出 attention 最高的前 10 個 token
top_k = 10
top_indices = torch.topk(token_500_global_attn, k=top_k).indices
top_values = torch.topk(token_500_global_attn, k=top_k).values

print(f"\nAttention 最高的前 {top_k} 個 token:")
for i, (idx, val) in enumerate(zip(top_indices, top_values)):
    print(f"  #{i+1}: Token {idx.item()}, Attention = {val.item():.6f}")

## 8. 總結

本 notebook 提取了：

1. **Global Attention Maps**: Token 500 對 3 個 frame 的注意力分佈
   - `global_values['attention_maps']`: shape [3, grid_h, grid_w]
   - `global_values['attention_maps_resized']`: shape [3, H, W]

2. **Frame Attention Map**: Token 500 在單個 frame 內的注意力分佈
   - `frame_values['attention_maps']`: shape [grid_h, grid_w]
   - `frame_values['attention_maps_resized']`: shape [H, W]

3. **原始 Attention Weights**: Token 500 對所有 token 的 attention weights
   - `token_500_global_attn`: shape [num_tokens]

所有數據已儲存到 `block9_token500_attention.npy`