diff --git a/benchmarks/run.py b/benchmarks/run.py index 68d0b7680..8fce9ce5b 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -136,7 +136,7 @@ class RunResult: "layer_norm": ( "tritonbench.operators.layer_norm.operator", "examples.layer_norm", - "layer_norm", + "layer_norm_tritonbench", ), "jagged_softmax": ( "tritonbench.operators.jagged_softmax.operator", diff --git a/examples/layer_norm.py b/examples/layer_norm.py index 6b4dac108..a86699316 100644 --- a/examples/layer_norm.py +++ b/examples/layer_norm.py @@ -10,6 +10,7 @@ from __future__ import annotations from typing import Any +from typing import Callable import torch @@ -240,6 +241,34 @@ def layer_norm( return LayerNormFunction.apply(x, normalized_shape, weight, bias, eps) # type: ignore[no-any-return] +# %% +# Benchmark Wrapper +# -------------- +def layer_norm_tritonbench( + tb_op: object, + x: torch.Tensor, + normalized_shape: list[int], + weight: torch.Tensor, + bias: torch.Tensor | None = None, + eps: float = 1e-5, +) -> Callable[[], torch.Tensor]: + """ + Wrapper for tritonbench that matches expected interface. + + Args: + tb_op: TritonBench operator instance + x: Input tensor + normalized_shape: Shape to normalize over + weight: Weight parameter + bias: Bias parameter (optional) + eps: Small constant for numerical stability + + Returns: + Callable that returns normalized tensor + """ + return lambda: layer_norm(x, normalized_shape, weight, bias, eps) + + # %% def main() -> None: """ diff --git a/examples/rms_norm.py b/examples/rms_norm.py index 475228b29..8f47fda33 100644 --- a/examples/rms_norm.py +++ b/examples/rms_norm.py @@ -41,13 +41,13 @@ def rms_norm_fwd( Returns: Output tensor of shape [M, N] with RMS normalization applied - RMS tensor of shape [M, N] with RMS values for each element + RMS tensor of shape [M, 1] with RMS values for each element """ m, n = x.size() assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {n}" out = torch.empty_like(x) - inv_rms = torch.empty_like(x) + inv_rms = torch.empty([m, 1], dtype=x.dtype, device=x.device) for tile_m in hl.tile(m): x_tile = x[tile_m, :].to(torch.float32) @@ -79,7 +79,7 @@ def rms_norm_bwd_dw( grad_out: Gradient w.r.t rms norm output [M, N] x: Original input tensor [M, N] weight: Weight parameter (used only for dtype/device info) [N] - inv_rms: Inverse RMS tensor [M, N] + inv_rms: Inverse RMS tensor [M, 1] Returns: grad_weight: Gradients for weight with shape [N] @@ -123,7 +123,7 @@ def rms_norm_bwd_dx( grad_out: Gradient w.r.t rms norm output [M, N] x: Original input tensor [M, N] weight: Weight parameter [N] - inv_rms: Inverse RMS tensor [M, N] + inv_rms: Inverse RMS tensor [M, 1] Returns: grad_x: Gradient w.r.t input tensor, shape [M, N] diff --git a/test/test_examples.expected b/test/test_examples.expected index 2c5c33aff..5760c3ca8 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -2039,7 +2039,7 @@ import triton.language as tl from helion.runtime import default_launcher as _default_launcher @triton.jit -def _helion_rms_norm_bwd_dw(x, grad_out, inv_rms, dw, dw_stride_0, grad_out_stride_0, grad_out_stride_1, inv_rms_stride_0, inv_rms_stride_1, x_stride_0, x_stride_1, n, m, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_0: tl.constexpr): +def _helion_rms_norm_bwd_dw(x, grad_out, inv_rms, dw, dw_stride_0, grad_out_stride_0, grad_out_stride_1, inv_rms_stride_0, x_stride_0, x_stride_1, n, m, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_0: tl.constexpr): pid_0 = tl.program_id(0) offset_1 = pid_0 * _BLOCK_SIZE_1 indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) @@ -2051,7 +2051,7 @@ def _helion_rms_norm_bwd_dw(x, grad_out, inv_rms, dw, dw_stride_0, grad_out_stri v_0 = tl.cast(load, tl.float32) load_1 = tl.load(grad_out + (rows[:, None] * grad_out_stride_0 + indices_1[None, :] * grad_out_stride_1), mask_0[:, None] & mask_1[None, :], other=0) v_1 = tl.cast(load_1, tl.float32) - load_2 = tl.load(inv_rms + (rows[:, None] * inv_rms_stride_0 + indices_1[None, :] * inv_rms_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + load_2 = tl.load(inv_rms + rows[:, None] * inv_rms_stride_0, mask_0[:, None], other=0) v_2 = tl.cast(load_2, tl.float32) v_3 = v_0 * v_2 v_4 = v_1 * v_3 @@ -2070,7 +2070,7 @@ def rms_norm_bwd_dw(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tenso grad_out: Gradient w.r.t rms norm output [M, N] x: Original input tensor [M, N] weight: Weight parameter (used only for dtype/device info) [N] - inv_rms: Inverse RMS tensor [M, N] + inv_rms: Inverse RMS tensor [M, 1] Returns: grad_weight: Gradients for weight with shape [N] @@ -2079,7 +2079,7 @@ def rms_norm_bwd_dw(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tenso dw = torch.empty([n], dtype=weight.dtype, device=weight.device) _BLOCK_SIZE_1 = 32 _RDIM_SIZE_0 = triton.next_power_of_2(m) - _launcher(_helion_rms_norm_bwd_dw, (triton.cdiv(n, _BLOCK_SIZE_1),), x, grad_out, inv_rms, dw, dw.stride(0), grad_out.stride(0), grad_out.stride(1), inv_rms.stride(0), inv_rms.stride(1), x.stride(0), x.stride(1), n, m, _BLOCK_SIZE_1, _RDIM_SIZE_0, num_warps=4, num_stages=3) + _launcher(_helion_rms_norm_bwd_dw, (triton.cdiv(n, _BLOCK_SIZE_1),), x, grad_out, inv_rms, dw, dw.stride(0), grad_out.stride(0), grad_out.stride(1), inv_rms.stride(0), x.stride(0), x.stride(1), n, m, _BLOCK_SIZE_1, _RDIM_SIZE_0, num_warps=4, num_stages=3) return dw --- assertExpectedJournal(TestExamples.test_rms_norm_bwd_dx) @@ -2091,7 +2091,7 @@ import triton.language as tl from helion.runtime import default_launcher as _default_launcher @triton.jit -def _helion_rms_norm_bwd_dx(x, grad_out, weight, inv_rms, grad_x, grad_out_stride_0, grad_out_stride_1, grad_x_stride_0, grad_x_stride_1, inv_rms_stride_0, inv_rms_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): +def _helion_rms_norm_bwd_dx(x, grad_out, weight, inv_rms, grad_x, grad_out_stride_0, grad_out_stride_1, grad_x_stride_0, grad_x_stride_1, inv_rms_stride_0, weight_stride_0, x_stride_0, x_stride_1, m, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): pid_0 = tl.program_id(0) offset_0 = pid_0 * _BLOCK_SIZE_0 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) @@ -2103,7 +2103,7 @@ def _helion_rms_norm_bwd_dx(x, grad_out, weight, inv_rms, grad_x, grad_out_strid v_1 = tl.cast(load_1, tl.float32) load_2 = tl.load(weight + indices_1 * weight_stride_0, None) v_2 = tl.cast(load_2, tl.float32) - load_3 = tl.load(inv_rms + (indices_0[:, None] * inv_rms_stride_0 + indices_1[None, :] * inv_rms_stride_1), mask_0[:, None], other=0) + load_3 = tl.load(inv_rms + indices_0[:, None] * inv_rms_stride_0, mask_0[:, None], other=0) v_3 = tl.cast(load_3, tl.float32) v_4 = v_2[None, :] v_5 = v_1 * v_4 @@ -2131,7 +2131,7 @@ def rms_norm_bwd_dx(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tenso grad_out: Gradient w.r.t rms norm output [M, N] x: Original input tensor [M, N] weight: Weight parameter [N] - inv_rms: Inverse RMS tensor [M, N] + inv_rms: Inverse RMS tensor [M, 1] Returns: grad_x: Gradient w.r.t input tensor, shape [M, N] @@ -2140,7 +2140,7 @@ def rms_norm_bwd_dx(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tenso grad_x = torch.empty_like(x) _BLOCK_SIZE_0 = 32 _RDIM_SIZE_1 = 64 - _launcher(_helion_rms_norm_bwd_dx, (triton.cdiv(m, _BLOCK_SIZE_0),), x, grad_out, weight, inv_rms, grad_x, grad_out.stride(0), grad_out.stride(1), grad_x.stride(0), grad_x.stride(1), inv_rms.stride(0), inv_rms.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) + _launcher(_helion_rms_norm_bwd_dx, (triton.cdiv(m, _BLOCK_SIZE_0),), x, grad_out, weight, inv_rms, grad_x, grad_out.stride(0), grad_out.stride(1), grad_x.stride(0), grad_x.stride(1), inv_rms.stride(0), weight.stride(0), x.stride(0), x.stride(1), m, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) return grad_x --- assertExpectedJournal(TestExamples.test_rms_norm_fwd) @@ -2153,7 +2153,7 @@ from torch._inductor.runtime.triton_compat import libdevice from helion.runtime import default_launcher as _default_launcher @triton.jit -def _helion_rms_norm_fwd(x, weight, out, inv_rms, inv_rms_stride_0, inv_rms_stride_1, out_stride_0, out_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, n, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): +def _helion_rms_norm_fwd(x, weight, out, inv_rms, inv_rms_stride_0, out_stride_0, out_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, n, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): pid_0 = tl.program_id(0) offset_0 = pid_0 * _BLOCK_SIZE_0 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) @@ -2175,7 +2175,7 @@ def _helion_rms_norm_fwd(x, weight, out, inv_rms, inv_rms_stride_0, inv_rms_stri v_9 = tl.cast(v_8, tl.float16) tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_9, mask_0[:, None] & mask_1[None, :]) v_10 = tl.cast(v_4, tl.float16) - tl.store(inv_rms + (indices_0[:, None] * inv_rms_stride_0 + indices_1[None, :] * inv_rms_stride_1), v_10, mask_0[:, None] & mask_1[None, :]) + tl.store(inv_rms + indices_0[:, None] * inv_rms_stride_0, v_10, mask_0[:, None]) def rms_norm_fwd(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher): """ @@ -2191,15 +2191,15 @@ def rms_norm_fwd(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05, *, _la Returns: Output tensor of shape [M, N] with RMS normalization applied - RMS tensor of shape [M, N] with RMS values for each element + RMS tensor of shape [M, 1] with RMS values for each element """ m, n = x.size() assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {n}' out = torch.empty_like(x) - inv_rms = torch.empty_like(x) + inv_rms = torch.empty([m, 1], dtype=x.dtype, device=x.device) _BLOCK_SIZE_0 = 16 _RDIM_SIZE_1 = triton.next_power_of_2(n) - _launcher(_helion_rms_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, out, inv_rms, inv_rms.stride(0), inv_rms.stride(1), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) + _launcher(_helion_rms_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, out, inv_rms, inv_rms.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) return (out, inv_rms) --- assertExpectedJournal(TestExamples.test_segment_reduction)