# Compare Old vs New Assignment Kernel

This notebook compares the speed of:
- `assign_euclid_triton.py` (old - two reductions)
- `assign_euclid_triton2.py` (new - one reduction + extract)

In [6]:
# Clone the repo and install
!git clone https://github.com/vibhaddalagit/vibha-flash-kmeans.git
%cd vibha-flash-kmeans
!pip install -e . -q

Cloning into 'vibha-flash-kmeans'...
remote: Enumerating objects: 189, done.[K
remote: Counting objects: 100% (189/189), done.[K
remote: Compressing objects: 100% (131/131), done.[K
remote: Total 189 (delta 102), reused 137 (delta 56), pack-reused 0 (from 0)[K
Receiving objects: 100% (189/189), 6.48 MiB | 17.42 MiB/s, done.
Resolving deltas: 100% (102/102), done.
/content/flash-kmeans/flash-kmeans/vibha-flash-kmeans
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
  Building editable for flash-kmeans (pyproject.toml) ... [?25l[?25hdone


In [7]:
import torch
import triton

print(f"PyTorch version: {torch.__version__}")
print(f"Triton version: {triton.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.9.0+cu126
Triton version: 3.5.0
CUDA available: True
GPU: NVIDIA A100-SXM4-80GB


In [8]:
# Import both implementations
from flash_kmeans.assign_euclid_triton import euclid_assign_triton as euclid_assign_old
from flash_kmeans.assign_euclid_triton2 import euclid_assign_triton as euclid_assign_new

print("Both kernels imported successfully!")

Both kernels imported successfully!


In [9]:
def benchmark_kernel(kernel_fn, x, centroids, x_sq, warmup=10, repeats=100):
    """Benchmark a kernel function."""
    B, N, D = x.shape
    out = torch.empty((B, N), device=x.device, dtype=torch.int32)
    c_sq = (centroids.float() ** 2).sum(-1)
    
    # Warmup
    for _ in range(warmup):
        kernel_fn(x, centroids, x_sq, out, c_sq)
    
    # Benchmark
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    
    start.record()
    for _ in range(repeats):
        kernel_fn(x, centroids, x_sq, out, c_sq)
    end.record()
    torch.cuda.synchronize()
    
    avg_time_ms = start.elapsed_time(end) / repeats
    return avg_time_ms, out

In [10]:
# Test parameters (same as your benchmark)
B, N, D = 32, 74256, 128
K = 1000
dtype = torch.float16

print(f"Test config: B={B}, N={N}, D={D}, K={K}, dtype={dtype}")
print(f"Points per batch: {N:,}")
print(f"Total points: {B*N:,}")

Test config: B=32, N=74256, D=128, K=1000, dtype=torch.float16
Points per batch: 74,256
Total points: 2,376,192


In [11]:
# Create test data
torch.manual_seed(42)
x = torch.randn(B, N, D, device="cuda", dtype=dtype)
centroids = torch.randn(B, K, D, device="cuda", dtype=dtype)
x_sq = (x.float() ** 2).sum(-1)

print(f"x shape: {x.shape}")
print(f"centroids shape: {centroids.shape}")
print(f"x_sq shape: {x_sq.shape}")

x shape: torch.Size([32, 74256, 128])
centroids shape: torch.Size([32, 1000, 128])
x_sq shape: torch.Size([32, 74256])


In [12]:
# Correctness check
print("Checking correctness...")

# Reference implementation
dist = (
    x_sq.unsqueeze(-1) 
    + (centroids.float() ** 2).sum(-1).unsqueeze(1) 
    - 2.0 * torch.einsum("bnd,bkd->bnk", x.float(), centroids.float())
).clamp_min_(0.0)
ref_ids = dist.argmin(dim=-1)

# Old kernel
out_old = torch.empty((B, N), device="cuda", dtype=torch.int32)
c_sq = (centroids.float() ** 2).sum(-1)
euclid_assign_old(x, centroids, x_sq, out_old, c_sq)

# New kernel
out_new = torch.empty((B, N), device="cuda", dtype=torch.int32)
euclid_assign_new(x, centroids, x_sq, out_new, c_sq)

print(f"Old kernel matches reference: {torch.equal(ref_ids, out_old.long())}")
print(f"New kernel matches reference: {torch.equal(ref_ids, out_new.long())}")
print(f"Old and New match each other: {torch.equal(out_old, out_new)}")

Checking correctness...
Old kernel matches reference: False
New kernel matches reference: False
Old and New match each other: True


In [13]:
# Benchmark both kernels
print("\nBenchmarking (warmup=10, repeats=100)...\n")

time_old, _ = benchmark_kernel(euclid_assign_old, x, centroids, x_sq)
print(f"Old kernel (2 reductions): {time_old:.3f} ms")

time_new, _ = benchmark_kernel(euclid_assign_new, x, centroids, x_sq)
print(f"New kernel (1 reduction):  {time_new:.3f} ms")

speedup = (time_old - time_new) / time_old * 100
print(f"\nSpeedup: {speedup:.1f}%")
print(f"Time saved per call: {time_old - time_new:.3f} ms")


Benchmarking (warmup=10, repeats=100)...

Old kernel (2 reductions): 4.321 ms
New kernel (1 reduction):  4.248 ms

Speedup: 1.7%
Time saved per call: 0.073 ms


In [14]:
# Test with different K values
print("\nBenchmark across different K values:\n")
print(f"{'K':>6} | {'Old (ms)':>10} | {'New (ms)':>10} | {'Speedup':>8}")
print("-" * 45)

for K in [100, 256, 512, 1000, 2000]:
    centroids_k = torch.randn(B, K, D, device="cuda", dtype=dtype)
    
    time_old, _ = benchmark_kernel(euclid_assign_old, x, centroids_k, x_sq, warmup=5, repeats=50)
    time_new, _ = benchmark_kernel(euclid_assign_new, x, centroids_k, x_sq, warmup=5, repeats=50)
    
    speedup = (time_old - time_new) / time_old * 100
    print(f"{K:>6} | {time_old:>10.3f} | {time_new:>10.3f} | {speedup:>7.1f}%")


Benchmark across different K values:

     K |   Old (ms) |   New (ms) |  Speedup
---------------------------------------------
   256 |      1.020 |      1.124 |   -10.3%
   512 |      1.999 |      2.134 |    -6.7%
  1000 |      4.088 |      4.258 |    -4.2%
  2000 |      7.752 |      8.281 |    -6.8%


In [None]:
# Summary
print("\n" + "="*50)
print("SUMMARY")
print("="*50)
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Test size: B={B}, N={N}, D={D}, K=1000")
print(f"\nOptimization: Removed redundant tl.min() call")
print(f"Old: tl.min() + tl.argmin() (2 reductions)")
print(f"New: tl.argmin() + extract (1 reduction)")
print("="*50)