# SinkVis Usage Guide

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/zeinalii/SinkVis/blob/main/examples/usage_guide.ipynb)

This notebook demonstrates how to use SinkVis in your projects for:
1. Analyzing attention patterns
2. Testing eviction policies
3. Integrating with HuggingFace models
4. Live streaming to the visualizer


## Setup

First, clone SinkVis and install dependencies (required for Google Colab):


In [None]:
# Clone SinkVis repository (for Google Colab)
import subprocess
import os

if not os.path.exists('SinkVis'):
    subprocess.run(['git', 'clone', 'https://github.com/zeinalii/SinkVis.git'], check=True)
    print("✓ Repository cloned")
else:
    print("✓ Repository already exists")

# Install dependencies
%pip install -q matplotlib numpy websockets

import sys
from pathlib import Path

# Add SinkVis to path (works for both Colab and local)
SINKVIS_PATH = Path("SinkVis").resolve() if Path("SinkVis").exists() else Path(".").resolve().parent
sys.path.insert(0, str(SINKVIS_PATH))

print(f"✓ SinkVis loaded from: {SINKVIS_PATH}")


SinkVis path: /Users/amir/Documents/stuff/SinkVis


---
## 1. Identify Attention Sinks and Heavy Hitters

Use SinkVis to find important tokens in any attention matrix.


In [2]:
import numpy as np
from backend.attention import (
    generate_attention_pattern,
    identify_sinks,
    identify_heavy_hitters,
    create_attention_frame
)

# Generate a realistic attention pattern (or use your own)
seq_len = 32
attention = generate_attention_pattern(seq_len, num_sinks=4)

print(f"Attention matrix shape: {attention.shape}")
print(f"Each row sums to: {attention[10, :11].sum():.4f}")


Attention matrix shape: (32, 32)
Each row sums to: 1.0000


In [4]:
# Find attention sinks (tokens that receive high attention)
sinks = identify_sinks(attention, threshold=0.1)
print(f"Attention sinks at positions: {sinks}")

# Find heavy hitters (semantically important tokens)
heavy_hitters = identify_heavy_hitters(attention, threshold=0.05, exclude_sinks=sinks)
print(f"Heavy hitters at positions: {heavy_hitters}")


Attention sinks at positions: []
Heavy hitters at positions: [0, 1, 2, 3, 13, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]


In [None]:
# Visualize with matplotlib
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(attention, cmap='hot', aspect='auto')

# Mark sinks with cyan lines
for s in sinks:
    ax.axvline(x=s, color='cyan', linewidth=2, label='Sink' if s == sinks[0] else '')

# Mark heavy hitters with green lines
for h in heavy_hitters:
    ax.axvline(x=h, color='lime', linewidth=1, linestyle='--', label='Heavy Hitter' if h == heavy_hitters[0] else '')

ax.set_xlabel('Key Position')
ax.set_ylabel('Query Position')
ax.set_title('Attention Pattern with Sinks and Heavy Hitters')
plt.colorbar(im, ax=ax, label='Attention Weight')
ax.legend(loc='upper right')
plt.tight_layout()
plt.show()


---
## 2. Test Eviction Policies

Compare different KV cache eviction strategies on your prompts.


In [None]:
from backend.eviction import run_simulation, tokenize_simple
from backend.models import SimulationConfig, EvictionPolicy

# Your conversation/prompt
prompt = """
System: You are a helpful assistant. The user's name is Alice and she likes cats.
User: What's my name?
Assistant: Your name is Alice.
User: Tell me a long story about dragons and knights.
Assistant: Once upon a time in a kingdom far away, there lived a brave knight...
User: What pet do I like?
"""

# See how it tokenizes
tokens = tokenize_simple(prompt)
print(f"Token count: {len(tokens)}")
print(f"First 10 tokens: {tokens[:10]}")


In [None]:
# Test StreamingLLM policy
config = SimulationConfig(
    policy=EvictionPolicy.STREAMING_LLM,
    cache_size=64,       # Max tokens in cache
    sink_count=4,        # Preserve first 4 tokens
    window_size=32       # Keep last 32 tokens
)

result = run_simulation(prompt, config, generate_tokens=32)

