In [1]:
import os
os.chdir('/data/apdesai/code/sparse-attention-hub')

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
import numpy as np
from typing import Dict, List

from sparse_attention_hub.sparse_attention.research_attention import ResearchAttentionConfig
from sparse_attention_hub.sparse_attention.research_attention.maskers.fixed.implementations import (
    LocalMaskerConfig, SinkMaskerConfig, HashAttentionTopKMaskerConfig
)
from sparse_attention_hub.sparse_attention.integrations.hugging_face import SparseAttentionHF

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")


Device: cuda


In [2]:
def extract_linear_weights(sequential_module):
    """
    Extract weight matrices and biases from Sequential module containing Linear layers.
    
    Args:
        sequential_module: nn.Sequential containing Linear layers and activations
        
    Returns:
        tuple: (weight_matrices, bias_vectors)
    """
    matrices = []
    biases = []
    
    for module in sequential_module:
        if isinstance(module, torch.nn.Linear):
            matrices.append(module.weight.data.clone())  # Shape: (out_features, in_features)
            if module.bias is not None:
                biases.append(module.bias.data.clone())
            else:
                biases.append(torch.zeros(module.out_features))
    
    return matrices, biases

def convert_usa_weights_to_hash_attention(usa_checkpoint_path: str, num_layers: int = 32, num_heads: int = 32) -> Dict[int, Dict[str, List[torch.Tensor]]]:
    """
    Convert USA module weights to HashAttentionTopKMasker format.
    
    Args:
        usa_checkpoint_path: Path to USA module checkpoint
        num_layers: Number of transformer layers
        num_heads: Number of attention heads
        
    Returns:
        Dict mapping layer_idx to weight dictionaries
    """
    # Load USA checkpoint
    print(f"Loading USA weights from {usa_checkpoint_path}")
    usa_state_dict = torch.load(usa_checkpoint_path, map_location='cpu')
    
    hat_weights = {}
    
    for layer_idx in range(num_layers):
        layer_weights = {
            "query_matrix": [],
            "query_bias": [],
            "key_matrix": [],
            "key_bias": []
        }
        
        # For each head, extract weights and stack them
        # First, collect all weights for this layer
        query_matrices_per_layer = [[] for _ in range(3)]  # 3 linear layers
        query_biases_per_layer = [[] for _ in range(3)]
        key_matrices_per_layer = [[] for _ in range(3)]
        key_biases_per_layer = [[] for _ in range(3)]
        
        for head_idx in range(num_heads):
            # Extract query transformation weights
            query_prefix = f"{layer_idx}.learning_to_hash_transformation_q.{head_idx}"
            key_prefix = f"{layer_idx}.learning_to_hash_transformation_k.{head_idx}"
            
            # For 3-layer MLP: 0.weight, 0.bias, 2.weight, 2.bias, 4.weight, 4.bias
            # (indices 1, 3 are SiLU activations)
            for i, linear_idx in enumerate([0, 2, 4]):  # Linear layer indices in Sequential
                weight_key = f"{query_prefix}.{linear_idx}.weight"
                bias_key = f"{query_prefix}.{linear_idx}.bias"
                
                if weight_key in usa_state_dict:
                    # HashAttentionTopKMasker expects shape (H, in_features, out_features)
                    # USA stores as (out_features, in_features), so we need to transpose
                    weight = usa_state_dict[weight_key].t()  # Transpose to (in_features, out_features)
                    query_matrices_per_layer[i].append(weight)
                    
                    if bias_key in usa_state_dict:
                        query_biases_per_layer[i].append(usa_state_dict[bias_key])
                    else:
                        query_biases_per_layer[i].append(torch.zeros(usa_state_dict[weight_key].shape[0]))
                
                # Same for key weights
                weight_key = f"{key_prefix}.{linear_idx}.weight"
                bias_key = f"{key_prefix}.{linear_idx}.bias"
                
                if weight_key in usa_state_dict:
                    weight = usa_state_dict[weight_key].t()  # Transpose to (in_features, out_features)
                    key_matrices_per_layer[i].append(weight)
                    
                    if bias_key in usa_state_dict:
                        key_biases_per_layer[i].append(usa_state_dict[bias_key])
                    else:
                        key_biases_per_layer[i].append(torch.zeros(usa_state_dict[weight_key].shape[0]))
        
        # Stack all heads for each layer
        for i in range(3):
            if query_matrices_per_layer[i]:
                layer_weights["query_matrix"].append(torch.stack(query_matrices_per_layer[i]))
                layer_weights["query_bias"].append(torch.stack(query_biases_per_layer[i]))
                layer_weights["key_matrix"].append(torch.stack(key_matrices_per_layer[i]))
                layer_weights["key_bias"].append(torch.stack(key_biases_per_layer[i]))
        
        hat_weights[layer_idx] = layer_weights
    
    print(f"✅ Converted weights for {num_layers} layers, {num_heads} heads")
    return hat_weights

