From 8bb79f44086175641e85bc80f26b6118ed1ad1d9 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Tue, 30 Sep 2025 13:46:40 -0700 Subject: [PATCH] Improve rms_norm perf Fixes #660 On my local 4080 laptop GPU perf improved from 3.37x to 6.12x. Will run CI benchmarks on B200 to validate. stack-info: PR: https://github.com/pytorch/helion/pull/727, branch: oulgen/stack/108 --- examples/rms_norm.py | 37 +++++++++++++++------- test/test_examples.expected | 61 ++++++++++++++++++++++--------------- test/test_examples.py | 4 +-- 3 files changed, 64 insertions(+), 38 deletions(-) diff --git a/examples/rms_norm.py b/examples/rms_norm.py index 76ca007be..03a2b7ae2 100644 --- a/examples/rms_norm.py +++ b/examples/rms_norm.py @@ -49,18 +49,33 @@ def rms_norm_fwd( out = torch.empty_like(x) inv_rms = torch.empty([m], dtype=x.dtype, device=x.device) - for tile_m in hl.tile(m): - x_tile = x[tile_m, :].to(torch.float32) - - # Compute inverse RMS: 1/sqrt(mean(x^2) + eps) - x_squared = x_tile * x_tile - mean_x_squared = torch.mean(x_squared, dim=-1) - inv_rms_tile = torch.rsqrt(mean_x_squared + eps) + block_size_n = hl.register_block_size(n) + n_spec = hl.specialize(n) - # Apply normalization and weight - normalized = x_tile * inv_rms_tile[:, None] - out[tile_m, :] = (normalized * weight[:].to(torch.float32)).to(out.dtype) - inv_rms[tile_m] = inv_rms_tile.to(out.dtype) + for tile_m in hl.tile(m): + # First pass: accumulate sum of squares across N in blocks + sum_sq = hl.zeros([tile_m], dtype=torch.float32) + for tile_n in hl.tile(n, block_size=block_size_n): + xi_chunk = hl.load(x, [tile_m, tile_n], eviction_policy="evict_last").to( + torch.float32 + ) + sum_sq = sum_sq + (xi_chunk * xi_chunk).sum(dim=1) + + mean_sq = sum_sq / n_spec + inv_tile = torch.rsqrt(mean_sq + eps) + inv_rms[tile_m] = inv_tile.to(inv_rms.dtype) + + # Second pass: apply normalization and weight + for tile_n in hl.tile(n, block_size=block_size_n): + w_chunk = hl.load(weight, [tile_n], eviction_policy="evict_last").to( + torch.float32 + ) + x_chunk = hl.load(x, [tile_m, tile_n], eviction_policy="evict_first").to( + torch.float32 + ) + out[tile_m, tile_n] = (x_chunk * inv_tile[:, None] * w_chunk[None, :]).to( + out.dtype + ) return out, inv_rms.reshape(-1, 1) diff --git a/test/test_examples.expected b/test/test_examples.expected index 61ba85a79..b68c2c2a4 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -3043,30 +3043,41 @@ from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit -def _helion_rms_norm_fwd(x, weight, out, inv_rms, inv_rms_stride_0, out_stride_0, out_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, n, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): +def _helion_rms_norm_fwd(x, inv_rms, weight, out, inv_rms_stride_0, out_stride_0, out_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, eps, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr): pid_0 = tl.program_id(0) - offset_0 = pid_0 * _BLOCK_SIZE_0 - indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) - mask_0 = indices_0 < m - indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) - mask_1 = indices_1 < n - load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - v_0 = tl.cast(load, tl.float32) - v_1 = v_0 * v_0 - mean_x_squared_extra = tl.cast(tl.sum(v_1, 1), tl.float32) - v_2 = mean_x_squared_extra / n.to(tl.float32) - v_3 = v_2 + eps - v_4 = libdevice.rsqrt(v_3) - subscript = v_4[:, None] - v_5 = v_0 * subscript - load_1 = tl.load(weight + indices_1 * weight_stride_0, mask_1, other=0) - v_6 = tl.cast(load_1, tl.float32) - v_7 = v_6[None, :] - v_8 = v_5 * v_7 - v_9 = tl.cast(v_8, tl.float16) - tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_9, mask_0[:, None] & mask_1[None, :]) - v_10 = tl.cast(v_4, tl.float16) - tl.store(inv_rms + indices_0 * inv_rms_stride_0, v_10, mask_0) + offset_1 = pid_0 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < m + sum_sq = tl.full([_BLOCK_SIZE_1], 0.0, tl.float32) + for offset_2 in tl.range(0, 256, _BLOCK_SIZE_0): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32) + sum_sq_copy = sum_sq + sum_sq_copy_0 = sum_sq_copy + load = tl.load(x + (indices_1[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_1[:, None], other=0, eviction_policy='evict_last') + v_0 = tl.cast(load, tl.float32) + v_1 = v_0 * v_0 + sum_1 = tl.cast(tl.sum(v_1, 1), tl.float32) + sum_sq = sum_sq_copy_0 + sum_1 + v_3 = 0.00390625 + v_4 = sum_sq * v_3 + v_5 = v_4 + eps + v_6 = libdevice.rsqrt(v_5) + v_7 = tl.cast(v_6, tl.float16) + tl.store(inv_rms + indices_1 * inv_rms_stride_0, v_7, mask_1) + for offset_2 in tl.range(0, 256, _BLOCK_SIZE_0): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32) + v_6_copy = v_6 + v_6_copy_0 = v_6_copy + load_1 = tl.load(weight + indices_2 * weight_stride_0, None, eviction_policy='evict_last') + v_8 = tl.cast(load_1, tl.float32) + load_2 = tl.load(x + (indices_1[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_1[:, None], other=0, eviction_policy='evict_first') + v_9 = tl.cast(load_2, tl.float32) + subscript = v_6_copy_0[:, None] + v_10 = v_9 * subscript + subscript_1 = v_8[None, :] + v_11 = v_10 * subscript_1 + v_12 = tl.cast(v_11, tl.float16) + tl.store(out + (indices_1[:, None] * out_stride_0 + indices_2[None, :] * out_stride_1), v_12, mask_1[:, None]) def rms_norm_fwd(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher): """ @@ -3088,9 +3099,9 @@ def rms_norm_fwd(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05, *, _la assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {n}' out = torch.empty_like(x) inv_rms = torch.empty([m], dtype=x.dtype, device=x.device) + _BLOCK_SIZE_1 = 16 _BLOCK_SIZE_0 = 16 - _RDIM_SIZE_1 = triton.next_power_of_2(n) - _launcher(_helion_rms_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, out, inv_rms, inv_rms.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) + _launcher(_helion_rms_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_1),), x, inv_rms, weight, out, inv_rms.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, eps, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3) return (out, inv_rms.reshape(-1, 1)) --- assertExpectedJournal(TestExamples.test_segment_reduction) diff --git a/test/test_examples.py b/test/test_examples.py index c1680ceb3..ae6485234 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -324,7 +324,7 @@ def test_rms_norm_fwd(self): args, (expected, None), # Expected: (output, 1/rms) fn_name="rms_norm_fwd", - block_sizes=[16], + block_sizes=[16, 16], indexing="pointer", ) ) @@ -343,7 +343,7 @@ def test_rms_norm_bwd(self): from examples.rms_norm import rms_norm_fwd # Create configured kernel with explicit config - config = helion.Config(block_size=32, num_warps=4, num_stages=3) + config = helion.Config(block_size=[32, 32], num_warps=4, num_stages=3) configured_kernel = helion.kernel(rms_norm_fwd.fn, config=config) y, rms = configured_kernel(x, weight, eps)