In [1]:
import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import numpy as np
from typing import Dict, List, Any
from dataclasses import dataclass
import math
from einops import rearrange

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"Device: {torch.cuda.get_device_name() if torch.cuda.is_available() else 'CPU'}")


PyTorch version: 2.5.0
Device: Tesla V100-SXM2-32GB-LS


In [2]:
# Import USA module components
from hashattention.hashattention_llama import USA, DEFAULT_USA_CFG, SignSTE, ste_sign

# Import HashAttentionTopKMasker components
from sparse_attention_hub.sparse_attention.research_attention.maskers.fixed.implementations.hashattention_top_k import (
    HashAttentionTopKMasker,
    HashAttentionTopKMaskerConfig
)
from sparse_attention_hub.sparse_attention.utils.mask import Mask


CENTER 0


In [3]:
# Configuration parameters
num_heads = 4
head_dim = 64
batch_size = 2
seq_len_q = 8
seq_len_k = 12
layer_idx = 0

# USA configuration
usa_config = {
    'lth_int_dim': 32,
    'lth_final_dim': 16,  # This will be our hat_bits
    'lth_thold': 0.0,
    'lth_num_layers': 2
}

print(f"Configuration:")
print(f"  num_heads: {num_heads}")
print(f"  head_dim: {head_dim}")
print(f"  batch_size: {batch_size}")
print(f"  seq_len_q: {seq_len_q}")
print(f"  seq_len_k: {seq_len_k}")
print(f"  USA config: {usa_config}")


Configuration:
  num_heads: 4
  head_dim: 64
  batch_size: 2
  seq_len_q: 8
  seq_len_k: 12
  USA config: {'lth_int_dim': 32, 'lth_final_dim': 16, 'lth_thold': 0.0, 'lth_num_layers': 2}


In [5]:
# Create USA module
usa_module = USA(num_heads=num_heads, head_dim=head_dim, usa_params=usa_config)
usa_module.eval()

print(f"USA module created with {usa_config['lth_num_layers']} layers")
print(f"Key transformation heads: {len(usa_module.learning_to_hash_transformation_k)}")
print(f"Query transformation heads: {len(usa_module.learning_to_hash_transformation_q)}")

# Print architecture for first head
print(f"\nFirst head key transformation:")
for i, layer in enumerate(usa_module.learning_to_hash_transformation_k[0]):
    print(f"  Layer {i}: {layer}")
    
print(f"\nFirst head query transformation:")
for i, layer in enumerate(usa_module.learning_to_hash_transformation_q[0]):
    print(f"  Layer {i}: {layer}")


USA module created with 2 layers
Key transformation heads: 4
Query transformation heads: 4

First head key transformation:
  Layer 0: Linear(in_features=64, out_features=32, bias=True)
  Layer 1: SiLU()
  Layer 2: Linear(in_features=32, out_features=16, bias=True)

First head query transformation:
  Layer 0: Linear(in_features=64, out_features=32, bias=True)
  Layer 1: SiLU()
  Layer 2: Linear(in_features=32, out_features=16, bias=True)


