diff --git a/benchmarks/run.py b/benchmarks/run.py index 724423528..380af33db 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -156,6 +156,11 @@ class RunResult: ("examples.matmul_split_k", "matmul_split_k_tritonbench"), ], ), + "welford": ( + "tritonbench.operators.welford.operator", + "examples.welford", + "welford", + ), } @@ -240,6 +245,14 @@ class RunResult: "helion_jsd_tritonbench-speedup": "helion_speedup", "helion_jsd_tritonbench-accuracy": "helion_accuracy", }, + "welford": { + "test_welford-speedup": "triton_speedup", + "test_welford-accuracy": "triton_accuracy", + "torch_compile_layer_norm-speedup": "torch_compile_speedup", + "torch_compile_layer_norm-accuracy": "torch_compile_accuracy", + "helion_welford-speedup": "helion_speedup", + "helion_welford-accuracy": "helion_accuracy", + }, } diff --git a/examples/welford.py b/examples/welford.py new file mode 100644 index 000000000..3cc20b77c --- /dev/null +++ b/examples/welford.py @@ -0,0 +1,123 @@ +""" +Welford Example +================ + +This example demonstrates how to implement a welford layernorm using Helion. +""" + +# %% +# Imports +# ------- +from __future__ import annotations + +import torch + +import helion +from helion._testing import run_example +import helion.language as hl + + +# %% +# Welford Kernel Implementations +# ------------------- +@helion.kernel() +def welford( + weight: torch.Tensor, bias: torch.Tensor, x: torch.Tensor, eps: float = 1e-05 +) -> torch.Tensor: + """ + Applies LayerNorm using Welford's algorithm for mean/variance. + Args: + weight: weight tensor of shape [N] + bias: bias tensor of shape [N] + x: input tensor of shape [M, N] + Returns: + Output tensor of shape [M, N] + """ + m, n = x.size() + + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + + for tile_m in hl.tile(m): + acc_cnt = torch.zeros_like(x[tile_m, 0], dtype=torch.float32) + acc_mean = torch.zeros_like(acc_cnt) + acc_m2 = torch.zeros_like(acc_cnt) + + for tile_n in hl.tile(n): + chunk = x[tile_m, tile_n] + Tn = chunk.size(-1) + sum_x = torch.sum(chunk, dim=-1) + sum_x2 = torch.sum(chunk * chunk, dim=-1) + mean_c = sum_x / Tn + m2_c = sum_x2 - (sum_x * sum_x) / Tn + + delta = mean_c - acc_mean + new_cnt = acc_cnt + Tn + new_mean = acc_mean + delta * (Tn / new_cnt) + new_m2 = acc_m2 + m2_c + delta * delta * (acc_cnt * Tn / new_cnt) + + acc_cnt, acc_mean, acc_m2 = new_cnt, new_mean, new_m2 + + rstd_tile = torch.rsqrt(acc_m2 / acc_cnt + eps) + mean_col = acc_mean[:, None] + rstd_col = rstd_tile[:, None] + + for tile_n in hl.tile(n): + xi_chuck = x[tile_m, tile_n] + w_chuck = weight[tile_n][None, :] + b_chuck = bias[tile_n][None, :] + + y = (xi_chuck - mean_col) * rstd_col + y = y * w_chuck + b_chuck + + out[tile_m, tile_n] = y.to(x.dtype) + return out + + +# %% +# Baseline Function +# ------------------- +def eager_layer_norm( + weight: torch.Tensor, bias: torch.Tensor, x: torch.Tensor, eps: float = 1e-05 +) -> torch.Tensor: + return torch.nn.functional.layer_norm( + x, normalized_shape=(x.shape[-1],), weight=weight, bias=bias, eps=eps + ) + + +# %% +# Verification Function +# ------------------- +def check(s: int, d: int) -> None: + """ + Verify the welford kernel implementation against PyTorch's native layer_norm function. + + Args: + s: First dimension of the test tensor + d: Second dimension of the test tensor + """ + + weight = torch.rand((d,), device="cuda:0", dtype=torch.float32) + bias = torch.rand((d,), device="cuda:0", dtype=torch.float32) + x = torch.rand((s, d), device="cuda:0", dtype=torch.float32) + + kernels = {"helion": welford} + run_example(kernels, eager_layer_norm, (weight, bias, x)) + + +# %% +# Main Function +# ----------- +def main() -> None: + """ + Main entry point that runs the welford kernel verification with different tensor sizes. + + Tests with two configurations: + - 262144x1536 + - 262144x2048 + """ + check(262144, 1536) + check(262144, 2048) + + +if __name__ == "__main__": + main() diff --git a/test/test_examples.expected b/test/test_examples.expected index e1d689835..97ed6985b 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -3040,3 +3040,102 @@ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]] _BLOCK_SIZE_2 = 16 _launcher(_helion_matmul, (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4) return out + +--- assertExpectedJournal(TestExamples.test_welford) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime.triton_compat import libdevice +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_welford(x, weight, bias, out, bias_stride_0, out_stride_0, out_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, n, eps, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + acc_cnt = tl.full([_BLOCK_SIZE_0], 0, tl.float32) + acc_mean = tl.full([_BLOCK_SIZE_0], 0, tl.float32) + acc_m2 = tl.full([_BLOCK_SIZE_0], 0, tl.float32) + for offset_1 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_1): + indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_1 < n + acc_mean_copy = acc_mean + acc_cnt_copy = acc_cnt + acc_m2_copy = acc_m2 + acc_mean_copy_0 = acc_mean_copy + acc_cnt_copy_0 = acc_cnt_copy + acc_m2_copy_0 = acc_m2_copy + chunk = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + sum_x = tl.cast(tl.sum(chunk, 1), tl.float32) + v_0 = chunk * chunk + sum_x2 = tl.cast(tl.sum(v_0, 1), tl.float32) + _BLOCK_SIZE_1_ = _BLOCK_SIZE_1 + v_1 = tl.cast(_BLOCK_SIZE_1_, tl.float32) + v_2 = sum_x / v_1 + v_3 = sum_x * sum_x + _BLOCK_SIZE_1__1 = _BLOCK_SIZE_1 + v_4 = tl.cast(_BLOCK_SIZE_1__1, tl.float32) + v_5 = v_3 / v_4 + v_6 = sum_x2 - v_5 + v_7 = v_2 - acc_mean_copy_0 + _BLOCK_SIZE_1__2 = _BLOCK_SIZE_1 + v_8 = tl.cast(_BLOCK_SIZE_1__2, tl.float32) + acc_cnt = acc_cnt_copy_0 + v_8 + v_10 = tl.full([], 1, tl.int32) + v_11 = v_10 / acc_cnt + _BLOCK_SIZE_1__3 = _BLOCK_SIZE_1 + v_12 = tl.cast(_BLOCK_SIZE_1__3, tl.float32) + v_13 = v_11 * v_12 + v_14 = v_7 * v_13 + acc_mean = acc_mean_copy_0 + v_14 + v_16 = acc_m2_copy_0 + v_6 + v_17 = v_7 * v_7 + _BLOCK_SIZE_1__4 = _BLOCK_SIZE_1 + v_18 = tl.cast(_BLOCK_SIZE_1__4, tl.float32) + v_19 = acc_cnt_copy_0 * v_18 + v_20 = v_19 / acc_cnt + v_21 = v_17 * v_20 + acc_m2 = v_16 + v_21 + v_23 = acc_m2 / acc_cnt + v_24 = v_23 + eps + v_25 = libdevice.rsqrt(v_24) + mean_col = acc_mean[:, None] + rstd_col = v_25[:, None] + for offset_2 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < n + mean_col_copy = mean_col + rstd_col_copy = rstd_col + mean_col_copy_0 = mean_col_copy + rstd_col_copy_0 = rstd_col_copy + xi_chuck = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + load_1 = tl.load(weight + indices_2 * weight_stride_0, mask_2, other=0) + w_chuck = load_1[None, :] + load_2 = tl.load(bias + indices_2 * bias_stride_0, mask_2, other=0) + b_chuck = load_2[None, :] + v_26 = xi_chuck - mean_col_copy_0 + v_27 = v_26 * rstd_col_copy_0 + v_28 = v_27 * w_chuck + v_29 = v_28 + b_chuck + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_2[None, :] * out_stride_1), v_29, mask_0[:, None] & mask_2[None, :]) + +def welford(weight: torch.Tensor, bias: torch.Tensor, x: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher): + """ + Applies LayerNorm using Welford's algorithm for mean/variance. + Args: + weight: weight tensor of shape [N] + bias: bias tensor of shape [N] + x: input tensor of shape [M, N] + Returns: + Output tensor of shape [M, N] + """ + m, n = x.size() + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + _BLOCK_SIZE_0 = 16 + _BLOCK_SIZE_1 = 16 + _BLOCK_SIZE_2 = 16 + _launcher(_helion_welford, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, bias, out, bias.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) + return out diff --git a/test/test_examples.py b/test/test_examples.py index f66686891..9679abb17 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -288,6 +288,26 @@ def test_cross_entropy(self): ) ) + def test_welford(self): + s, d = 128, 1024 + weight = torch.rand((d,), device=DEVICE, dtype=torch.float32) + bias = torch.rand((d,), device=DEVICE, dtype=torch.float32) + x = torch.rand((s, d), device=DEVICE, dtype=torch.float32) + + self.assertExpectedJournal( + check_example( + "welford", + (weight, bias, x), + torch.nn.functional.layer_norm( + x, + normalized_shape=(x.shape[-1],), + weight=weight, + bias=bias, + eps=1e-05, + ), + ) + ) + def test_rms_norm_fwd(self): args = ( torch.randn([128, 256], device=DEVICE, dtype=torch.float16),