# Attention Mechanism Explorer

**Duration:** ~30 min | **Platform:** Kaggle dual Tesla T4

This notebook dives deep into **attention mechanisms** — extracting attention
configuration from GGUF metadata, analyzing multi-head attention structure,
and visualizing attention patterns.

### What you'll learn
1. Extract attention configuration from GGUF
2. Understand multi-head attention dimensions
3. Analyze attention patterns during inference
4. Visualize per-head patterns
5. Q/K/V weight matrix analysis

In [None]:
!pip install -q git+https://github.com/llamatelemetry/llamatelemetry.git@v1.0.0
!pip install -q matplotlib

import llamatelemetry
from llamatelemetry.llama import ServerManager, LlamaCppClient, parse_gguf_header
from huggingface_hub import hf_hub_download

llamatelemetry.init(service_name="attention-explorer")

model_path = hf_hub_download(
    repo_id="bartowski/google_gemma-3-1b-it-GGUF",
    filename="google_gemma-3-1b-it-Q4_K_M.gguf",
    cache_dir="/root/.cache/huggingface",
)

# Start server for inference experiments
mgr = ServerManager()
mgr.start_server(model_path=model_path, gpu_layers=99, ctx_size=2048)
mgr.wait_until_ready(timeout=60)
client = LlamaCppClient(base_url="http://127.0.0.1:8090")
print("Ready")

## Multi-Head Attention Explained

In transformer models, each attention layer splits the input into multiple
"heads" that attend to different aspects of the input independently.

```
Input → [Q_proj] → Q matrix → split into H heads
      → [K_proj] → K matrix → split into H heads
      → [V_proj] → V matrix → split into H heads

Each head: Attention(Q_h, K_h, V_h) = softmax(Q_h × K_h^T / √d_k) × V_h

All heads → [concatenate] → [O_proj] → Output
```

In [None]:
@llamatelemetry.task(name="extract-attention-config")
def extract_attention_config(model_path):
    info = parse_gguf_header(model_path, read_tensors=True)
    meta = info.metadata

    embed_dim = meta.embedding_length or 0
    n_layers = meta.block_count or 0

    # Find Q/K/V projection tensors
    qkv_tensors = {"q_proj": [], "k_proj": [], "v_proj": [], "o_proj": []}
    for t in info.tensors:
        for key in qkv_tensors:
            if key in t.name:
                qkv_tensors[key].append(t)

    config = {
        "architecture": meta.architecture,
        "embedding_dim": embed_dim,
        "n_layers": n_layers,
    }

    # Estimate head count from Q projection
    if qkv_tensors["q_proj"]:
        q_shape = qkv_tensors["q_proj"][0].shape
        config["q_shape"] = q_shape
        # Try common head dimensions
        for head_dim in [64, 128, 256]:
            if embed_dim % head_dim == 0:
                config["n_heads"] = embed_dim // head_dim
                config["head_dim"] = head_dim
                break

    if qkv_tensors["k_proj"]:
        k_shape = qkv_tensors["k_proj"][0].shape
        config["k_shape"] = k_shape
        # GQA: K/V may have fewer heads than Q
        k_dim = k_shape[0] if len(k_shape) > 0 else embed_dim
        if config.get("head_dim") and k_dim != embed_dim:
            config["n_kv_heads"] = k_dim // config["head_dim"]
            config["gqa_ratio"] = config.get("n_heads", 0) // config["n_kv_heads"] if config.get("n_kv_heads") else 1

    return config

attn_config = extract_attention_config(model_path)

print("Attention Configuration:")
for k, v in attn_config.items():
    print(f"  {k}: {v}")

## Attention Patterns

Run inference and analyze how the model distributes attention across tokens.

In [None]:
import time

# Run inference with timing per phase
test_prompts = [
    ("Short", "What is AI?"),
    ("Medium", "Explain the transformer architecture and its key innovations over previous approaches."),
    ("Long", "Write a detailed comparison of RNNs, LSTMs, and Transformers for sequence modeling, "
             "covering their architectures, strengths, weaknesses, and modern applications."),
]

for name, prompt in test_prompts:
    # Tokenize to get token count
    tokens = client.tokenize(prompt)
    n_input = len(tokens.tokens)

    t0 = time.perf_counter()
    resp = client.chat.completions.create(
        messages=[{"role": "user", "content": prompt}],
        max_tokens=64, temperature=0.7,
    )
    elapsed = time.perf_counter() - t0

    n_output = resp.usage.completion_tokens
    n_total = resp.usage.prompt_tokens + n_output

    print(f"{name} prompt ({n_input} tokens → {n_output} output):")
    print(f"  Total time: {elapsed*1000:.0f} ms")
    print(f"  Prefill: ~{resp.usage.prompt_tokens} tokens processed")
    print(f"  Decode: {n_output/elapsed:.1f} tok/s\n")

## Head Specialization

Visualize how different attention heads might specialize in different patterns.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