In [6]:
def extract_weights_from_usa(usa_module, layer_idx=0):
    """
    Extract weights from USA module and format them for HashAttentionTopKMasker.
    
    Args:
        usa_module: USA module instance
        layer_idx: Layer index for the weights dictionary
    
    Returns:
        Dictionary formatted for HashAttentionTopKMaskerConfig.hat_weights
    """
    num_heads = usa_module.num_heads
    
    # Initialize weight structure
    hat_weights = {
        layer_idx: {
            "key_matrix": [],
            "key_bias": [],
            "query_matrix": [],
            "query_bias": []
        }
    }
    
    # Extract weights for each MLP layer
    # We need to stack weights across heads to get (H, d_in, d_out) tensors
    
    # Get the number of linear layers (excluding activations)
    k_linear_layers = [layer for layer in usa_module.learning_to_hash_transformation_k[0] if isinstance(layer, nn.Linear)]
    q_linear_layers = [layer for layer in usa_module.learning_to_hash_transformation_q[0] if isinstance(layer, nn.Linear)]
    
    print(f"Found {len(k_linear_layers)} linear layers in key transformation")
    print(f"Found {len(q_linear_layers)} linear layers in query transformation")
    
    # Extract key matrices and biases
    for layer_idx_mlp in range(len(k_linear_layers)):
        # Stack weights and biases across heads
        key_weights = []
        key_biases = []
        
        for head_idx in range(num_heads):
            # Get the actual linear layer from the sequential module
            actual_layer = None
            layer_count = 0
            for module in usa_module.learning_to_hash_transformation_k[head_idx]:
                if isinstance(module, nn.Linear):
                    if layer_count == layer_idx_mlp:
                        actual_layer = module
                        break
                    layer_count += 1
            
            if actual_layer is not None:
                key_weights.append(actual_layer.weight.detach().clone().T)  # Transpose for correct shape
                key_biases.append(actual_layer.bias.detach().clone())
        
        # Stack to get (H, d_in, d_out) and (H, d_out) shapes
        key_matrix = torch.stack(key_weights, dim=0)
        key_bias = torch.stack(key_biases, dim=0)
        
        hat_weights[layer_idx]["key_matrix"].append(key_matrix)
        hat_weights[layer_idx]["key_bias"].append(key_bias)
        
        print(f"Key layer {layer_idx_mlp}: weight shape {key_matrix.shape}, bias shape {key_bias.shape}")
    
    # Extract query matrices and biases
    for layer_idx_mlp in range(len(q_linear_layers)):
        # Stack weights and biases across heads
        query_weights = []
        query_biases = []
        
        for head_idx in range(num_heads):
            # Get the actual linear layer from the sequential module
            actual_layer = None
            layer_count = 0
            for module in usa_module.learning_to_hash_transformation_q[head_idx]:
                if isinstance(module, nn.Linear):
                    if layer_count == layer_idx_mlp:
                        actual_layer = module
                        break
                    layer_count += 1
            
            if actual_layer is not None:
                query_weights.append(actual_layer.weight.detach().clone().T)  # Transpose for correct shape
                query_biases.append(actual_layer.bias.detach().clone())
        
        # Stack to get (H, d_in, d_out) and (H, d_out) shapes
        query_matrix = torch.stack(query_weights, dim=0)
        query_bias = torch.stack(query_biases, dim=0)
        
        hat_weights[layer_idx]["query_matrix"].append(query_matrix)
        hat_weights[layer_idx]["query_bias"].append(query_bias)
        
        print(f"Query layer {layer_idx_mlp}: weight shape {query_matrix.shape}, bias shape {query_bias.shape}")
    
    return hat_weights

# Extract weights
hat_weights = extract_weights_from_usa(usa_module, layer_idx=0)
print(f"\nExtracted weights structure:")
print(f"  Key matrices: {len(hat_weights[0]['key_matrix'])} layers")
print(f"  Query matrices: {len(hat_weights[0]['query_matrix'])} layers")


Found 2 linear layers in key transformation
Found 2 linear layers in query transformation
Key layer 0: weight shape torch.Size([4, 64, 32]), bias shape torch.Size([4, 32])
Key layer 1: weight shape torch.Size([4, 32, 16]), bias shape torch.Size([4, 16])
Query layer 0: weight shape torch.Size([4, 64, 32]), bias shape torch.Size([4, 32])
Query layer 1: weight shape torch.Size([4, 32, 16]), bias shape torch.Size([4, 16])

Extracted weights structure:
  Key matrices: 2 layers
  Query matrices: 2 layers


In [8]:
# Create sample inputs
torch.manual_seed(42)  # For reproducibility

# Create sample key and query tensors
# Shape: (batch_size, num_heads, seq_len, head_dim)
sample_keys = torch.randn(batch_size, num_heads, seq_len_k, head_dim, dtype=torch.float32)
sample_queries = torch.randn(batch_size, num_heads, seq_len_q, head_dim, dtype=torch.float32)
sample_values = torch.randn(batch_size, num_heads, seq_len_k, head_dim, dtype=torch.float32)

print(f"Sample inputs created:")
print(f"  Keys shape: {sample_keys.shape}")
print(f"  Queries shape: {sample_queries.shape}")
print(f"  Values shape: {sample_values.shape}")

# Create attention mask (optional)
attention_mask = None

# Create previous mask for the masker
mask_shape = (batch_size, num_heads, seq_len_q, seq_len_k)
previous_mask = Mask.create_empty_mask(mask_shape)

print(f"  Previous mask shape: {mask_shape}")
print(f"  Previous mask is full: {previous_mask.is_full_mask()}")


Sample inputs created:
  Keys shape: torch.Size([2, 4, 12, 64])
  Queries shape: torch.Size([2, 4, 8, 64])
  Values shape: torch.Size([2, 4, 12, 64])
  Previous mask shape: (2, 4, 8, 12)
  Previous mask is full: False


In [9]:
# Test USA module forward pass
print("=== USA Module Forward Pass ===")

