# Compare Centroid Update Methods

Compare:
- `triton_centroid_update_sorted_euclid` (current - with sort)
- `triton_centroid_update_euclid` (atomic - no sort)

In [None]:
!git clone https://github.com/vibhaddalagit/vibha-flash-kmeans.git 2>/dev/null || echo 'Already cloned'
%cd vibha-flash-kmeans
!pip install -e . -q

In [None]:
import torch
print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
from flash_kmeans.assign_euclid_triton import euclid_assign_triton
from flash_kmeans.centroid_update_triton import (
    triton_centroid_update_sorted_euclid,  # With sort
    triton_centroid_update_euclid,         # Atomic, no sort
)
print("Imports successful!")

In [None]:
# Test parameters
B, N, D = 32, 74256, 128
K = 1000
dtype = torch.float16

torch.manual_seed(42)
x = torch.randn(B, N, D, device="cuda", dtype=dtype)
centroids = torch.randn(B, K, D, device="cuda", dtype=dtype)
x_sq = (x.float() ** 2).sum(-1)

# Get cluster assignments
cluster_ids = euclid_assign_triton(x, centroids, x_sq)
print(f"x: {x.shape}, centroids: {centroids.shape}, cluster_ids: {cluster_ids.shape}")

In [None]:
# Benchmark the SORT alone
print("Benchmarking torch.sort alone...")
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

# Warmup
for _ in range(10):
    _ = torch.sort(cluster_ids, dim=-1)

start.record()
for _ in range(100):
    _ = torch.sort(cluster_ids, dim=-1)
end.record()
torch.cuda.synchronize()

sort_time = start.elapsed_time(end) / 100
print(f"Sort time: {sort_time:.3f} ms")

In [None]:
# Benchmark sorted update (WITH sort)
print("\nBenchmarking sorted centroid update (with sort)...")

# Warmup
for _ in range(10):
    _ = triton_centroid_update_sorted_euclid(x, cluster_ids, centroids)

torch.cuda.synchronize()
start.record()
for _ in range(100):
    _ = triton_centroid_update_sorted_euclid(x, cluster_ids, centroids)
end.record()
torch.cuda.synchronize()

sorted_time = start.elapsed_time(end) / 100
print(f"Sorted update time: {sorted_time:.3f} ms")

In [None]:
# Benchmark atomic update (NO sort)
print("\nBenchmarking atomic centroid update (no sort)...")

# Warmup
for _ in range(10):
    _ = triton_centroid_update_euclid(x, cluster_ids, centroids)

torch.cuda.synchronize()
start.record()
for _ in range(100):
    _ = triton_centroid_update_euclid(x, cluster_ids, centroids)
end.record()
torch.cuda.synchronize()

atomic_time = start.elapsed_time(end) / 100
print(f"Atomic update time: {atomic_time:.3f} ms")

In [None]:
# Correctness check
print("\nChecking correctness...")
result_sorted = triton_centroid_update_sorted_euclid(x, cluster_ids, centroids)
result_atomic = triton_centroid_update_euclid(x, cluster_ids, centroids)

max_diff = (result_sorted - result_atomic).abs().max().item()
print(f"Max difference: {max_diff:.6f}")
print(f"Results match: {torch.allclose(result_sorted, result_atomic, atol=1e-3)}")

In [None]:
# Summary
print("\n" + "="*50)
print("SUMMARY")
print("="*50)
print(f"Sort alone:      {sort_time:.3f} ms")
print(f"Sorted update:   {sorted_time:.3f} ms")
print(f"Atomic update:   {atomic_time:.3f} ms")
print(f"")
print(f"Speedup (atomic vs sorted): {(sorted_time - atomic_time) / sorted_time * 100:.1f}%")
print(f"Time saved per iteration: {sorted_time - atomic_time:.3f} ms")
print("="*50)

In [None]:
# Full iteration comparison
print("\nFull iteration comparison:")
print("="*50)

def iter_sorted(x, x_sq, centroids):
    cluster_ids = euclid_assign_triton(x, centroids, x_sq)
    centroids_new = triton_centroid_update_sorted_euclid(x, cluster_ids, centroids)
    return centroids_new, cluster_ids

def iter_atomic(x, x_sq, centroids):
    cluster_ids = euclid_assign_triton(x, centroids, x_sq)
    centroids_new = triton_centroid_update_euclid(x, cluster_ids, centroids)
    return centroids_new, cluster_ids

# Warmup
for _ in range(5):
    iter_sorted(x, x_sq, centroids)
    iter_atomic(x, x_sq, centroids)

# Benchmark sorted iteration
torch.cuda.synchronize()
start.record()
for _ in range(50):
    iter_sorted(x, x_sq, centroids)
end.record()
torch.cuda.synchronize()
iter_sorted_time = start.elapsed_time(end) / 50

# Benchmark atomic iteration
torch.cuda.synchronize()
start.record()
for _ in range(50):
    iter_atomic(x, x_sq, centroids)
end.record()
torch.cuda.synchronize()
iter_atomic_time = start.elapsed_time(end) / 50

print(f"Full iter (sorted): {iter_sorted_time:.3f} ms")
print(f"Full iter (atomic): {iter_atomic_time:.3f} ms")
print(f"Speedup: {(iter_sorted_time - iter_atomic_time) / iter_sorted_time * 100:.1f}%")