Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ class RunResult:
("examples.matmul_split_k", "matmul_split_k_tritonbench"),
],
),
"welford": (
"tritonbench.operators.welford.operator",
"examples.welford",
"welford",
),
}


Expand Down Expand Up @@ -240,6 +245,14 @@ class RunResult:
"helion_jsd_tritonbench-speedup": "helion_speedup",
"helion_jsd_tritonbench-accuracy": "helion_accuracy",
},
"welford": {
"test_welford-speedup": "triton_speedup",
"test_welford-accuracy": "triton_accuracy",
"torch_compile_layer_norm-speedup": "torch_compile_speedup",
"torch_compile_layer_norm-accuracy": "torch_compile_accuracy",
"helion_welford-speedup": "helion_speedup",
"helion_welford-accuracy": "helion_accuracy",
},
}


Expand Down
123 changes: 123 additions & 0 deletions examples/welford.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
Welford Example
================

This example demonstrates how to implement a welford layernorm using Helion.
"""

# %%
# Imports
# -------
from __future__ import annotations

import torch

import helion
from helion._testing import run_example
import helion.language as hl


# %%
# Welford Kernel Implementations
# -------------------
@helion.kernel()
def welford(
weight: torch.Tensor, bias: torch.Tensor, x: torch.Tensor, eps: float = 1e-05
) -> torch.Tensor:
"""
Applies LayerNorm using Welford's algorithm for mean/variance.
Args:
weight: weight tensor of shape [N]
bias: bias tensor of shape [N]
x: input tensor of shape [M, N]
Returns:
Output tensor of shape [M, N]
"""
m, n = x.size()

out = torch.empty([m, n], dtype=x.dtype, device=x.device)

for tile_m in hl.tile(m):
acc_cnt = torch.zeros_like(x[tile_m, 0], dtype=torch.float32)
acc_mean = torch.zeros_like(acc_cnt)
acc_m2 = torch.zeros_like(acc_cnt)

for tile_n in hl.tile(n):
chunk = x[tile_m, tile_n]
Tn = chunk.size(-1)
sum_x = torch.sum(chunk, dim=-1)
sum_x2 = torch.sum(chunk * chunk, dim=-1)
mean_c = sum_x / Tn
m2_c = sum_x2 - (sum_x * sum_x) / Tn

delta = mean_c - acc_mean
new_cnt = acc_cnt + Tn
new_mean = acc_mean + delta * (Tn / new_cnt)
new_m2 = acc_m2 + m2_c + delta * delta * (acc_cnt * Tn / new_cnt)

acc_cnt, acc_mean, acc_m2 = new_cnt, new_mean, new_m2

rstd_tile = torch.rsqrt(acc_m2 / acc_cnt + eps)
mean_col = acc_mean[:, None]
rstd_col = rstd_tile[:, None]

for tile_n in hl.tile(n):
xi_chuck = x[tile_m, tile_n]
w_chuck = weight[tile_n][None, :]
b_chuck = bias[tile_n][None, :]

y = (xi_chuck - mean_col) * rstd_col
y = y * w_chuck + b_chuck

out[tile_m, tile_n] = y.to(x.dtype)
return out


# %%
# Baseline Function
# -------------------
def eager_layer_norm(
weight: torch.Tensor, bias: torch.Tensor, x: torch.Tensor, eps: float = 1e-05
) -> torch.Tensor:
return torch.nn.functional.layer_norm(
x, normalized_shape=(x.shape[-1],), weight=weight, bias=bias, eps=eps
)


# %%
# Verification Function
# -------------------
def check(s: int, d: int) -> None:
"""
Verify the welford kernel implementation against PyTorch's native layer_norm function.

Args:
s: First dimension of the test tensor
d: Second dimension of the test tensor
"""

weight = torch.rand((d,), device="cuda:0", dtype=torch.float32)
bias = torch.rand((d,), device="cuda:0", dtype=torch.float32)
x = torch.rand((s, d), device="cuda:0", dtype=torch.float32)

kernels = {"helion": welford}
run_example(kernels, eager_layer_norm, (weight, bias, x))


# %%
# Main Function
# -----------
def main() -> None:
"""
Main entry point that runs the welford kernel verification with different tensor sizes.