with torch.no_grad():
    # USA module expects (B, H, seq_len, head_dim) input
    usa_scores, usa_key_embeddings = usa_module(sample_keys, sample_queries, hard=True)
    
    print(f"USA scores shape: {usa_scores.shape}")
    print(f"USA key embeddings shape: {usa_key_embeddings.shape}")
    
    # Also test individual embeddings
    usa_key_sigs = usa_module.k_embedding(sample_keys, hard=True)
    usa_query_sigs = usa_module.q_embedding(sample_queries, hard=True)
    
    print(f"USA key signatures shape: {usa_key_sigs.shape}")
    print(f"USA query signatures shape: {usa_query_sigs.shape}")
    
    # Check value range (should be -1 or 1 for hard=True)
    print(f"USA key signatures range: [{usa_key_sigs.min():.3f}, {usa_key_sigs.max():.3f}]")
    print(f"USA query signatures range: [{usa_query_sigs.min():.3f}, {usa_query_sigs.max():.3f}]")
    
    # Compute manual scores for comparison
    usa_manual_scores = torch.matmul(usa_query_sigs, usa_key_sigs.transpose(-2, -1))
    print(f"Manual USA scores shape: {usa_manual_scores.shape}")
    print(f"Manual USA scores range: [{usa_manual_scores.min():.3f}, {usa_manual_scores.max():.3f}]")


=== USA Module Forward Pass ===
USA scores shape: torch.Size([2, 4, 8, 12])
USA key embeddings shape: torch.Size([2, 4, 12, 16])
USA key signatures shape: torch.Size([2, 4, 12, 16])
USA query signatures shape: torch.Size([2, 4, 8, 16])
USA key signatures range: [-1.000, 1.000]
USA query signatures range: [-1.000, 1.000]
Manual USA scores shape: torch.Size([2, 4, 8, 12])
Manual USA scores range: [-12.000, 12.000]


In [10]:
# Create HashAttentionTopKMaskerConfig
masker_config = HashAttentionTopKMaskerConfig(
    heavy_size=0.5,  # Use 50% of keys
    hat_bits=usa_config['lth_final_dim'],
    hat_mlp_layers=usa_config['lth_num_layers'],
    hat_mlp_hidden_size=usa_config['lth_int_dim'],
    hat_mlp_activation="silu",  # USA uses SiLU activation
    hat_weights=hat_weights
)

# Create HashAttentionTopKMasker
masker = HashAttentionTopKMasker(masker_config)

print(f"HashAttentionTopKMasker created with:")
print(f"  heavy_size: {masker_config.heavy_size}")
print(f"  hat_bits: {masker_config.hat_bits}")
print(f"  hat_mlp_layers: {masker_config.hat_mlp_layers}")
print(f"  hat_mlp_hidden_size: {masker_config.hat_mlp_hidden_size}")
print(f"  hat_mlp_activation: {masker_config.hat_mlp_activation}")


HashAttentionTopKMasker created with:
  heavy_size: 0.5
  hat_bits: 16
  hat_mlp_layers: 2
  hat_mlp_hidden_size: 32
  hat_mlp_activation: silu


In [11]:
# Test HashAttentionTopKMasker _get_signatures
print("=== HashAttentionTopKMasker _get_signatures ===")

with torch.no_grad():
    # Test key signatures
    key_matrix_list = hat_weights[0]["key_matrix"]
    key_bias_list = hat_weights[0]["key_bias"]
    
    masker_key_sigs = masker._get_signatures(
        sample_keys, key_matrix_list, key_bias_list
    )
    
    print(f"Masker key signatures shape: {masker_key_sigs.shape}")
    print(f"Masker key signatures range: [{masker_key_sigs.min():.3f}, {masker_key_sigs.max():.3f}]")
    
    # Test query signatures
    query_matrix_list = hat_weights[0]["query_matrix"]
    query_bias_list = hat_weights[0]["query_bias"]
    
    masker_query_sigs = masker._get_signatures(
        sample_queries, query_matrix_list, query_bias_list
    )
    
    print(f"Masker query signatures shape: {masker_query_sigs.shape}")
    print(f"Masker query signatures range: [{masker_query_sigs.min():.3f}, {masker_query_sigs.max():.3f}]")
    
    # Compute manual scores for comparison
    masker_manual_scores = torch.matmul(masker_query_sigs, masker_key_sigs.transpose(-2, -1))
    print(f"Manual masker scores shape: {masker_manual_scores.shape}")
    print(f"Manual masker scores range: [{masker_manual_scores.min():.3f}, {masker_manual_scores.max():.3f}]")


