In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import random
from datasets import load_dataset

# Load model and tokenizer
model_name = "meta-llama/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto",
    output_attentions=True
).half().eval()

dataset = load_dataset('THUDM/LongBench', 'triviaqa', split='test')
random_indices = random.sample(range(len(dataset)), 10)
contexts = [dataset[idx]['context'] for idx in random_indices]
questions = [dataset[idx]['input'] for idx in random_indices]  

# 添加推理提示模板（适配问答任务）
prompt_template = lambda c, q: f"Context: {c}\n\nQuestion: {q}\n\nAnalyze the context and answer the question."
first_prompt = prompt_template(contexts[0], questions[0])

# Tokenize and prepare input
inputs = tokenizer(
    first_prompt, 
    return_tensors="pt",
    max_length=2048,
    truncation=True
).to(model.device)
input_ids = inputs.input_ids
seq_len = input_ids.shape[1]

# Storage for key vectors
layer_keys = {}

def get_key_hook(layer_idx):
    def hook(module, input, output):
        # Extract key vectors [batch_size, num_heads, seq_len, head_dim]
        key = output[2]
        layer_keys[layer_idx] = key
    return hook

# Register hooks to each layer
for layer_idx, layer in enumerate(model.model.layers):
    layer.self_attn.register_forward_hook(get_key_hook(layer_idx))

# Forward pass
with torch.no_grad():
    outputs = model(**inputs)

# Get model parameters
num_layers = len(model.model.layers)
num_heads = 32
head_dim = model.config.hidden_size // num_heads

# Initialize storage for dissimilarity scores
# Shape: [num_layers, num_heads, seq_len]
token_dissimilarity_all = torch.zeros((num_layers, num_heads, seq_len), device=model.device)

# Process each layer
for layer_idx in range(num_layers):
    keys = layer_keys[layer_idx]  # [1, num_heads, seq_len, head_dim]
    keys = keys[0]  # [num_heads, seq_len, head_dim]
    
    # Compute similarity matrix [num_heads, seq_len, seq_len]
    norm = torch.norm(keys, dim=2, keepdim=True)  # [num_heads, seq_len, 1]
    normed_keys = keys / (norm + 1e-6)  # normalized
    sim_matrix = torch.bmm(normed_keys, normed_keys.transpose(1, 2))  # [num_heads, seq_len, seq_len]
    
    # Create neighbor mask (diagonal band)
    neighbor_mask = torch.zeros_like(sim_matrix)
    for i in range(seq_len):
        left = max(0, i-3)
        right = min(seq_len, i+4)
        neighbor_mask[:, i, left:right] = 1
    neighbor_mask[:, torch.arange(seq_len), torch.arange(seq_len)] = 0  # exclude self
    
    # Compute mean similarity for each token's neighbors
    neighbor_sim = sim_matrix * neighbor_mask
    neighbor_count = neighbor_mask.sum(dim=2, keepdim=True)  # [num_heads, seq_len, 1]
    mean_sim = neighbor_sim.sum(dim=2) / (neighbor_count.squeeze(2) + 1e-6)  # [num_heads, seq_len]
    
    dissimilarity = mean_sim
    
    # Store for this layer
    token_dissimilarity_all[layer_idx] = dissimilarity

# Convert to numpy for easier handling
token_dissimilarity_np = token_dissimilarity_all.cpu().numpy()

# Get token texts for reference
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

# Prepare results
results = {
    "prompt": first_prompt,
    "tokens": tokens,
    "token_dissimilarity": token_dissimilarity_np.tolist(),  # Convert to list for JSON serialization
    "shape": token_dissimilarity_np.shape  # (num_layers, num_heads, seq_len)
}

# Print summary
print(f"Analyzed prompt: {first_prompt[:200]}...")  # Print first 200 chars
print(f"Number of tokens: {seq_len}")
print(f"Dissimilarity array shape: {token_dissimilarity_np.shape} (layers, heads, tokens)")
print("\nExample dissimilarity values for first layer, first head:")
print(token_dissimilarity_np[0, 0, :10])  # First 10 tokens of first layer/head

In [None]:
import matplotlib.pyplot as plt
from matplotlib import rcParams
import numpy as np

# Set global style
rcParams['font.family'] = 'Arial'
rcParams['font.weight'] = 'normal'

# Font sizes
fontsize = 14
fontlabel = 16
tick_fontsize = 12

fig, axs = plt.subplots(4, 8, figsize=(24, 12), dpi=300)  
plt.subplots_adjust(hspace=0.5, wspace=0.3)

token_positions = np.arange(token_dissimilarity_np.shape[2])
layers_to_plot = [7, 10, 21, 29]  
heads_to_plot = [3, 7, 11, 15, 19, 23, 27, 31]  

for row, head_idx in enumerate(heads_to_plot):
    for col, layer_idx in enumerate(layers_to_plot):
        ax = axs[col, row]
        ax.plot(token_positions, 
                np.abs(token_dissimilarity_np)[layer_idx, head_idx, :],
                linestyle='-',
                linewidth=1,
                color='#3A7CA5')  
        ax.set_title(f'Layer {layer_idx+1} / Head {head_idx+1}', fontsize=fontsize)
        ax.set_ylabel("Neighborhood Similarity", fontsize=fontlabel)
        ax.set_xlabel("Token Pos", fontsize=fontlabel)
        ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
        for spine in ['top', 'right']:
            ax.spines[spine].set_visible(False)
        ax.grid(axis='y', linestyle='--', alpha=0.3)

plt.tight_layout()
plt.show()