print(f"Policy: {result.policy.value}")
print(f"Tokens processed: {result.total_tokens_processed}")
print(f"Cache hits: {result.cache_hits}")
print(f"Cache misses: {result.cache_misses}")
print(f"Evictions: {result.evictions}")
print(f"Retained sinks: {result.retained_sinks}")
print(f"Retained heavy hitters: {result.retained_heavy_hitters}")


In [None]:
# Compare ALL eviction policies
policies = [
    EvictionPolicy.FULL,           # No eviction (baseline)
    EvictionPolicy.LRU,            # Least Recently Used
    EvictionPolicy.SLIDING_WINDOW, # Keep only recent tokens
    EvictionPolicy.STREAMING_LLM,  # Sinks + sliding window
    EvictionPolicy.H2O,            # Sinks + heavy hitters
]

results = []
for policy in policies:
    config = SimulationConfig(
        policy=policy,
        cache_size=64,
        sink_count=4,
        window_size=32
    )
    result = run_simulation(prompt, config, generate_tokens=32)
    results.append({
        'policy': policy.value,
        'evictions': result.evictions,
        'sinks_kept': result.retained_sinks,
        'heavy_kept': result.retained_heavy_hitters,
        'cached': result.final_cache.total_tokens
    })

# Display results
for r in results:
    print(f"{r['policy']:15} | evictions: {r['evictions']:3} | sinks: {r['sinks_kept']} | cached: {r['cached']}")


In [None]:
# Visualize comparison
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

policies_names = [r['policy'] for r in results]

axes[0].bar(policies_names, [r['evictions'] for r in results], color='coral')
axes[0].set_title('Evictions (lower is better)')
axes[0].tick_params(axis='x', rotation=45)

axes[1].bar(policies_names, [r['sinks_kept'] for r in results], color='cyan')
axes[1].set_title('Sinks Retained')
axes[1].tick_params(axis='x', rotation=45)

axes[2].bar(policies_names, [r['cached'] for r in results], color='lime')
axes[2].set_title('Final Cache Size')
axes[2].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()


---
## 3. Integrate with HuggingFace Transformers

Extract real attention weights from a HuggingFace model.


In [None]:
# Install if needed: pip install transformers torch
try:
    from transformers import AutoModelForCausalLM, AutoTokenizer
    import torch
    HF_AVAILABLE = True
    print("✓ HuggingFace transformers available")
except ImportError:
    print("✗ Install with: pip install transformers torch")
    HF_AVAILABLE = False


In [None]:
if HF_AVAILABLE:
    # Load a small model (GPT-2)
    model_name = "gpt2"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True)
    model.eval()
    
    print(f"Loaded {model_name}")
    print(f"Layers: {model.config.n_layer}, Heads: {model.config.n_head}")


