# DeepSeek Sparse Attention (DSA) vs Dense MLA: Performance Benchmark
# 
This notebook demonstrates the significant performance improvements achieved by
DeepSeek Sparse Attention compared to traditional dense Multi-head Latent Attention (MLA).

 ## Setup and Imports

In [None]:
%pip install torch numpy matplotlib
%env FLASH_MLA_DISABLE_SM100=1
!git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla
%cd flash-mla
!git submodule update --init --recursive
%pip install -v .

# Important: Move this file into flash-mla/test

In [None]:
import torch
import triton
import numpy as np
import matplotlib.pyplot as plt
import quant
from typing import Tuple, Optional
import time
import flash_mla
from dataclasses import dataclass

In [2]:
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"CUDA Compute Capability: {torch.cuda.get_device_capability()}")

PyTorch version: 2.9.0+cu128
CUDA available: True
GPU: NVIDIA H200
CUDA Compute Capability: (9, 0)


In [3]:

@dataclass
class BenchmarkConfig:
    """Configuration for attention benchmark"""
    batch_size: int
    seq_len_q: int  # Query sequence length (typically 1 for decoding)
    seq_len_k: int  # KV cache length
    num_heads_q: int = 128
    num_heads_kv: int = 1
    head_dim: int = 576  # Q/K dimension
    head_dim_v: int = 512  # V dimension
    block_size: int = 64  # Page block size
    is_fp8: bool = True  # Use FP8 quantization
    topk: Optional[int] = None  # For sparse attention
    num_warmup: int = 3
    num_iterations: int = 10


## Helper Functions

In [4]:

def cdiv(a: int, b: int) -> int:
    """Ceiling division"""
    return (a + b - 1) // b