print("✅ Weight conversion functions defined")


✅ Weight conversion functions defined


In [3]:
# Load and convert USA weights
usa_checkpoint_path = "/data/apdesai/code/HashAttention-1.0/artifacts/llama3.1-8b-patch.32K.v1.pt"

# Check if file exists
if not os.path.exists(usa_checkpoint_path):
    print(f"❌ USA checkpoint not found at {usa_checkpoint_path}")
    print("Please make sure the checkpoint file exists at the specified path.")
    # Create dummy weights for demonstration
    print("Creating dummy weights for demonstration...")
    hat_weights = {}
    for layer_idx in range(32):
        hat_weights[layer_idx] = {
            "query_matrix": [
                torch.randn(32, 128, 128),  # First linear layer
                torch.randn(32, 128, 128),  # Second linear layer
                torch.randn(32, 128, 32),   # Third linear layer
            ],
            "query_bias": [
                torch.randn(32, 128),
                torch.randn(32, 128),
                torch.randn(32, 32),
            ],
            "key_matrix": [
                torch.randn(32, 128, 128),  # First linear layer
                torch.randn(32, 128, 128),  # Second linear layer
                torch.randn(32, 128, 32),   # Third linear layer
            ],
            "key_bias": [
                torch.randn(32, 128),
                torch.randn(32, 128),
                torch.randn(32, 32),
            ],
        }
    print("✅ Created dummy weights")
else:
    try:
        hat_weights = convert_usa_weights_to_hash_attention(
            usa_checkpoint_path, 
            num_layers=32, 
            num_heads=32
        )
        print("✅ Successfully loaded and converted USA weights")
    except Exception as e:
        print(f"❌ Error loading USA weights: {e}")
        print("Creating dummy weights for demonstration...")
        hat_weights = {}
        for layer_idx in range(32):
            hat_weights[layer_idx] = {
                "query_matrix": [
                    torch.randn(32, 128, 128),
                    torch.randn(32, 128, 128),
                    torch.randn(32, 128, 32),
                ],
                "query_bias": [
                    torch.randn(32, 128),
                    torch.randn(32, 128),
                    torch.randn(32, 32),
                ],
                "key_matrix": [
                    torch.randn(32, 128, 128),
                    torch.randn(32, 128, 128),
                    torch.randn(32, 128, 32),
                ],
                "key_bias": [
                    torch.randn(32, 128),
                    torch.randn(32, 128),
                    torch.randn(32, 32),
                ],
            }
        print("✅ Created dummy weights")


Loading USA weights from /data/apdesai/code/HashAttention-1.0/artifacts/llama3.1-8b-patch.32K.v1.pt


  usa_state_dict = torch.load(usa_checkpoint_path, map_location='cpu')


✅ Converted weights for 32 layers, 32 heads
✅ Successfully loaded and converted USA weights


In [4]:
model_name = "meta-llama/Llama-3.1-8B-Instruct"

try:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True,
        attn_implementation="eager"
    )
    print(f"✅ Loaded {model_name}")
    
except Exception as e:
    raise e

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