Tests with two configurations:
- 262144x1536
- 262144x2048
"""
check(262144, 1536)
check(262144, 2048)


if __name__ == "__main__":
main()
99 changes: 99 additions & 0 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -3040,3 +3040,102 @@ def matmul(x: Tensor, y: Tensor, epilogue: Callable[[Tensor, tuple[Tensor, ...]]
_BLOCK_SIZE_2 = 16
_launcher(_helion_matmul, (triton.cdiv(1024, _BLOCK_SIZE_0) * triton.cdiv(1024, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
return out

--- assertExpectedJournal(TestExamples.test_welford)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from torch._inductor.runtime.triton_compat import libdevice
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_welford(x, weight, bias, out, bias_stride_0, out_stride_0, out_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, n, eps, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < m
acc_cnt = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
acc_mean = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
acc_m2 = tl.full([_BLOCK_SIZE_0], 0, tl.float32)
for offset_1 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_1):
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_1 < n
acc_mean_copy = acc_mean
acc_cnt_copy = acc_cnt
acc_m2_copy = acc_m2
acc_mean_copy_0 = acc_mean_copy
acc_cnt_copy_0 = acc_cnt_copy
acc_m2_copy_0 = acc_m2_copy
chunk = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
sum_x = tl.cast(tl.sum(chunk, 1), tl.float32)
v_0 = chunk * chunk
sum_x2 = tl.cast(tl.sum(v_0, 1), tl.float32)
_BLOCK_SIZE_1_ = _BLOCK_SIZE_1
v_1 = tl.cast(_BLOCK_SIZE_1_, tl.float32)
v_2 = sum_x / v_1
v_3 = sum_x * sum_x
_BLOCK_SIZE_1__1 = _BLOCK_SIZE_1
v_4 = tl.cast(_BLOCK_SIZE_1__1, tl.float32)
v_5 = v_3 / v_4
v_6 = sum_x2 - v_5
v_7 = v_2 - acc_mean_copy_0
_BLOCK_SIZE_1__2 = _BLOCK_SIZE_1
v_8 = tl.cast(_BLOCK_SIZE_1__2, tl.float32)
acc_cnt = acc_cnt_copy_0 + v_8
v_10 = tl.full([], 1, tl.int32)
v_11 = v_10 / acc_cnt
_BLOCK_SIZE_1__3 = _BLOCK_SIZE_1
v_12 = tl.cast(_BLOCK_SIZE_1__3, tl.float32)
v_13 = v_11 * v_12
v_14 = v_7 * v_13
acc_mean = acc_mean_copy_0 + v_14
v_16 = acc_m2_copy_0 + v_6
v_17 = v_7 * v_7
_BLOCK_SIZE_1__4 = _BLOCK_SIZE_1
v_18 = tl.cast(_BLOCK_SIZE_1__4, tl.float32)
v_19 = acc_cnt_copy_0 * v_18
v_20 = v_19 / acc_cnt
v_21 = v_17 * v_20
acc_m2 = v_16 + v_21
v_23 = acc_m2 / acc_cnt
v_24 = v_23 + eps
v_25 = libdevice.rsqrt(v_24)
mean_col = acc_mean[:, None]
rstd_col = v_25[:, None]
for offset_2 in tl.range(0, n.to(tl.int32), _BLOCK_SIZE_2):
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
mask_2 = indices_2 < n
mean_col_copy = mean_col
rstd_col_copy = rstd_col
mean_col_copy_0 = mean_col_copy
rstd_col_copy_0 = rstd_col_copy
xi_chuck = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
load_1 = tl.load(weight + indices_2 * weight_stride_0, mask_2, other=0)
w_chuck = load_1[None, :]
load_2 = tl.load(bias + indices_2 * bias_stride_0, mask_2, other=0)
b_chuck = load_2[None, :]
v_26 = xi_chuck - mean_col_copy_0
v_27 = v_26 * rstd_col_copy_0
v_28 = v_27 * w_chuck
v_29 = v_28 + b_chuck
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_2[None, :] * out_stride_1), v_29, mask_0[:, None] & mask_2[None, :])

def welford(weight: torch.Tensor, bias: torch.Tensor, x: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
"""
Applies LayerNorm using Welford's algorithm for mean/variance.
Args:
weight: weight tensor of shape [N]
bias: bias tensor of shape [N]
x: input tensor of shape [M, N]
Returns:
Output tensor of shape [M, N]
"""
m, n = x.size()
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
_BLOCK_SIZE_0 = 16
_BLOCK_SIZE_1 = 16
_BLOCK_SIZE_2 = 16
_launcher(_helion_welford, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, bias, out, bias.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
return out
20 changes: 20 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,26 @@ def test_cross_entropy(self):
)
)

def test_welford(self):
s, d = 128, 1024
weight = torch.rand((d,), device=DEVICE, dtype=torch.float32)
bias = torch.rand((d,), device=DEVICE, dtype=torch.float32)
x = torch.rand((s, d), device=DEVICE, dtype=torch.float32)

self.assertExpectedJournal(
check_example(
"welford",
(weight, bias, x),
torch.nn.functional.layer_norm(
x,
normalized_shape=(x.shape[-1],),
weight=weight,
bias=bias,
eps=1e-05,
),
)
)

def test_rms_norm_fwd(self):
args = (
torch.randn([128, 256], device=DEVICE, dtype=torch.float16),
Expand Down
Loading