In [None]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# 1. 加载模型和分词器
# model_name = "Qwen/Qwen2.5-7B-Instruct"
model_name = "meta-llama/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto",
    output_attentions=True
).half().eval()

# 2. 加载数据集
triviaqa_dataset = load_dataset('THUDM/LongBench', 'triviaqa', split='test')
num_samples = 10  # 取100条数据
dataset_samples = triviaqa_dataset.select(range(num_samples))

# 3. 定义参数
num_layers = len(model.model.layers)
num_heads = 32  # Qwen2-7B-Instruct的注意力头数
kv_pairs = 32    # Multi-Query Attention的kv对数
group_num = 10  # 分组数量

# 初始化累计结果
total_group_results = torch.zeros(num_layers, kv_pairs, group_num, device=model.device)

# dataset = load_dataset('THUDM/LongBench', 'samsum', split='test')
# random_indices = random.sample(range(len(dataset)), 10)
# contexts = [dataset[idx]['context'] for idx in random_indices]
# questions = [dataset[idx]['input'] for idx in random_indices]

# # Create prompt for the first sequence
# prompt_template = lambda c, q: f"Context: {c}\n\nQuestion: {q}\n\nAnalyze the context and answer the question."

# 4. 定义处理单个样本的函数
def process_sample(sample):
    question = sample["input"]
    evidence = sample["context"]

    prompt = f"""Answer the question based on the given passage. Only give me the answer and do not output any other words.

Context: {evidence}

Question: {question}"""

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    seq_len = len(tokens)-5
    
    # 存储各层key向量和注意力权重
    key_vectors = {}
    attention_weights = {}

    def get_key_hook(layer_idx):
        def hook(module, input, output):
            key_vectors[f"layer_{layer_idx}"] = output[2].detach()[:,:,5:,:] # key向量
            attention_weights[f"layer_{layer_idx}"] = output[1].detach()[:,:,5:,5:]  # 注意力权重
        return hook

    # 注册钩子到每个注意力层
    hooks = []
    for layer_idx, layer in enumerate(model.model.layers):
        hook = layer.self_attn.register_forward_hook(get_key_hook(layer_idx))
        hooks.append(hook)

    # 执行前向传播
    with torch.no_grad():
        outputs = model(**inputs)
    
    # 移除钩子
    for hook in hooks:
        hook.remove()

    # 初始化当前样本的结果
    sample_group_results = torch.zeros(num_layers, kv_pairs, group_num, device=model.device)

    for layer_idx in range(num_layers):
        layer_key = key_vectors[f"layer_{layer_idx}"]  # [1, kv_pairs, seq_len, head_dim]
        layer_attn = attention_weights[f"layer_{layer_idx}"]  # [1, num_heads, seq_len, seq_len]
        
        # 计算key相似度
        neighbor_mask = torch.zeros(seq_len, seq_len, dtype=torch.bool, device=model.device)
        for i in range(seq_len):
            left = max(0, i - 2)
            right = min(seq_len, i + 3)
            neighbor_mask[i, left:right] = True
        
        key_norm = layer_key.squeeze(0).norm(dim=2, keepdim=True)  # [kv_pairs, seq_len, 1]
        cos_sim = torch.bmm(layer_key.squeeze(0), layer_key.squeeze(0).transpose(1,2)) / (key_norm * key_norm.transpose(1,2) + 1e-9)
        
        neighbor_sim = (cos_sim * neighbor_mask.unsqueeze(0)).sum(dim=2) / neighbor_mask.sum(dim=1).unsqueeze(0)
        avg_sim = neighbor_sim.mean(dim=0)  # [seq_len]
        
        # 分组处理
        sorted_indices = torch.argsort(avg_sim)
        group_size = seq_len // group_num
        groups = [sorted_indices[i*group_size : (i+1)*group_size] for i in range(group_num)]
        
        # 计算注意力权重
        last_10_tokens = slice(seq_len-10, seq_len)  # 最后10个token
        prev_tokens = slice(None, seq_len-10)        # 前面所有token
        
        head_per_kv = num_heads // kv_pairs
        relevant_heads = layer_attn.squeeze(0).reshape(kv_pairs, head_per_kv, seq_len, seq_len)
        
        avg_attn = relevant_heads[:, :, last_10_tokens, prev_tokens].mean(dim=1)
        
        for group_idx, group in enumerate(groups):
            valid_group = group[group < seq_len-10]
            if len(valid_group) > 0:
                group_avg = avg_attn[:, :, valid_group].sum(dim=2)
                sample_group_results[layer_idx, :, group_idx] = group_avg.mean(dim=1)
    
    return sample_group_results