✅ Loaded meta-llama/Llama-3.1-8B-Instruct


In [5]:
# Create HashAttention configuration
local_config = LocalMaskerConfig(window_size=16)
sink_config = SinkMaskerConfig(sink_size=16)
hash_config = HashAttentionTopKMaskerConfig(
    heavy_size=32,
    hat_bits=32,
    hat_mlp_layers=3,
    hat_mlp_hidden_size=128,
    hat_mlp_activation="silu",
    hat_weights=hat_weights
)

# Combine all maskers
research_config = ResearchAttentionConfig(
    masker_configs=[local_config, sink_config, hash_config]
)

print("✅ HashAttention config: Local(16) + Sink(16) + Hash(32 bits, 32 heavy)")


✅ HashAttention config: Local(16) + Sink(16) + Hash(32 bits, 32 heavy)


In [6]:
# Create SparseAttentionHF integration object
sparse_attention_hf = SparseAttentionHF.create_from_config(research_config)
print("✅ SparseAttentionHF with HashAttention created")


✅ SparseAttentionHF with HashAttention created


In [7]:
# Test text about HashAttention and sparse attention mechanisms
test_text = """ 
        HashAttention is an innovative sparse attention mechanism that combines multiple strategies:
        
        1. Local Attention: Maintains a sliding window of recent tokens for immediate context
        2. Sink Tokens: Preserves the first few tokens which often contain crucial global information
        3. Hash-based Selection: Uses learned hash functions to identify important tokens beyond the local window
        
        This approach allows the model to maintain long-range dependencies while significantly reducing computational costs.
        The hash functions are trained to recognize patterns that indicate token importance, enabling dynamic attention allocation.
        
        Unlike traditional attention mechanisms that compute all pairwise interactions, HashAttention selectively focuses on:
        - Recent tokens (local window)
        - Important early tokens (sink tokens)
        - Semantically relevant distant tokens (hash-selected)
        
        This combination provides an excellent balance between computational efficiency and model performance.
        
        Please summarize the key benefits of HashAttention in one concise sentence.
        """

# Tokenize input
inputs = tokenizer(test_text, return_tensors="pt", truncation=True, max_length=32000)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)

print(f"✅ Input prepared: {input_ids.shape[1]} tokens")


✅ Input prepared: 190 tokens


In [8]:
# Run with full attention
model.eval()
max_new_tokens = 50

start_time = time.time()
with torch.no_grad():
    full_outputs = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens,
        temperature=0.7,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
full_time = time.time() - start_time

# Get generated text
full_generated_ids = full_outputs[0]
full_generated_text = tokenizer.decode(full_generated_ids, skip_special_tokens=True)

print(f"⏱️ Full attention: {full_time:.2f}s")
print(f"📝 Generated: {len(full_generated_ids) - len(input_ids[0])} tokens")
print("\nOutput:")
print("-" * 50)
print(full_generated_text)


⏱️ Full attention: 3.04s
📝 Generated: 32 tokens

Output:
--------------------------------------------------
 
        HashAttention is an innovative sparse attention mechanism that combines multiple strategies:
        
        1. Local Attention: Maintains a sliding window of recent tokens for immediate context
        2. Sink Tokens: Preserves the first few tokens which often contain crucial global information
        3. Hash-based Selection: Uses learned hash functions to identify important tokens beyond the local window
        
        This approach allows the model to maintain long-range dependencies while significantly reducing computational costs.
        The hash functions are trained to recognize patterns that indicate token importance, enabling dynamic attention allocation.
        
        Unlike traditional attention mechanisms that compute all pairwise interactions, HashAttention selectively focuses on:
        - Recent tokens (local window)
        - Important early toke

In [9]:
# Run with HashAttention
start_time = time.time()
with torch.no_grad():
    with sparse_attention_hf(model) as sparse_model:
        sparse_outputs = sparse_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            sparse_meta_data={}
        )
sparse_time = time.time() - start_time