In [None]:
if HF_AVAILABLE:
    # Run inference and get attention
    text = "The quick brown fox jumps over the lazy dog."
    inputs = tokenizer(text, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get attention from last layer, first head
    # Shape: (batch, heads, seq_len, seq_len)
    attention_weights = outputs.attentions[-1][0, 0].numpy()
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    
    print(f"Attention shape: {attention_weights.shape}")
    print(f"Tokens: {tokens}")


In [None]:
if HF_AVAILABLE:
    # Analyze with SinkVis
    sinks = identify_sinks(attention_weights, threshold=0.15)
    heavy = identify_heavy_hitters(attention_weights, threshold=0.08, exclude_sinks=sinks)
    
    print(f"\nAttention Sinks: {[tokens[i] for i in sinks]}")
    print(f"Heavy Hitters: {[tokens[i] for i in heavy]}")
    
    # Visualize
    fig, ax = plt.subplots(figsize=(10, 8))
    im = ax.imshow(attention_weights, cmap='hot')
    ax.set_xticks(range(len(tokens)))
    ax.set_yticks(range(len(tokens)))
    ax.set_xticklabels(tokens, rotation=45, ha='right')
    ax.set_yticklabels(tokens)
    ax.set_xlabel('Key (attended to)')
    ax.set_ylabel('Query (attending from)')
    ax.set_title(f'GPT-2 Attention (Last Layer, Head 0)')
    plt.colorbar(im, ax=ax)
    plt.tight_layout()
    plt.show()


---
## 4. Stream to Live Visualizer

Send attention data to the SinkVis web interface in real-time.

**First, start the server in a terminal:**
```bash
cd SinkVis && python run.py
```

Then open http://localhost:8765


In [None]:
import asyncio
import json

try:
    import websockets
    WS_AVAILABLE = True
    print("✓ websockets available")
except ImportError:
    print("✗ Install with: pip install websockets")
    WS_AVAILABLE = False


In [None]:
async def send_attention_to_sinkvis(attention_matrix, tokens, layer=0, head=0):
    """
    Send an attention matrix to SinkVis for visualization.
    
    Args:
        attention_matrix: numpy array of shape (seq_len, seq_len)
        tokens: list of token strings
        layer: layer number (for display)
        head: head number (for display)
    """
    uri = "ws://localhost:8765/ws/attention"
    
    # Find sinks and heavy hitters
    sinks = identify_sinks(attention_matrix, threshold=0.1)
    heavy = identify_heavy_hitters(attention_matrix, threshold=0.05, exclude_sinks=sinks)
    
    frame = {
        "type": "frame",
        "data": {
            "layer": layer,
            "head": head,
            "seq_len": len(tokens),
            "attention_weights": attention_matrix.tolist(),
            "token_labels": tokens,
            "sink_indices": sinks,
            "heavy_hitter_indices": heavy,
            "timestamp": 0.0
        }
    }
    
    try:
        async with websockets.connect(uri) as ws:
            await ws.send(json.dumps(frame))
            print(f"✓ Sent attention frame ({len(tokens)} tokens)")
    except Exception as e:
        print(f"✗ Could not connect to SinkVis: {e}")
        print("  Make sure the server is running: python run.py")


In [None]:
# Send synthetic attention to visualizer
if WS_AVAILABLE:
    test_attention = generate_attention_pattern(20)
    test_tokens = [f"tok_{i}" for i in range(20)]
    
    # Run async function
    await send_attention_to_sinkvis(test_attention, test_tokens)


In [None]:
---
## 5. Cache Block Analysis

Analyze how tokens are distributed across memory tiers.


In [None]:
from backend.attention import generate_cache_blocks
from backend.models import MemoryTier

# Generate cache blocks for a 256-token sequence
blocks = generate_cache_blocks(
    seq_len=256,
    block_size=16,
    sink_indices=[0, 1, 2, 3],
    heavy_hitter_indices=[50, 100, 150]
)

print(f"Total blocks: {len(blocks)}")

# Count by tier
tier_counts = {}
tier_sizes = {}
for block in blocks:
    tier = block.memory_tier.value
    tier_counts[tier] = tier_counts.get(tier, 0) + 1
    tier_sizes[tier] = tier_sizes.get(tier, 0) + block.size_bytes

print("\nBlocks per tier:")
for tier, count in tier_counts.items():
    print(f"  {tier}: {count} blocks ({tier_sizes[tier] / 1024:.1f} KB)")


In [None]:
# Visualize memory hierarchy
tier_colors = {
    'gpu_hbm': '#00ff88',
    'gpu_l2': '#00ccff',
    'system_ram': '#ffd93d',
    'disk': '#ff6b9d'
}

fig, ax = plt.subplots(figsize=(14, 3))

for i, block in enumerate(blocks):
    color = tier_colors[block.memory_tier.value]
    edgecolor = 'red' if block.is_sink else ('cyan' if block.is_heavy_hitter else 'none')
    linewidth = 2 if block.is_sink or block.is_heavy_hitter else 0
    
    ax.bar(i, 1, color=color, edgecolor=edgecolor, linewidth=linewidth)

ax.set_xlabel('Block Index')
ax.set_title('KV Cache Memory Hierarchy')
ax.set_yticks([])

# Legend
from matplotlib.patches import Patch
legend_elements = [
    Patch(facecolor='#00ff88', label='GPU HBM'),
    Patch(facecolor='#00ccff', label='GPU L2'),
    Patch(facecolor='#ffd93d', label='System RAM'),
    Patch(facecolor='#ff6b9d', label='Disk'),
]
ax.legend(handles=legend_elements, loc='upper right')

plt.tight_layout()
plt.show()


---
## Summary

| Function | Use Case |
|----------|----------|
| `identify_sinks()` | Find tokens receiving high attention |
| `identify_heavy_hitters()` | Find semantically important tokens |
| `run_simulation()` | Test eviction policies on prompts |
| `generate_cache_blocks()` | Analyze memory hierarchy |
| WebSocket streaming | Real-time visualization |

For more, see the [README.md](../README.md) or run the web interface:
```bash
python run.py
```
