From 5448e5b46610341d51bd706b4cdd8d12fda07f2a Mon Sep 17 00:00:00 2001 From: Markus Hoehnerbach Date: Wed, 24 Sep 2025 09:49:53 -0700 Subject: [PATCH] rms norm: improve fwd perf --- examples/rms_norm.py | 10 +++++----- test/test_examples.expected | 11 ++++++----- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/examples/rms_norm.py b/examples/rms_norm.py index 823183a3f..76ca007be 100644 --- a/examples/rms_norm.py +++ b/examples/rms_norm.py @@ -47,22 +47,22 @@ def rms_norm_fwd( assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {n}" out = torch.empty_like(x) - inv_rms = torch.empty([m, 1], dtype=x.dtype, device=x.device) + inv_rms = torch.empty([m], dtype=x.dtype, device=x.device) for tile_m in hl.tile(m): x_tile = x[tile_m, :].to(torch.float32) # Compute inverse RMS: 1/sqrt(mean(x^2) + eps) x_squared = x_tile * x_tile - mean_x_squared = torch.mean(x_squared, dim=-1, keepdim=True) + mean_x_squared = torch.mean(x_squared, dim=-1) inv_rms_tile = torch.rsqrt(mean_x_squared + eps) # Apply normalization and weight - normalized = x_tile * inv_rms_tile + normalized = x_tile * inv_rms_tile[:, None] out[tile_m, :] = (normalized * weight[:].to(torch.float32)).to(out.dtype) - inv_rms[tile_m, :] = inv_rms_tile.to(out.dtype) + inv_rms[tile_m] = inv_rms_tile.to(out.dtype) - return out, inv_rms + return out, inv_rms.reshape(-1, 1) @helion.kernel diff --git a/test/test_examples.expected b/test/test_examples.expected index 5b733b3cb..00da612f9 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -2972,11 +2972,12 @@ def _helion_rms_norm_fwd(x, weight, out, inv_rms, inv_rms_stride_0, out_stride_0 load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) v_0 = tl.cast(load, tl.float32) v_1 = v_0 * v_0 - mean_x_squared_extra = tl.cast(tl.reshape(tl.sum(v_1, 1), [_BLOCK_SIZE_0, 1]), tl.float32) + mean_x_squared_extra = tl.cast(tl.sum(v_1, 1), tl.float32) v_2 = mean_x_squared_extra / n.to(tl.float32) v_3 = v_2 + eps v_4 = libdevice.rsqrt(v_3) - v_5 = v_0 * v_4 + subscript = v_4[:, None] + v_5 = v_0 * subscript load_1 = tl.load(weight + indices_1 * weight_stride_0, mask_1, other=0) v_6 = tl.cast(load_1, tl.float32) v_7 = v_6[None, :] @@ -2984,7 +2985,7 @@ def _helion_rms_norm_fwd(x, weight, out, inv_rms, inv_rms_stride_0, out_stride_0 v_9 = tl.cast(v_8, tl.float16) tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_9, mask_0[:, None] & mask_1[None, :]) v_10 = tl.cast(v_4, tl.float16) - tl.store(inv_rms + indices_0[:, None] * inv_rms_stride_0, v_10, mask_0[:, None]) + tl.store(inv_rms + indices_0 * inv_rms_stride_0, v_10, mask_0) def rms_norm_fwd(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher): """ @@ -3005,11 +3006,11 @@ def rms_norm_fwd(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05, *, _la m, n = x.size() assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {n}' out = torch.empty_like(x) - inv_rms = torch.empty([m, 1], dtype=x.dtype, device=x.device) + inv_rms = torch.empty([m], dtype=x.dtype, device=x.device) _BLOCK_SIZE_0 = 16 _RDIM_SIZE_1 = triton.next_power_of_2(n) _launcher(_helion_rms_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, out, inv_rms, inv_rms.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) - return (out, inv_rms) + return (out, inv_rms.reshape(-1, 1)) --- assertExpectedJournal(TestExamples.test_segment_reduction) from __future__ import annotations