# SinkVis Demo with Real LLM

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/zeinalii/SinkVis/blob/main/examples/real_model_demo.ipynb)

This notebook demonstrates SinkVis with a real open-source language model.

We'll:
1. **Setup**: Clone SinkVis and install dependencies
2. Download DistilGPT-2 (small, fast, ~350MB)
3. Extract real attention patterns
4. Identify attention sinks and heavy hitters
5. Simulate KV cache eviction policies
6. Visualize results


## Step 1: Clone SinkVis & Install Dependencies

Run this cell first to set up the environment (required for Google Colab).


In [None]:
# Clone SinkVis repository
import subprocess
import os

if not os.path.exists('SinkVis'):
    subprocess.run(['git', 'clone', 'https://github.com/zeinalii/SinkVis.git'], check=True)
    print("✓ Repository cloned")
else:
    print("✓ Repository already exists")

# Install dependencies
%pip install -q transformers torch matplotlib numpy websockets

print("✓ Setup complete!")


In [None]:
import sys
from pathlib import Path

# Add SinkVis to path (works for both Colab and local)
SINKVIS_PATH = Path("SinkVis").resolve() if Path("SinkVis").exists() else Path(".").resolve().parent
sys.path.insert(0, str(SINKVIS_PATH))

import numpy as np
import matplotlib.pyplot as plt

# Import SinkVis modules
from backend.attention import identify_sinks, identify_heavy_hitters
from backend.eviction import run_simulation
from backend.models import SimulationConfig, EvictionPolicy

print(f"✓ SinkVis loaded from: {SINKVIS_PATH}")


## Step 2: Download and Load Model

We'll use **DistilGPT-2** - a distilled version of GPT-2 that's:
- Small (~350MB)
- Fast to run on CPU
- Has 6 layers, 12 attention heads


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Download and load model (first run will download ~350MB)
MODEL_NAME = "distilgpt2"

print(f"Loading {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, output_attentions=True)
model.eval()

print(f"✓ Model loaded!")
print(f"  - Layers: {model.config.n_layer}")
print(f"  - Heads per layer: {model.config.n_head}")
print(f"  - Hidden size: {model.config.n_embd}")
print(f"  - Vocab size: {model.config.vocab_size}")


## Step 3: Run Inference and Extract Attention


In [None]:
# Input prompt - try changing this!
prompt = """The transformer architecture has revolutionized natural language processing. 
Self-attention allows models to focus on relevant parts of the input sequence. 
The key insight is that attention weights determine which tokens influence the output."""

# Tokenize
inputs = tokenizer(prompt, return_tensors="pt")
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

print(f"Input: {len(tokens)} tokens")
print(f"Tokens: {tokens[:10]}... (showing first 10)")


In [None]:
# Run model and get attention weights
with torch.no_grad():
    outputs = model(**inputs)

# outputs.attentions is a tuple of (batch, heads, seq, seq) for each layer
print(f"\nAttention tensors: {len(outputs.attentions)} layers")
print(f"Shape per layer: {outputs.attentions[0].shape}")

# Extract attention from last layer, average across heads
last_layer_attention = outputs.attentions[-1][0]  # (heads, seq, seq)
avg_attention = last_layer_attention.mean(dim=0).numpy()  # (seq, seq)

print(f"\nAveraged attention shape: {avg_attention.shape}")


## Step 4: Identify Attention Sinks

**Attention sinks** are tokens that receive disproportionately high attention regardless of their semantic content. Research shows these are typically:
- The BOS (beginning of sequence) token
- First few tokens in the sequence


In [None]:
# Find attention sinks
sinks = identify_sinks(avg_attention, threshold=0.08)
print("Attention Sinks:")
for idx in sinks:
    avg_weight = avg_attention[:, idx].mean()
    print(f"  Position {idx}: '{tokens[idx]}' (avg attention: {avg_weight:.3f})")


In [None]:
# Find heavy hitters (semantically important tokens)
heavy_hitters = identify_heavy_hitters(avg_attention, threshold=0.03, exclude_sinks=sinks)
print("\nHeavy Hitters (semantically important):")
for idx in heavy_hitters[:10]:  # Show top 10
    avg_weight = avg_attention[:, idx].mean()
    print(f"  Position {idx}: '{tokens[idx]}' (avg attention: {avg_weight:.3f})")