# 5. 处理所有样本
for sample in tqdm(dataset_samples, desc="Processing samples"):
    try:
        sample_result = process_sample(sample)
        total_group_results += sample_result
    except Exception as e:
        print(f"Error processing sample: {e}")
        continue

# 6. 计算平均值
avg_group_results = total_group_results / num_samples

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib import rcParams
import matplotlib.cm as cm
from matplotlib.colors import Normalize, LinearSegmentedColormap

# Set Nature-inspired style parameters
plt.rcParams.update({
    'font.family': 'Arial',
    'font.size': 10,
    'axes.labelsize': 14+2,
    'axes.titlesize': 14+2,
    'xtick.labelsize': 12+2,
    'ytick.labelsize': 12+2,
    'legend.fontsize': 12+2,
    'figure.dpi': 600,
    'axes.spines.top': False,
    'axes.spines.right': False
})

# Custom blue colormap
colors = ["#caf0f8", "#48cae4", "#0077b6", "#023e8a", "#001233"]   # 更深的深蓝色和更浅的浅蓝色
cmap = LinearSegmentedColormap.from_list("custom_blue", colors, N=32)
norm = Normalize(vmin=0, vmax=31)

# Load data and select specific layers (0, 6, 12, 18, 24, 30)
data = avg_group_results.cpu()
selected_data = data[[0, 1, 4, 12, 20, 28]]  # 0-indexed layers 0,6,12,18,24,30

# Create x-axis (Neighborhood Similarity 0.1-1)
x = [1,2,3,4,5,6,7,8,9,10]

# Layer titles
layer_titles = ['Layer 1', 'Layer 2', 'Layer 4', 'Layer 12', 'Layer 20', 'Layer 28']

# Create 2x3 subplots with adjusted size
fig, axs = plt.subplots(2, 3, figsize=(7.5, 3.5))  # Wider figure for 3 columns
plt.subplots_adjust(wspace=0.23, hspace=0.45)

# Flatten axes for easier iteration
axs_flat = axs.flatten()

# Plot each subplot
for i, ax in enumerate(axs_flat):
    # Plot each attention head
    for j in range(32):
        ax.plot(x, selected_data[i,j], 
               marker='o', markersize=6, linestyle='-', linewidth=2.2,
               color=cmap(norm(j)),
               markerfacecolor='white',
               markeredgewidth=0.8,
               alpha=0.8)
    
    # Set titles and labels
    ax.set_title(layer_titles[i], fontsize=14, fontweight="bold")
    
    # Only left subplots get y-label
    if i % 3 == 0:  # Changed from 2 to 3 for 3 columns
        ax.set_ylabel('Attention', fontsize=12+2)
    
    # Bottom row gets x-label
    if i >= 3:  # Changed from 2 to 3 for 2 rows
        ax.set_xlabel('Localness Degree', fontsize=12+2)
    
    # Set ticks and grid
    ax.tick_params(axis='both', labelsize=12)
    ax.grid(True, linestyle=':', alpha=0.3)
    ax.set_xticks([2, 4, 6, 8, 10])

# Add horizontal colorbar at top
cbar_ax = fig.add_axes([0.15, -0.12, 0.7, 0.02])  # [left, bottom, width, height]
sm = cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = plt.colorbar(sm, cax=cbar_ax, orientation='horizontal')

# Add colorbar label
cbar.ax.text(15, 2.1, 'Attention Head Index From 1 To 32', 
             fontsize=12+3, ha='center', va='center')
cbar.ax.tick_params(labelsize=10)

# Save and show
plt.savefig('layers_attention.pdf', bbox_inches='tight', dpi=600)
plt.show()