n_heads = attn_config.get("n_heads", 8)
head_dim = attn_config.get("head_dim", 64)

# Simulate attention patterns for visualization
# (Real attention weights require model internals access)
np.random.seed(42)
seq_len = 12
tokens_display = ["[BOS]", "The", "cat", "sat", "on", "the", "mat", "and", "looked", "at", "me", "[EOS]"]

fig, axes = plt.subplots(2, min(4, n_heads // 2 + 1), figsize=(16, 8))
axes = axes.flatten()

patterns = {
    0: "Local (nearby tokens)",
    1: "Global (attend to [BOS])",
    2: "Diagonal (self + next)",
    3: "Content (nouns)",
    4: "Syntactic (verbs)",
    5: "Positional (even/odd)",
    6: "Uniform",
    7: "Sink (first token)",
}

for h in range(min(8, len(axes))):
    # Generate synthetic attention patterns
    attn = np.random.rand(seq_len, seq_len) * 0.1
    if h == 0:  # Local attention
        for i in range(seq_len):
            for j in range(max(0, i-2), min(seq_len, i+3)):
                attn[i, j] += 0.5
    elif h == 1:  # Global
        attn[:, 0] += 0.8
    elif h == 2:  # Diagonal
        for i in range(seq_len):
            attn[i, i] += 0.7
            if i + 1 < seq_len:
                attn[i, i+1] += 0.3
    elif h == 3:  # Content-based
        noun_pos = [2, 6]  # "cat", "mat"
        for i in range(seq_len):
            for j in noun_pos:
                attn[i, j] += 0.6

    # Normalize rows
    attn = attn / attn.sum(axis=1, keepdims=True)

    im = axes[h].imshow(attn, cmap="Blues", aspect="auto")
    axes[h].set_title(f"Head {h}: {patterns.get(h, 'Mixed')}"[:30], fontsize=9)
    if h >= 4:
        axes[h].set_xticks(range(seq_len))
        axes[h].set_xticklabels(tokens_display, rotation=45, ha="right", fontsize=7)
    else:
        axes[h].set_xticks([])

plt.suptitle(f"Attention Head Patterns ({n_heads} heads, dim={head_dim})", fontsize=12)
plt.tight_layout()
plt.show()

## QKV Analysis

Analyze the Q, K, V weight matrices from the GGUF file.

In [None]:
info_full = parse_gguf_header(model_path, read_tensors=True)

# Categorize attention tensors per layer
qkv_by_layer = {}
for t in info_full.tensors:
    for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]:
        if proj in t.name:
            # Extract layer number from tensor name
            parts = t.name.split(".")
            layer_num = None
            for part in parts:
                if part.isdigit():
                    layer_num = int(part)
                    break
            if layer_num is not None:
                qkv_by_layer.setdefault(layer_num, {})[proj] = t

print(f"Attention layers found: {len(qkv_by_layer)}")
print(f"\n{'Layer':<8} {'Q shape':<20} {'K shape':<20} {'V shape':<20} {'O shape'}")
print("-" * 80)
for layer in sorted(list(qkv_by_layer.keys()))[:5]:  # Show first 5 layers
    tensors = qkv_by_layer[layer]
    q_s = str(tensors.get("q_proj", {}).shape) if "q_proj" in tensors else "N/A"
    k_s = str(tensors.get("k_proj", {}).shape) if "k_proj" in tensors else "N/A"
    v_s = str(tensors.get("v_proj", {}).shape) if "v_proj" in tensors else "N/A"
    o_s = str(tensors.get("o_proj", {}).shape) if "o_proj" in tensors else "N/A"
    print(f"{layer:<8} {q_s:<20} {k_s:<20} {v_s:<20} {o_s}")

# Check for Grouped Query Attention (GQA)
if qkv_by_layer:
    first_layer = qkv_by_layer[min(qkv_by_layer.keys())]
    if "q_proj" in first_layer and "k_proj" in first_layer:
        q_dim = first_layer["q_proj"].shape[0] if first_layer["q_proj"].shape else 0
        k_dim = first_layer["k_proj"].shape[0] if first_layer["k_proj"].shape else 0
        if q_dim != k_dim and k_dim > 0:
            ratio = q_dim // k_dim
            print(f"\nGrouped Query Attention (GQA) detected: ratio={ratio}:1")
            print(f"  Q heads: {q_dim // head_dim}, KV heads: {k_dim // head_dim}")
        else:
            print(f"\nStandard Multi-Head Attention (MHA): Q=K=V dimensions")

## Key Findings

| Aspect | Details |
|--------|--------|
| **MHA vs GQA** | Modern models use Grouped Query Attention to reduce KV cache |
| **Head dim** | Typically 64 or 128 — determines attention resolution |
| **Head count** | More heads = more diverse attention patterns |
| **KV cache cost** | Proportional to `n_kv_heads × head_dim × seq_len × n_layers` |

In [None]:
mgr.stop_server()
llamatelemetry.shutdown()
print("Done.")