diff --git a/examples/rms_norm.py b/examples/rms_norm.py index b7354712b..475228b29 100644 --- a/examples/rms_norm.py +++ b/examples/rms_norm.py @@ -11,6 +11,7 @@ # ------- from __future__ import annotations +from typing import Any from typing import Callable import torch @@ -23,8 +24,10 @@ # %% # RMS Normalization Kernel # --------------------- -@helion.kernel(static_shapes=True) -def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: +@helion.kernel +def rms_norm_fwd( + x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5 +) -> tuple[torch.Tensor, torch.Tensor]: """ Performs Root Mean Square (RMS) normalization on the input tensor. @@ -38,25 +41,151 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5) -> torch. Returns: Output tensor of shape [M, N] with RMS normalization applied + RMS tensor of shape [M, N] with RMS values for each element """ m, n = x.size() assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {n}" - out = torch.empty([m, n], dtype=x.dtype, device=x.device) + out = torch.empty_like(x) + inv_rms = torch.empty_like(x) for tile_m in hl.tile(m): x_tile = x[tile_m, :].to(torch.float32) - # Compute RMS: sqrt(mean(x^2)) + # Compute inverse RMS: 1/sqrt(mean(x^2) + eps) x_squared = x_tile * x_tile mean_x_squared = torch.mean(x_squared, dim=-1, keepdim=True) - rms = torch.rsqrt(mean_x_squared + eps) + inv_rms_tile = torch.rsqrt(mean_x_squared + eps) # Apply normalization and weight - normalized = x_tile * rms + normalized = x_tile * inv_rms_tile out[tile_m, :] = (normalized * weight[:].to(torch.float32)).to(out.dtype) + inv_rms[tile_m, :] = inv_rms_tile.to(out.dtype) + + return out, inv_rms + + +@helion.kernel +def rms_norm_bwd_dw( + grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tensor, inv_rms: torch.Tensor +) -> torch.Tensor: + """ + Compute gradients for weight (dW) + + This kernel performs reduction across the batch dimension (M) to accumulate + gradients for each feature dimension's weight parameter. + + Args: + grad_out: Gradient w.r.t rms norm output [M, N] + x: Original input tensor [M, N] + weight: Weight parameter (used only for dtype/device info) [N] + inv_rms: Inverse RMS tensor [M, N] + + Returns: + grad_weight: Gradients for weight with shape [N] + """ + m, n = x.shape + + dw = torch.empty([n], dtype=weight.dtype, device=weight.device) + + # Reduce across rows (M) inside the kernel without atomics + rdim = hl.register_reduction_dim(m) + + for tile_n in hl.tile(n): + rows = hl.arange(0, rdim) + # Load slices for all rows in rdim and this tile of columns + x_blk = x[rows, tile_n].to(torch.float32) + dy_blk = grad_out[rows, tile_n].to(torch.float32) + inv_rms_blk = inv_rms[rows, tile_n].to(torch.float32) + + # Compute normalized input: x_normalized = x * inv_rms + x_normalized = x_blk * inv_rms_blk + + # Weight gradient: dw = sum_over_batch(dy * x_normalized) + dw_tile = torch.sum(dy_blk * x_normalized, dim=0).to(weight.dtype) + + dw[tile_n] = dw_tile + + return dw + + +@helion.kernel +def rms_norm_bwd_dx( + grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tensor, inv_rms: torch.Tensor +) -> torch.Tensor: + """ + Compute gradient for input tensor (dX). + + This kernel computes per-sample gradients by performing reductions across + the feature dimension (N) for each sample in the batch. - return out + Args: + grad_out: Gradient w.r.t rms norm output [M, N] + x: Original input tensor [M, N] + weight: Weight parameter [N] + inv_rms: Inverse RMS tensor [M, N] + + Returns: + grad_x: Gradient w.r.t input tensor, shape [M, N] + """ + m, n = x.shape + n = hl.specialize(n) + + grad_x = torch.empty_like(x) + + for tile_m in hl.tile(m): + x_tile = x[tile_m, :].to(torch.float32) + dy_tile = grad_out[tile_m, :].to(torch.float32) + w = weight[:].to(torch.float32) + inv_rms_tile = inv_rms[tile_m, :].to(torch.float32) + + dyw = dy_tile * w + normed = x_tile * inv_rms_tile + rowsum_dy_normed = (dyw * normed).sum(dim=-1, keepdim=True) + dx = inv_rms_tile / n * (n * dyw - normed * rowsum_dy_normed) + + grad_x[tile_m, :] = dx.to(x.dtype) + + return grad_x + + +# %% +class RMSNormFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, # noqa: ANN401 + x: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-5, + ) -> torch.Tensor: + """Forward pass for rms normalization.""" + y, rms = rms_norm_fwd(x, weight, eps) + ctx.save_for_backward(x, weight) + ctx.rms = rms # type: ignore[attr-defined] + return y + + @staticmethod + def backward( # type: ignore[override] + ctx: Any, # noqa: ANN401 + grad_out: torch.Tensor, + ) -> tuple[torch.Tensor | None, torch.Tensor | None, None]: + """Backward pass for rms normalization split into two separate kernels for efficiency.""" + x, weight = ctx.saved_tensors # type: ignore[attr-defined] + rms = ctx.rms # type: ignore[attr-defined] + + # First kernel: Compute gradients for weight by reducing across batch dimension (M) + grad_weight = rms_norm_bwd_dw(grad_out, x, weight, rms) + + # Second kernel: Compute gradient for input (dx) using per-sample reductions across feature dimension (N) + grad_x = rms_norm_bwd_dx(grad_out, x, weight, rms) + + return grad_x, grad_weight, None + + +# %% +def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: + """RMS normalization with forward + backward support.""" + return RMSNormFunction.apply(x, weight, eps) # type: ignore[no-any-return] # %% @@ -117,7 +246,36 @@ def check(m: int, n: int) -> None: """ x = torch.randn([m, n], device="cuda", dtype=torch.float16) weight = torch.randn([n], device="cuda", dtype=torch.float16) - run_example(rms_norm, rms_norm_pytorch, (x, weight, 1e-5)) + + # Test forward pass only + print("\n=== Forward Pass Test ===") + run_example( + rms_norm, + rms_norm_pytorch, + (x, weight, 1e-5), + kernel_name="helion_fwd_kernel", + baseline_name="torch", + rtol=1e-3, + atol=1e-3, + ) + + # Test forward + backward pass + print("\n\n=== Forward + Backward Pass Test ===") + x_grad = torch.randn([m, n], device="cuda", dtype=torch.float16, requires_grad=True) + weight_grad = torch.randn( + [n], device="cuda", dtype=torch.float16, requires_grad=True + ) + + run_example( + rms_norm, + rms_norm_pytorch, + (x_grad, weight_grad, 1e-5), + kernel_name="helion_autograd", + baseline_name="torch", + rtol=1e-3, + atol=1e-3, + bwd=True, + ) # %% diff --git a/test/test_examples.expected b/test/test_examples.expected index 2d343d720..8459f7c5d 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -2030,7 +2030,120 @@ def moe_matmul_ogs(A: torch.Tensor, W: torch.Tensor, expert_token_counts: torch. _launcher(_helion_moe_matmul_ogs, (E,), expert_token_offsets, expert_token_counts, sorted_to_orig_token_idx, A, W, C, A.stride(0), A.stride(1), C.stride(0), C.stride(1), W.stride(0), W.stride(1), W.stride(2), expert_token_counts.stride(0), expert_token_offsets.stride(0), sorted_to_orig_token_idx.stride(0), max_T_per_expert, N, K, _BLOCK_SIZE_2, _BLOCK_SIZE_1, _BLOCK_SIZE_3, num_warps=4, num_stages=3) return C ---- assertExpectedJournal(TestExamples.test_rms_norm) +--- assertExpectedJournal(TestExamples.test_rms_norm_bwd_dw) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_rms_norm_bwd_dw(x, grad_out, inv_rms, dw, dw_stride_0, grad_out_stride_0, grad_out_stride_1, inv_rms_stride_0, inv_rms_stride_1, x_stride_0, x_stride_1, n, m, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_1 = pid_0 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < n + indices_0 = tl.arange(0, _RDIM_SIZE_0).to(tl.int32) + mask_0 = indices_0 < m + rows = tl.arange(0, _RDIM_SIZE_0) + load = tl.load(x + (rows[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + v_0 = tl.cast(load, tl.float32) + load_1 = tl.load(grad_out + (rows[:, None] * grad_out_stride_0 + indices_1[None, :] * grad_out_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + v_1 = tl.cast(load_1, tl.float32) + load_2 = tl.load(inv_rms + (rows[:, None] * inv_rms_stride_0 + indices_1[None, :] * inv_rms_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + v_2 = tl.cast(load_2, tl.float32) + v_3 = v_0 * v_2 + v_4 = v_1 * v_3 + sum_1 = tl.cast(tl.sum(v_4, 0), tl.float32) + v_5 = tl.cast(sum_1, tl.float16) + tl.store(dw + indices_1 * dw_stride_0, v_5, mask_1) + +def rms_norm_bwd_dw(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tensor, inv_rms: torch.Tensor, *, _launcher=_default_launcher): + """ + Compute gradients for weight (dW) + + This kernel performs reduction across the batch dimension (M) to accumulate + gradients for each feature dimension's weight parameter. + + Args: + grad_out: Gradient w.r.t rms norm output [M, N] + x: Original input tensor [M, N] + weight: Weight parameter (used only for dtype/device info) [N] + inv_rms: Inverse RMS tensor [M, N] + + Returns: + grad_weight: Gradients for weight with shape [N] + """ + m, n = x.shape + dw = torch.empty([n], dtype=weight.dtype, device=weight.device) + _BLOCK_SIZE_1 = 32 + _RDIM_SIZE_0 = triton.next_power_of_2(m) + _launcher(_helion_rms_norm_bwd_dw, (triton.cdiv(n, _BLOCK_SIZE_1),), x, grad_out, inv_rms, dw, dw.stride(0), grad_out.stride(0), grad_out.stride(1), inv_rms.stride(0), inv_rms.stride(1), x.stride(0), x.stride(1), n, m, _BLOCK_SIZE_1, _RDIM_SIZE_0, num_warps=4, num_stages=3) + return dw + +--- assertExpectedJournal(TestExamples.test_rms_norm_bwd_dx) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_rms_norm_bwd_dx(x, grad_out, weight, inv_rms, grad_x, grad_out_stride_0, grad_out_stride_1, grad_x_stride_0, grad_x_stride_1, inv_rms_stride_0, inv_rms_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None], other=0) + v_0 = tl.cast(load, tl.float32) + load_1 = tl.load(grad_out + (indices_0[:, None] * grad_out_stride_0 + indices_1[None, :] * grad_out_stride_1), mask_0[:, None], other=0) + v_1 = tl.cast(load_1, tl.float32) + load_2 = tl.load(weight + indices_1 * weight_stride_0, None) + v_2 = tl.cast(load_2, tl.float32) + load_3 = tl.load(inv_rms + (indices_0[:, None] * inv_rms_stride_0 + indices_1[None, :] * inv_rms_stride_1), mask_0[:, None], other=0) + v_3 = tl.cast(load_3, tl.float32) + v_4 = v_2[None, :] + v_5 = v_1 * v_4 + v_6 = v_0 * v_3 + v_7 = v_5 * v_6 + rowsum_dy_normed = tl.cast(tl.reshape(tl.sum(v_7, 1), [_BLOCK_SIZE_0, 1]), tl.float32) + v_8 = 0.015625 + v_9 = v_3 * v_8 + v_10 = 64.0 + v_11 = v_5 * v_10 + v_12 = v_6 * rowsum_dy_normed + v_13 = v_11 - v_12 + v_14 = v_9 * v_13 + v_15 = tl.cast(v_14, tl.float16) + tl.store(grad_x + (indices_0[:, None] * grad_x_stride_0 + indices_1[None, :] * grad_x_stride_1), v_15, mask_0[:, None]) + +def rms_norm_bwd_dx(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tensor, inv_rms: torch.Tensor, *, _launcher=_default_launcher): + """ + Compute gradient for input tensor (dX). + + This kernel computes per-sample gradients by performing reductions across + the feature dimension (N) for each sample in the batch. + + Args: + grad_out: Gradient w.r.t rms norm output [M, N] + x: Original input tensor [M, N] + weight: Weight parameter [N] + inv_rms: Inverse RMS tensor [M, N] + + Returns: + grad_x: Gradient w.r.t input tensor, shape [M, N] + """ + m, n = x.shape + grad_x = torch.empty_like(x) + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_1 = 64 + _launcher(_helion_rms_norm_bwd_dx, (triton.cdiv(m, _BLOCK_SIZE_0),), x, grad_out, weight, inv_rms, grad_x, grad_out.stride(0), grad_out.stride(1), grad_x.stride(0), grad_x.stride(1), inv_rms.stride(0), inv_rms.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) + return grad_x + +--- assertExpectedJournal(TestExamples.test_rms_norm_fwd) from __future__ import annotations import torch @@ -2040,28 +2153,31 @@ from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit -def _helion_rms_norm(x, weight, out, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): +def _helion_rms_norm_fwd(x, weight, out, inv_rms, inv_rms_stride_0, inv_rms_stride_1, out_stride_0, out_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, n, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): pid_0 = tl.program_id(0) offset_0 = pid_0 * _BLOCK_SIZE_0 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) - load = tl.load(x + (indices_0[:, None] * 256 + indices_1[None, :] * 1), None) + mask_1 = indices_1 < n + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) v_0 = tl.cast(load, tl.float32) v_1 = v_0 * v_0 mean_x_squared_extra = tl.cast(tl.reshape(tl.sum(v_1, 1), [_BLOCK_SIZE_0, 1]), tl.float32) - v_2 = 256 - v_3 = mean_x_squared_extra / v_2.to(tl.float32) - v_4 = v_3 + eps - v_5 = libdevice.rsqrt(v_4) - v_6 = v_0 * v_5 - load_1 = tl.load(weight + indices_1 * 1, None) - v_7 = tl.cast(load_1, tl.float32) - v_8 = v_7[None, :] - v_9 = v_6 * v_8 - v_10 = tl.cast(v_9, tl.float16) - tl.store(out + (indices_0[:, None] * 256 + indices_1[None, :] * 1), v_10, None) - -def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher): + v_2 = mean_x_squared_extra / n.to(tl.float32) + v_3 = v_2 + eps + v_4 = libdevice.rsqrt(v_3) + v_5 = v_0 * v_4 + load_1 = tl.load(weight + indices_1 * weight_stride_0, mask_1, other=0) + v_6 = tl.cast(load_1, tl.float32) + v_7 = v_6[None, :] + v_8 = v_5 * v_7 + v_9 = tl.cast(v_8, tl.float16) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_9, mask_0[:, None] & mask_1[None, :]) + v_10 = tl.cast(v_4, tl.float16) + tl.store(inv_rms + (indices_0[:, None] * inv_rms_stride_0 + indices_1[None, :] * inv_rms_stride_1), v_10, mask_0[:, None] & mask_1[None, :]) + +def rms_norm_fwd(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher): """ Performs Root Mean Square (RMS) normalization on the input tensor. @@ -2075,14 +2191,16 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05, *, _launch Returns: Output tensor of shape [M, N] with RMS normalization applied + RMS tensor of shape [M, N] with RMS values for each element """ m, n = x.size() assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {n}' - out = torch.empty([m, n], dtype=x.dtype, device=x.device) + out = torch.empty_like(x) + inv_rms = torch.empty_like(x) _BLOCK_SIZE_0 = 16 - _RDIM_SIZE_1 = 256 - _launcher(_helion_rms_norm, (triton.cdiv(128, _BLOCK_SIZE_0),), x, weight, out, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) - return out + _RDIM_SIZE_1 = triton.next_power_of_2(n) + _launcher(_helion_rms_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, out, inv_rms, inv_rms.stride(0), inv_rms.stride(1), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) + return (out, inv_rms) --- assertExpectedJournal(TestExamples.test_segment_reduction) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index cd803d3c7..d77b56e6d 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -288,7 +288,7 @@ def test_cross_entropy(self): ) ) - def test_rms_norm(self): + def test_rms_norm_fwd(self): args = ( torch.randn([128, 256], device=DEVICE, dtype=torch.float16), torch.randn([256], device=DEVICE, dtype=torch.float16), @@ -302,12 +302,102 @@ def test_rms_norm(self): check_example( "rms_norm", args, - expected, + (expected, None), # Expected: (output, 1/rms) + fn_name="rms_norm_fwd", block_sizes=[16], indexing="pointer", ) ) + def test_rms_norm_bwd_dw(self): + """Test backward pass for rms norm weight gradient.""" + batch_size, dim = 32, 64 + x = torch.randn([batch_size, dim], device=DEVICE, dtype=torch.float16) + weight = torch.randn( + [dim], device=DEVICE, dtype=torch.float16, requires_grad=True + ) + grad_out = torch.randn([batch_size, dim], device=DEVICE, dtype=torch.float16) + eps = 1e-5 + + # Compute forward pass to get rms + from examples.rms_norm import rms_norm_fwd + + # Create configured kernel with explicit config + config = helion.Config(block_size=32, num_warps=4, num_stages=3) + configured_kernel = helion.kernel(rms_norm_fwd.fn, config=config) + y, rms = configured_kernel(x, weight, eps) + + # Compute expected gradients with PyTorch + x_torch = x.detach().clone().requires_grad_(True) + weight_torch = weight.detach().clone().requires_grad_(True) + y_torch = torch.nn.functional.rms_norm(x_torch, [dim], weight_torch, eps) + y_torch.backward(grad_out) + + # Test the kernel using check_example + args = ( + grad_out, + x, + weight, + rms, + ) + + # rms_norm_bwd_dw returns grad_weight + self.assertExpectedJournal( + check_example( + "rms_norm", + args, + weight_torch.grad, # Expected: grad_weight + fn_name="rms_norm_bwd_dw", + block_size=32, + num_warps=4, + num_stages=3, + rtol=1e-2, + atol=1e-2, + ) + ) + + def test_rms_norm_bwd_dx(self): + """Test backward pass for rms norm input gradient.""" + batch_size, dim = 32, 64 + x = torch.randn( + [batch_size, dim], device=DEVICE, dtype=torch.float16, requires_grad=True + ) + weight = torch.randn( + [dim], device=DEVICE, dtype=torch.float16, requires_grad=True + ) + eps = 1e-5 + grad_out = torch.randn([batch_size, dim], device=DEVICE, dtype=torch.float16) + + # Compute forward pass to get rms + from examples.rms_norm import rms_norm_fwd + + # Create configured kernel with explicit config + config = helion.Config(block_size=32, num_warps=4, num_stages=3) + configured_kernel = helion.kernel(rms_norm_fwd.fn, config=config) + y, rms = configured_kernel(x, weight, eps) + + # Compute expected gradient with PyTorch + x_torch = x.detach().clone().requires_grad_(True) + weight_torch = weight.detach().clone().requires_grad_(True) + y_torch = torch.nn.functional.rms_norm(x_torch, [dim], weight_torch, eps) + y_torch.backward(grad_out) + + args = (grad_out, x, weight, rms) + + self.assertExpectedJournal( + check_example( + "rms_norm", + args, + x_torch.grad, + fn_name="rms_norm_bwd_dx", + block_size=32, + num_warps=4, + num_stages=3, + rtol=1e-3, + atol=1e-3, + ) + ) + def test_embedding_pointers(self): args = ( torch.randint(0, 1024, [8, 128], device=DEVICE, dtype=torch.int32),