def generate_test_tensors(
    config: BenchmarkConfig,
    seed: int = 42
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    """
    Generate test tensors for attention computation
    
    Returns:
        q: Query tensor [batch, seq_q, h_q, d]
        k_cache: Blocked key cache
        block_table: Block mapping table
        cache_seqlens: Actual sequence lengths
        indices: Sparse attention indices (if topk is set)
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
    # Generate query tensor
    q = torch.randn(
        config.batch_size, 
        config.seq_len_q, 
        config.num_heads_q, 
        config.head_dim,
        device='cuda',
        dtype=torch.bfloat16
    )
    q.clamp_(min=-1.0, max=1.0)
    
    # Generate sequence lengths (all same for simplicity)
    cache_seqlens = torch.full(
        (config.batch_size,), 
        config.seq_len_k, 
        dtype=torch.int32, 
        device='cuda'
    )
    
    # Calculate number of blocks needed
    max_seqlen_pad = cdiv(config.seq_len_k, 256) * 256
    num_blocks_per_seq = max_seqlen_pad // config.block_size
    total_blocks = config.batch_size * num_blocks_per_seq
    
    # Generate block table
    block_table = torch.arange(
        total_blocks, 
        dtype=torch.int32, 
        device='cuda'
    ).view(config.batch_size, num_blocks_per_seq)
    
    # Generate blocked key-value cache
    blocked_k = torch.randn(
        total_blocks,
        config.block_size,
        config.num_heads_kv,
        config.head_dim,
        device='cuda',
        dtype=torch.bfloat16
    ) / 10
    blocked_k.clamp_(min=-1.0, max=1.0)
    
    # Mask unused blocks
    for i in range(config.batch_size):
        cur_len = config.seq_len_k
        cur_num_blocks = cdiv(cur_len, config.block_size)
        # Mark unused blocks
        if cur_num_blocks < num_blocks_per_seq:
            unused_blocks = block_table[i, cur_num_blocks:]
            blocked_k[unused_blocks] = float('nan')
        # Mark unused positions in last block
        if cur_len % config.block_size != 0:
            last_block = block_table[i, cur_num_blocks - 1]
            blocked_k[last_block, cur_len % config.block_size:] = float('nan')
    
    # Generate sparse attention indices if needed
    indices = None
    if config.topk is not None:
        indices = torch.empty(
            config.batch_size, 
            config.seq_len_q, 
            config.topk,
            dtype=torch.int32,
            device='cuda'
        )
        for i in range(config.batch_size):
            for j in range(config.seq_len_q):
                # Random sampling of topk tokens
                sampled_indices = torch.randperm(config.seq_len_k, device='cuda')[:config.topk]
                # Convert to blocked indices
                blocked_indices = (
                    block_table[i, sampled_indices // config.block_size] * config.block_size + 
                    (sampled_indices % config.block_size)
                )
                indices[i, j] = blocked_indices
    
    # Quantize K cache if using FP8
    if config.is_fp8:
        blocked_k = quant.quantize_k_cache(blocked_k, config.head_dim_v, 128)
    
    return q, blocked_k, block_table, cache_seqlens, indices

def run_attention_benchmark(
    config: BenchmarkConfig,
    label: str
) -> dict:
    """
    Run attention benchmark and return timing and throughput metrics
    """
    print(f"\n{'='*60}")
    print(f"Benchmarking: {label}")
    print(f"{'='*60}")
    print(f"Batch size: {config.batch_size}")
    print(f"Query length: {config.seq_len_q}")
    print(f"KV cache length: {config.seq_len_k}")
    print(f"Num heads (Q): {config.num_heads_q}")
    print(f"FP8 quantization: {config.is_fp8}")
    if config.topk:
        print(f"Top-k (sparse): {config.topk}")
        print(f"Sparsity: {config.topk / config.seq_len_k * 100:.1f}%")
    else:
        print(f"Mode: Dense (full attention)")
    
    # Generate test data
    q, k_cache, block_table, cache_seqlens, indices = generate_test_tensors(config)
    
    # Get scheduling metadata
    tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata(
        cache_seqlens,
        config.seq_len_q * config.num_heads_q // config.num_heads_kv,
        config.num_heads_kv,
        config.num_heads_q,
        config.is_fp8,
        config.topk
    )
    
    # Define benchmark function
    def run_forward():
        return flash_mla.flash_mla_with_kvcache(
            q,
            k_cache,
            block_table,
            cache_seqlens,
            config.head_dim_v,
            tile_scheduler_metadata,
            num_splits,
            causal=False,
            is_fp8_kvcache=config.is_fp8,
            indices=indices
        )
    
    # Warmup
    for _ in range(config.num_warmup):
        run_forward()
    torch.cuda.synchronize()
    
    # Benchmark
    times = []
    for _ in range(config.num_iterations):
        start = time.perf_counter()
        out, lse = run_forward()
        torch.cuda.synchronize()
        times.append(time.perf_counter() - start)
    
    # Calculate metrics
    mean_time = np.mean(times) * 1000  # Convert to ms
    std_time = np.std(times) * 1000
    
    # Calculate FLOPs
    attended_tokens = config.topk if config.topk else config.seq_len_k
    flops = config.batch_size * config.num_heads_q * config.seq_len_q * (
        2 * config.head_dim * attended_tokens +  # Q @ K^T
        2 * attended_tokens * config.head_dim_v  # attn @ V
    )
    tflops = (flops / (mean_time / 1000)) / 1e12
    
    # Calculate memory bandwidth
    q_elem_size = 2  # bfloat16
    kv_token_size = 656 if config.is_fp8 else config.head_dim * 2
    memory_bytes = config.batch_size * (
        config.seq_len_q * config.num_heads_q * config.head_dim * q_elem_size +  # Q
        attended_tokens * config.num_heads_kv * kv_token_size +  # K/V
        config.seq_len_q * config.num_heads_q * config.head_dim_v * q_elem_size  # Output
    )
    bandwidth_gbps = (memory_bytes / (mean_time / 1000)) / 1e9
    
    print(f"\nResults:")
    print(f"  Time: {mean_time:.3f} ± {std_time:.3f} ms")
    print(f"  Throughput: {tflops:.1f} TFLOPS")
    print(f"  Bandwidth: {bandwidth_gbps:.1f} GB/s")
    
    return {
        'label': label,
        'time_ms': mean_time,
        'time_std': std_time,
        'tflops': tflops,
        'bandwidth_gbps': bandwidth_gbps,
        'attended_tokens': attended_tokens,
        'config': config
    }


## Benchmark 1: Dense vs Sparse at Different Sequence Lengths

In [5]:

print("\n" + "="*80)
print("BENCHMARK 1: Dense MLA vs Sparse DSA across sequence lengths")
print("="*80)

batch_size = 128
seq_lengths = [4096, 8192, 16384, 32768]
topk_sparse = 2048

results_dense = []
results_sparse = []

for seq_len in seq_lengths:
    # Dense attention
    config_dense = BenchmarkConfig(
        batch_size=batch_size,
        seq_len_q=1,  # Decoding (single token)
        seq_len_k=seq_len,
        is_fp8=False,
        topk=None
    )
    result_dense = run_attention_benchmark(config_dense, f"Dense MLA (seq_len={seq_len})")
    results_dense.append(result_dense)
    
    # Sparse attention
    config_sparse = BenchmarkConfig(
        batch_size=batch_size,
        seq_len_q=1,
        seq_len_k=seq_len,
        is_fp8=True,
        topk=topk_sparse
    )
    result_sparse = run_attention_benchmark(config_sparse, f"Sparse DSA (seq_len={seq_len}, topk={topk_sparse})")
    results_sparse.append(result_sparse)



BENCHMARK 1: Dense MLA vs Sparse DSA across sequence lengths

Benchmarking: Dense MLA (seq_len=4096)
Batch size: 128
Query length: 1
KV cache length: 4096
Num heads (Q): 128
FP8 quantization: False
Mode: Dense (full attention)

Results:
  Time: 0.308 ± 0.146 ms
  Throughput: 474.1 TFLOPS
  Bandwidth: 2076.6 GB/s

Benchmarking: Sparse DSA (seq_len=4096, topk=2048)
Batch size: 128
Query length: 1
KV cache length: 4096
Num heads (Q): 128
FP8 quantization: True
Top-k (sparse): 2048
Sparsity: 50.0%


NameError: name 'quant' is not defined

## Benchmark 2: Impact of Top-K Selection

In [None]:

print("\n" + "="*80)
print("BENCHMARK 2: Impact of top-k value on sparse attention")
print("="*80)

batch_size = 128
seq_len = 16384
topk_values = [128, 512, 1024, 2048]

results_topk = []

for topk in topk_values:
    config = BenchmarkConfig(
        batch_size=batch_size,
        seq_len_q=1,
        seq_len_k=seq_len,
        is_fp8=True,
        topk=topk
    )
    result = run_attention_benchmark(config, f"Sparse DSA (topk={topk})")
    results_topk.append(result)

# ## Visualization

fig, axes = plt.subplots(2, 2, figsize=(15, 12))
fig.suptitle('DeepSeek Sparse Attention Performance Analysis', fontsize=16, fontweight='bold')

# Plot 1: Latency comparison across sequence lengths
ax1 = axes[0, 0]
seq_lens_plot = [r['config'].seq_len_k for r in results_dense]
times_dense = [r['time_ms'] for r in results_dense]
times_sparse = [r['time_ms'] for r in results_sparse]

x_pos = np.arange(len(seq_lens_plot))
width = 0.35

bars1 = ax1.bar(x_pos - width/2, times_dense, width, label='Dense MLA', color='#d62728', alpha=0.8)
bars2 = ax1.bar(x_pos + width/2, times_sparse, width, label=f'Sparse DSA (topk={topk_sparse})', color='#2ca02c', alpha=0.8)

ax1.set_xlabel('Sequence Length', fontsize=11, fontweight='bold')
ax1.set_ylabel('Latency (ms)', fontsize=11, fontweight='bold')
ax1.set_title('Latency: Dense vs Sparse Attention', fontsize=12, fontweight='bold')
ax1.set_xticks(x_pos)
ax1.set_xticklabels([f'{s//1024}K' for s in seq_lens_plot])
ax1.legend()
ax1.grid(axis='y', alpha=0.3)

# Add speedup annotations
for i, (d, s) in enumerate(zip(times_dense, times_sparse)):
    speedup = d / s
    ax1.text(i, max(d, s) * 1.05, f'{speedup:.1f}x', ha='center', fontweight='bold', fontsize=9)

# Plot 2: Speedup vs sequence length
ax2 = axes[0, 1]
speedups = [d['time_ms'] / s['time_ms'] for d, s in zip(results_dense, results_sparse)]
ax2.plot(seq_lens_plot, speedups, marker='o', linewidth=2.5, markersize=8, color='#1f77b4')
ax2.axhline(y=1, color='gray', linestyle='--', alpha=0.5, label='1x (no speedup)')
ax2.set_xlabel('Sequence Length', fontsize=11, fontweight='bold')
ax2.set_ylabel('Speedup (Dense / Sparse)', fontsize=11, fontweight='bold')
ax2.set_title('Speedup Achieved by Sparse Attention', fontsize=12, fontweight='bold')
ax2.set_xscale('log', base=2)
ax2.set_xticks(seq_lens_plot)
ax2.set_xticklabels([f'{s//1024}K' for s in seq_lens_plot])
ax2.grid(True, alpha=0.3)
ax2.legend()

# Annotate speedup values
for x, y in zip(seq_lens_plot, speedups):
    ax2.annotate(f'{y:.1f}x', xy=(x, y), xytext=(0, 10), 
                textcoords='offset points', ha='center', fontweight='bold', fontsize=9)

# Plot 3: Impact of top-k on latency
ax3 = axes[1, 0]
topk_vals = [r['config'].topk for r in results_topk]
times_topk = [r['time_ms'] for r in results_topk]
sparsity_pct = [tk / seq_len * 100 for tk in topk_vals]

bars = ax3.bar(range(len(topk_vals)), times_topk, color='#ff7f0e', alpha=0.8)
ax3.set_xlabel('Top-K Value', fontsize=11, fontweight='bold')
ax3.set_ylabel('Latency (ms)', fontsize=11, fontweight='bold')
ax3.set_title(f'Impact of Top-K Selection (seq_len={seq_len})', fontsize=12, fontweight='bold')
ax3.set_xticks(range(len(topk_vals)))
ax3.set_xticklabels([f'{tk}' for tk in topk_vals])
ax3.grid(axis='y', alpha=0.3)

# Add sparsity percentage labels
for i, (bar, sp) in enumerate(zip(bars, sparsity_pct)):
    height = bar.get_height()
    ax3.text(bar.get_x() + bar.get_width()/2., height * 1.02,
            f'{sp:.1f}%\nsparse', ha='center', va='bottom', fontsize=8)

# Plot 4: Throughput comparison
ax4 = axes[1, 1]
tflops_dense = [r['tflops'] for r in results_dense]
tflops_sparse = [r['tflops'] for r in results_sparse]

x_pos = np.arange(len(seq_lens_plot))
bars1 = ax4.bar(x_pos - width/2, tflops_dense, width, label='Dense MLA', color='#d62728', alpha=0.8)
bars2 = ax4.bar(x_pos + width/2, tflops_sparse, width, label=f'Sparse DSA (topk={topk_sparse})', color='#2ca02c', alpha=0.8)

ax4.set_xlabel('Sequence Length', fontsize=11, fontweight='bold')
ax4.set_ylabel('Throughput (TFLOPS)', fontsize=11, fontweight='bold')
ax4.set_title('Computational Throughput', fontsize=12, fontweight='bold')
ax4.set_xticks(x_pos)
ax4.set_xticklabels([f'{s//1024}K' for s in seq_lens_plot])
ax4.legend()
ax4.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig('dsa_performance_analysis.png', dpi=300, bbox_inches='tight')
plt.show()


## Summary Statistics

In [None]:

print("\n" + "="*80)
print("PERFORMANCE SUMMARY")
print("="*80)

print("\n1. Dense vs Sparse (topk=2048) at Different Sequence Lengths:")
print("-" * 80)
print(f"{'Seq Length':<15} {'Dense (ms)':<15} {'Sparse (ms)':<15} {'Speedup':<15} {'Memory Saved':<15}")
print("-" * 80)
for d, s in zip(results_dense, results_sparse):
    seq_len = d['config'].seq_len_k
    speedup = d['time_ms'] / s['time_ms']
    memory_ratio = s['attended_tokens'] / d['attended_tokens']
    print(f"{seq_len:<15} {d['time_ms']:<15.2f} {s['time_ms']:<15.2f} {speedup:<15.2f}x {(1-memory_ratio)*100:<14.1f}%")

avg_speedup = np.mean([d['time_ms'] / s['time_ms'] for d, s in zip(results_dense, results_sparse)])
print(f"\nAverage speedup: {avg_speedup:.2f}x")

print("\n2. Impact of Top-K Selection (seq_len=16384):")
print("-" * 80)
print(f"{'Top-K':<15} {'Sparsity':<15} {'Latency (ms)':<15} {'TFLOPS':<15}")
print("-" * 80)
for r in results_topk:
    topk = r['config'].topk
    sparsity = topk / seq_len * 100
    print(f"{topk:<15} {sparsity:<14.1f}% {r['time_ms']:<15.2f} {r['tflops']:<15.1f}")

print("\n" + "="*80)
print("KEY INSIGHTS")
print("="*80)
print(f"""
1. Sparse DSA achieves {avg_speedup:.1f}x average speedup over dense MLA
2. Speedup increases with sequence length (scales better for long contexts)
3. At {seq_lens_plot[-1]} tokens, sparse attention is {speedups[-1]:.1f}x faster
4. Memory bandwidth reduced by ~{(1 - topk_sparse/seq_lens_plot[-1])*100:.0f}% with topk={topk_sparse}
5. Performance scales predictably with top-k selection
6. FP8 quantization enables efficient sparse computation
""")

print("\n" + "="*80)
print("CONCLUSION")
print("="*80)
print("""
DeepSeek Sparse Attention (DSA) provides substantial performance improvements
for long-context inference workloads. By attending to only the most relevant
tokens (top-k), DSA reduces both computational cost and memory bandwidth
requirements while maintaining model quality.

This makes DSA particularly valuable for:
- Long-context language models (>32K tokens)
- Cost-sensitive production deployments
- Real-time inference applications
- Multi-user serving scenarios

The performance gains scale with sequence length, making DSA increasingly
valuable as context windows continue to grow.
""")