# SwiftKV Strategy Search with Visualization

This notebook runs the adaptive best-first strategy search with comprehensive visualizations:
- **KV cache distance heatmaps** - Visual similarity between all layer pairs
- **Distance distribution analysis** - Statistical overview of layer distances  
- **Top similar/dissimilar pairs** - Best and worst candidates for sharing
- **Live search progress** - Round-by-round evaluation of candidates
- **Architecture diagram** - Visual representation of final sharing strategy


## 1. Setup and Imports


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset
from IPython.display import display, HTML
import pandas as pd
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# Set matplotlib style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# Import SwiftKV modules
from projects.swiftkv.models.qwen3 import Qwen3SwiftKVForCausalLM
from projects.swiftkv.models.llama import LlamaSwiftKVForCausalLM
from projects.swiftkv.strategy_search import (
    run_strategy_search,
    collect_kv_means_streaming,
    pairwise_distance_matrix,
    rank_layer_pairs,
    make_loader
)
from transformers import AutoTokenizer

print("‚úÖ All imports successful!")


## 2. Configuration


In [None]:
# Model configuration
MODEL_NAME = "Qwen/Qwen3-8B"
DEVICE = "cuda:2"

# Strategy search configuration
TARGET_SHARED_LAYERS = 8
SIMILARITY_THRESHOLD = 0.6
DISTANCE_METRIC = "euclidean"  # Options: "euclidean", "cosine", "manhattan"
TOP_K = 15 # Number of candidates to evaluate per round

# Data configuration  
NUM_CALIBRATION_SAMPLES = 9  # Using 5 custom scientific passages
BATCH_SIZE = 4
MAX_LENGTH = 512

print(f"Model: {MODEL_NAME}")
print(f"Target shared layers: {TARGET_SHARED_LAYERS}")
print(f"Similarity threshold: {SIMILARITY_THRESHOLD}")
print(f"Calibration: {NUM_CALIBRATION_SAMPLES} diverse scientific passages (~200 tokens each)")


## 3. Load Calibration Data


In [None]:
# Calibration data: 5 diverse scientific passages (~200 tokens each)
texts = [
    """When two particles become quantum entangled, how does a measurement on one influence the other, and what fundamental principle of locality does this challenge?""",
    
    """According to the central dogma of molecular biology, how is genetic information transferred from DNA to functional proteins, and what processes add complexity beyond this basic framework?""",
    
    """What role does the event horizon play in defining a black hole, and how have recent astronomical observations helped confirm theoretical predictions about these objects?""",
    
    """How do neurons communicate through electrical and chemical signals, and what mechanisms allow the brain to adapt and reorganize over time?""",
    
    """How do greenhouse gases influence Earth‚Äôs energy balance, and what feedback mechanisms amplify or mitigate global climate change according to current scientific understanding?"""

    """How does quantum computing leverage the principles of quantum mechanics to process information, and what are the potential applications of this technology?""",

    """How does artificial intelligence leverage machine learning algorithms to mimic human intelligence, and what are the ethical considerations associated with this technology?""",

    """How does machine learning leverage statistical methods to improve decision-making and prediction accuracy, and what are the challenges in scaling this technology to large datasets?""",

    """How does computer vision leverage deep learning algorithms to analyze and understand visual data, and what are the applications of this technology in fields such as autonomous driving and medical diagnosis?""",
    
]

print(f"‚úÖ Created {len(texts)} calibration samples")
print(f"Domains: Quantum Physics, Molecular Biology, Astrophysics, Neuroscience, Climate Science")
print(f"\nToken counts (approximate):")
for i, text in enumerate(texts, 1):
    word_count = len(text.split())
    print(f"  {i}. ~{word_count} words (~{int(word_count * 1.3)} tokens)")


## 4. Load Model and Compute KV Statistics


In [None]:
# Load tokenizer and model
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"\nLoading model: {MODEL_NAME}...")
model = Qwen3SwiftKVForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map=DEVICE
)
model.eval()

num_layers = model.config.num_hidden_layers
print(f"‚úÖ Model loaded! Total layers: {num_layers}")


In [None]:
# Compute KV statistics
print("\nComputing KV cache statistics...")
loader = make_loader(texts, tokenizer, BATCH_SIZE, MAX_LENGTH)
kv_means = collect_kv_means_streaming(model, loader, device=DEVICE)

# Compute distance matrix
print("\nComputing pairwise distance matrix...")
distance_matrix = pairwise_distance_matrix(kv_means, num_layers, metric=DISTANCE_METRIC)

print(f"\n‚úÖ Distance matrix: {distance_matrix.shape}")
print(f"   Range: [{distance_matrix.min():.6f}, {distance_matrix.max():.6f}]")
print(f"   Mean: {distance_matrix[np.triu_indices_from(distance_matrix, k=1)].mean():.6f}")


## 5. Visualize KV Distance Heatmap

Blue = Similar layers (good for sharing) | Red = Dissimilar layers (avoid sharing)


In [None]:
fig, ax = plt.subplots(figsize=(16, 14))