=== HashAttentionTopKMasker _get_signatures ===
Masker key signatures shape: torch.Size([2, 4, 12, 16])
Masker key signatures range: [-1.000, 1.000]
Masker query signatures shape: torch.Size([2, 4, 8, 16])
Masker query signatures range: [-1.000, 1.000]
Manual masker scores shape: torch.Size([2, 4, 8, 12])
Manual masker scores range: [-12.000, 12.000]


In [12]:
# Compare the signatures
print("=== Signature Comparison ===")

# Compare key signatures
key_diff = torch.abs(usa_key_sigs - masker_key_sigs)
print(f"Key signatures:")
print(f"  Max absolute difference: {key_diff.max():.6f}")
print(f"  Mean absolute difference: {key_diff.mean():.6f}")
print(f"  Are they close? {torch.allclose(usa_key_sigs, masker_key_sigs, atol=1e-6)}")

# Compare query signatures
query_diff = torch.abs(usa_query_sigs - masker_query_sigs)
print(f"\nQuery signatures:")
print(f"  Max absolute difference: {query_diff.max():.6f}")
print(f"  Mean absolute difference: {query_diff.mean():.6f}")
print(f"  Are they close? {torch.allclose(usa_query_sigs, masker_query_sigs, atol=1e-6)}")

# Compare manual scores
scores_diff = torch.abs(usa_manual_scores - masker_manual_scores)
print(f"\nManual scores:")
print(f"  Max absolute difference: {scores_diff.max():.6f}")
print(f"  Mean absolute difference: {scores_diff.mean():.6f}")
print(f"  Are they close? {torch.allclose(usa_manual_scores, masker_manual_scores, atol=1e-6)}")


=== Signature Comparison ===
Key signatures:
  Max absolute difference: 0.000000
  Mean absolute difference: 0.000000
  Are they close? True

Query signatures:
  Max absolute difference: 0.000000
  Mean absolute difference: 0.000000
  Are they close? True

Manual scores:
  Max absolute difference: 0.000000
  Mean absolute difference: 0.000000
  Are they close? True


In [13]:
# Summary of results
print("=== SUMMARY ===")
print()
print("1. USA Module:")
print(f"   - Successfully created with {usa_config['lth_num_layers']} layers")
print(f"   - Produces signatures of shape: {usa_key_sigs.shape}")
print(f"   - Signature values in range: [{usa_key_sigs.min():.3f}, {usa_key_sigs.max():.3f}]")
print()
print("2. HashAttentionTopKMasker:")
print(f"   - Successfully created with weights extracted from USA module")
print(f"   - Produces signatures of shape: {masker_key_sigs.shape}")
print(f"   - Signature values in range: [{masker_key_sigs.min():.3f}, {masker_key_sigs.max():.3f}]")
print()
print("3. Comparison Results:")
print(f"   - Key signatures match: {torch.allclose(usa_key_sigs, masker_key_sigs, atol=1e-6)}")
print(f"   - Query signatures match: {torch.allclose(usa_query_sigs, masker_query_sigs, atol=1e-6)}")
print(f"   - Manual scores match: {torch.allclose(usa_manual_scores, masker_manual_scores, atol=1e-6)}")
print(f"   - Max key signature difference: {key_diff.max():.6f}")
print(f"   - Max query signature difference: {query_diff.max():.6f}")
print()
if torch.allclose(usa_key_sigs, masker_key_sigs, atol=1e-6) and torch.allclose(usa_query_sigs, masker_query_sigs, atol=1e-6):
    print("✅ SUCCESS: Both implementations produce identical results!")
else:
    print("❌ MISMATCH: Implementations produce different results.")
    print("   This might be due to:")
    print("   - Different weight extraction/formatting")
    print("   - Different activation functions")
    print("   - Different numerical precision")
    print("   - Different tensor operations")


=== SUMMARY ===

1. USA Module:
   - Successfully created with 2 layers
   - Produces signatures of shape: torch.Size([2, 4, 12, 16])
   - Signature values in range: [-1.000, 1.000]

2. HashAttentionTopKMasker:
   - Successfully created with weights extracted from USA module
   - Produces signatures of shape: torch.Size([2, 4, 12, 16])
   - Signature values in range: [-1.000, 1.000]

3. Comparison Results:
   - Key signatures match: True
   - Query signatures match: True
   - Manual scores match: True
   - Max key signature difference: 0.000000
   - Max query signature difference: 0.000000

✅ SUCCESS: Both implementations produce identical results!