# Get generated text
sparse_generated_ids = sparse_outputs[0]
sparse_generated_text = tokenizer.decode(sparse_generated_ids, skip_special_tokens=True)

print(f"⚡ HashAttention: {sparse_time:.2f}s")
print(f"📝 Generated: {len(sparse_generated_ids) - len(input_ids[0])} tokens")
print("\nOutput:")
print("-" * 50)
print(sparse_generated_text)


RuntimeError: einsum(): subscript h has size 32 for operand 1 which does not broadcast with previously seen size 8

In [None]:
# Compare performance
speedup = full_time / sparse_time if sparse_time > 0 else 0
print(f"\n📊 Performance Comparison:")
print(f"{'Method':<20} {'Time (s)':<10} {'Speedup':<10}")
print("-" * 40)
print(f"{'Full Attention':<20} {full_time:<10.2f} {'1.00x':<10}")
print(f"{'HashAttention':<20} {sparse_time:<10.2f} {speedup:<10.2f}x")

# Analyze attention pattern efficiency
total_possible_attention = input_ids.shape[1] * input_ids.shape[1]
estimated_sparse_attention = (
    16 * input_ids.shape[1] +  # Local window
    16 * input_ids.shape[1] +  # Sink tokens
    32 * input_ids.shape[1]    # Hash-selected tokens
)
sparsity_ratio = estimated_sparse_attention / total_possible_attention

print(f"\n🎯 Attention Efficiency:")
print(f"Full attention pairs: {total_possible_attention:,}")
print(f"Sparse attention pairs: ~{estimated_sparse_attention:,}")
print(f"Sparsity ratio: {sparsity_ratio:.3f} ({sparsity_ratio*100:.1f}% of full attention)")


In [None]:
print("🔍 HashAttention Components Analysis:")
print("\n1. Local Attention (Window Size: 16)")
print("   - Maintains attention to the most recent 16 tokens")
print("   - Ensures local coherence and immediate context")

print("\n2. Sink Tokens (Sink Size: 16)")
print("   - Preserves attention to the first 16 tokens")
print("   - Captures global context and important initial information")

print("\n3. Hash-based Selection (Heavy Size: 32)")
print("   - Uses learned hash functions to identify 32 important tokens")
print("   - Hash dimension: 32 bits")
print("   - MLP layers: 3 (with SiLU activation)")
print("   - Hidden size: 128")

print("\n4. Combined Strategy:")
print(f"   - Total attended tokens per query: up to {16 + 16 + 32} tokens")
print(f"   - Reduction from full attention: {(1 - sparsity_ratio)*100:.1f}%")
print("   - Maintains both local and global context")
print("   - Adaptive selection of important distant tokens")


In [None]:
# Extract only the generated portions for comparison
full_generated_only = tokenizer.decode(
    full_generated_ids[len(input_ids[0]):], 
    skip_special_tokens=True
)
sparse_generated_only = tokenizer.decode(
    sparse_generated_ids[len(input_ids[0]):], 
    skip_special_tokens=True
)

print("📝 Generated Text Comparison:")
print("\n" + "="*60)
print("FULL ATTENTION OUTPUT:")
print("="*60)
print(full_generated_only)
print("\n" + "="*60)
print("HASHATTENTION OUTPUT:")
print("="*60)
print(sparse_generated_only)
print("\n" + "="*60)

# Simple similarity metrics
full_tokens = set(full_generated_only.split())
sparse_tokens = set(sparse_generated_only.split())

if full_tokens and sparse_tokens:
    token_overlap = len(full_tokens.intersection(sparse_tokens))
    token_union = len(full_tokens.union(sparse_tokens))
    jaccard_similarity = token_overlap / token_union if token_union > 0 else 0
    
    print(f"\n🔍 Text Similarity Analysis:")
    print(f"Full attention tokens: {len(full_tokens)}")
    print(f"HashAttention tokens: {len(sparse_tokens)}")
    print(f"Token overlap: {token_overlap}")
    print(f"Jaccard similarity: {jaccard_similarity:.3f}")