sns.heatmap(
    distance_matrix,
    cmap='RdYlBu_r',
    square=True,
    linewidths=0.5,
    cbar_kws={'label': f'{DISTANCE_METRIC.capitalize()} Distance', 'shrink': 0.8},
    xticklabels=range(0, num_layers, max(1, num_layers//20)),
    yticklabels=range(0, num_layers, max(1, num_layers//20)),
    ax=ax
)

ax.set_title(f'KV Cache Distance Between Layers ({DISTANCE_METRIC})', 
             fontsize=18, fontweight='bold', pad=20)
ax.set_xlabel('Layer Index', fontsize=14, fontweight='bold')
ax.set_ylabel('Layer Index', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("üí° Blue = Very similar | Green = Moderately similar | Yellow/Red = Dissimilar")


## 6. Distance Distribution and Top Pairs


In [None]:
# Distance distribution
distances_list = distance_matrix[np.triu_indices_from(distance_matrix, k=1)]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

ax1.hist(distances_list, bins=50, edgecolor='black', alpha=0.7, color='skyblue')
ax1.axvline(np.mean(distances_list), color='r', linestyle='--', linewidth=2, 
            label=f'Mean: {np.mean(distances_list):.4f}')
ax1.axvline(np.median(distances_list), color='g', linestyle='--', linewidth=2,
            label=f'Median: {np.median(distances_list):.4f}')
ax1.set_xlabel('Distance', fontsize=12)
ax1.set_ylabel('Frequency', fontsize=12)
ax1.set_title('Distribution of KV Cache Distances', fontsize=14, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

bp = ax2.boxplot(distances_list, vert=True, patch_artist=True)
bp['boxes'][0].set_facecolor('lightblue')
ax2.set_ylabel('Distance', fontsize=12)
ax2.set_title('Distance Statistics', fontsize=14, fontweight='bold')
ax2.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print(f"\nüìä Statistics: Min={np.min(distances_list):.6f}, "
      f"Median={np.median(distances_list):.6f}, Max={np.max(distances_list):.6f}")


In [None]:
# Show top similar and dissimilar pairs
distances_dict = {}
for i in range(num_layers):
    for j in range(i+1, num_layers):
        distances_dict[(i, j)] = distance_matrix[i, j]

sorted_pairs = sorted(distances_dict.items(), key=lambda x: x[1])

print("\n" + "="*70)
print("üîµ Top 10 Most Similar Pairs (Best candidates for sharing)")
print("="*70)
for rank, ((i, j), dist) in enumerate(sorted_pairs[:10], 1):
    print(f"{rank:2d}. Layers {i:2d} <-> {j:2d}  |  Dist: {dist:.6f}  |  Gap: {abs(j-i):2d}")

print("\n" + "="*70)
print("üî¥ Top 10 Most Dissimilar Pairs (Should NOT share)")
print("="*70)
for rank, ((i, j), dist) in enumerate(sorted_pairs[-10:][::-1], 1):
    print(f"{rank:2d}. Layers {i:2d} <-> {j:2d}  |  Dist: {dist:.6f}  |  Gap: {abs(j-i):2d}")


## 7. Run Adaptive Best-First Strategy Search

This will evaluate candidates round-by-round and select the best sharing strategy!


In [None]:
# Run the adaptive strategy search
sharing_map, diagnostics = run_strategy_search(
    model_class=Qwen3SwiftKVForCausalLM,
    model_name=MODEL_NAME,
    calibration_texts=texts,
    target_shared_layers=TARGET_SHARED_LAYERS,
    similarity_threshold=SIMILARITY_THRESHOLD,
    policy="greedy",
    distance_metric=DISTANCE_METRIC,
    top_k=TOP_K,
    batch_size=BATCH_SIZE,
    max_length=MAX_LENGTH,
    device=DEVICE,
    use_direct_similarity=True
)


## 8. Visualize Results


In [None]:
# Display results table
results_df = pd.DataFrame([
    {"Metric": "Total Layers", "Value": int(diagnostics['layers_total'])},
    {"Metric": "Shared Layers", "Value": int(diagnostics['layers_shared'])},
    {"Metric": "Producer Layers", "Value": int(diagnostics['layers_total'] - diagnostics['layers_shared'])},
    {"Metric": "Memory Reduction", "Value": f"{diagnostics['memory_reduction_pct']:.2f}%"},
    {"Metric": "Final Similarity", "Value": f"{diagnostics['similarity']:.6f}"},
])

display(HTML("<h3>üìä Strategy Search Results</h3>"))
display(results_df)

if sharing_map:
    sharing_df = pd.DataFrame([
        {"Consumer Layer": consumer, "Producer Layer": producer}
        for consumer, producer in sorted(sharing_map.items())
    ])
    display(HTML("<h3>üîó Sharing Map (Consumer ‚Üí Producer)</h3>"))
    display(sharing_df)
    print(f"\nFinal sharing map: {sharing_map}")
else:
    print("‚ö†Ô∏è No sharing pairs found!")
