Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class RunResult:
"layer_norm": (
"tritonbench.operators.layer_norm.operator",
"examples.layer_norm",
"layer_norm",
"layer_norm_tritonbench",
),
"jagged_softmax": (
"tritonbench.operators.jagged_softmax.operator",
Expand Down
29 changes: 29 additions & 0 deletions examples/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from __future__ import annotations

from typing import Any
from typing import Callable

import torch

Expand Down Expand Up @@ -240,6 +241,34 @@ def layer_norm(
return LayerNormFunction.apply(x, normalized_shape, weight, bias, eps) # type: ignore[no-any-return]


# %%
# Benchmark Wrapper
# --------------
def layer_norm_tritonbench(
tb_op: object,
x: torch.Tensor,
normalized_shape: list[int],
weight: torch.Tensor,
bias: torch.Tensor | None = None,
eps: float = 1e-5,
) -> Callable[[], torch.Tensor]:
"""
Wrapper for tritonbench that matches expected interface.

Args:
tb_op: TritonBench operator instance
x: Input tensor
normalized_shape: Shape to normalize over
weight: Weight parameter
bias: Bias parameter (optional)
eps: Small constant for numerical stability

Returns:
Callable that returns normalized tensor
"""
return lambda: layer_norm(x, normalized_shape, weight, bias, eps)


# %%
def main() -> None:
"""
Expand Down
8 changes: 4 additions & 4 deletions examples/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ def rms_norm_fwd(

Returns:
Output tensor of shape [M, N] with RMS normalization applied
RMS tensor of shape [M, N] with RMS values for each element
RMS tensor of shape [M, 1] with RMS values for each element
"""
m, n = x.size()
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {n}"

out = torch.empty_like(x)
inv_rms = torch.empty_like(x)
inv_rms = torch.empty([m, 1], dtype=x.dtype, device=x.device)

for tile_m in hl.tile(m):
x_tile = x[tile_m, :].to(torch.float32)
Expand Down Expand Up @@ -79,7 +79,7 @@ def rms_norm_bwd_dw(
grad_out: Gradient w.r.t rms norm output [M, N]
x: Original input tensor [M, N]
weight: Weight parameter (used only for dtype/device info) [N]
inv_rms: Inverse RMS tensor [M, N]
inv_rms: Inverse RMS tensor [M, 1]

Returns:
grad_weight: Gradients for weight with shape [N]
Expand Down Expand Up @@ -123,7 +123,7 @@ def rms_norm_bwd_dx(
grad_out: Gradient w.r.t rms norm output [M, N]
x: Original input tensor [M, N]
weight: Weight parameter [N]
inv_rms: Inverse RMS tensor [M, N]
inv_rms: Inverse RMS tensor [M, 1]

Returns:
grad_x: Gradient w.r.t input tensor, shape [M, N]
Expand Down
26 changes: 13 additions & 13 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -2039,7 +2039,7 @@ import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_rms_norm_bwd_dw(x, grad_out, inv_rms, dw, dw_stride_0, grad_out_stride_0, grad_out_stride_1, inv_rms_stride_0, inv_rms_stride_1, x_stride_0, x_stride_1, n, m, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_0: tl.constexpr):
def _helion_rms_norm_bwd_dw(x, grad_out, inv_rms, dw, dw_stride_0, grad_out_stride_0, grad_out_stride_1, inv_rms_stride_0, x_stride_0, x_stride_1, n, m, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_1 = pid_0 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
Expand All @@ -2051,7 +2051,7 @@ def _helion_rms_norm_bwd_dw(x, grad_out, inv_rms, dw, dw_stride_0, grad_out_stri
v_0 = tl.cast(load, tl.float32)
load_1 = tl.load(grad_out + (rows[:, None] * grad_out_stride_0 + indices_1[None, :] * grad_out_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
v_1 = tl.cast(load_1, tl.float32)
load_2 = tl.load(inv_rms + (rows[:, None] * inv_rms_stride_0 + indices_1[None, :] * inv_rms_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
load_2 = tl.load(inv_rms + rows[:, None] * inv_rms_stride_0, mask_0[:, None], other=0)
v_2 = tl.cast(load_2, tl.float32)
v_3 = v_0 * v_2
v_4 = v_1 * v_3
Expand All @@ -2070,7 +2070,7 @@ def rms_norm_bwd_dw(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tenso
grad_out: Gradient w.r.t rms norm output [M, N]
x: Original input tensor [M, N]
weight: Weight parameter (used only for dtype/device info) [N]
inv_rms: Inverse RMS tensor [M, N]
inv_rms: Inverse RMS tensor [M, 1]

Returns:
grad_weight: Gradients for weight with shape [N]
Expand All @@ -2079,7 +2079,7 @@ def rms_norm_bwd_dw(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tenso
dw = torch.empty([n], dtype=weight.dtype, device=weight.device)
_BLOCK_SIZE_1 = 32
_RDIM_SIZE_0 = triton.next_power_of_2(m)
_launcher(_helion_rms_norm_bwd_dw, (triton.cdiv(n, _BLOCK_SIZE_1),), x, grad_out, inv_rms, dw, dw.stride(0), grad_out.stride(0), grad_out.stride(1), inv_rms.stride(0), inv_rms.stride(1), x.stride(0), x.stride(1), n, m, _BLOCK_SIZE_1, _RDIM_SIZE_0, num_warps=4, num_stages=3)
_launcher(_helion_rms_norm_bwd_dw, (triton.cdiv(n, _BLOCK_SIZE_1),), x, grad_out, inv_rms, dw, dw.stride(0), grad_out.stride(0), grad_out.stride(1), inv_rms.stride(0), x.stride(0), x.stride(1), n, m, _BLOCK_SIZE_1, _RDIM_SIZE_0, num_warps=4, num_stages=3)
return dw

--- assertExpectedJournal(TestExamples.test_rms_norm_bwd_dx)
Expand All @@ -2091,7 +2091,7 @@ import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_rms_norm_bwd_dx(x, grad_out, weight, inv_rms, grad_x, grad_out_stride_0, grad_out_stride_1, grad_x_stride_0, grad_x_stride_1, inv_rms_stride_0, inv_rms_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
def _helion_rms_norm_bwd_dx(x, grad_out, weight, inv_rms, grad_x, grad_out_stride_0, grad_out_stride_1, grad_x_stride_0, grad_x_stride_1, inv_rms_stride_0, weight_stride_0, x_stride_0, x_stride_1, m, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
Expand All @@ -2103,7 +2103,7 @@ def _helion_rms_norm_bwd_dx(x, grad_out, weight, inv_rms, grad_x, grad_out_strid
v_1 = tl.cast(load_1, tl.float32)
load_2 = tl.load(weight + indices_1 * weight_stride_0, None)
v_2 = tl.cast(load_2, tl.float32)
load_3 = tl.load(inv_rms + (indices_0[:, None] * inv_rms_stride_0 + indices_1[None, :] * inv_rms_stride_1), mask_0[:, None], other=0)
load_3 = tl.load(inv_rms + indices_0[:, None] * inv_rms_stride_0, mask_0[:, None], other=0)
v_3 = tl.cast(load_3, tl.float32)
v_4 = v_2[None, :]
v_5 = v_1 * v_4
Expand Down Expand Up @@ -2131,7 +2131,7 @@ def rms_norm_bwd_dx(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tenso
grad_out: Gradient w.r.t rms norm output [M, N]
x: Original input tensor [M, N]
weight: Weight parameter [N]
inv_rms: Inverse RMS tensor [M, N]
inv_rms: Inverse RMS tensor [M, 1]

Returns:
grad_x: Gradient w.r.t input tensor, shape [M, N]
Expand All @@ -2140,7 +2140,7 @@ def rms_norm_bwd_dx(grad_out: torch.Tensor, x: torch.Tensor, weight: torch.Tenso
grad_x = torch.empty_like(x)
_BLOCK_SIZE_0 = 32
_RDIM_SIZE_1 = 64
_launcher(_helion_rms_norm_bwd_dx, (triton.cdiv(m, _BLOCK_SIZE_0),), x, grad_out, weight, inv_rms, grad_x, grad_out.stride(0), grad_out.stride(1), grad_x.stride(0), grad_x.stride(1), inv_rms.stride(0), inv_rms.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
_launcher(_helion_rms_norm_bwd_dx, (triton.cdiv(m, _BLOCK_SIZE_0),), x, grad_out, weight, inv_rms, grad_x, grad_out.stride(0), grad_out.stride(1), grad_x.stride(0), grad_x.stride(1), inv_rms.stride(0), weight.stride(0), x.stride(0), x.stride(1), m, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
return grad_x

--- assertExpectedJournal(TestExamples.test_rms_norm_fwd)
Expand All @@ -2153,7 +2153,7 @@ from torch._inductor.runtime.triton_compat import libdevice
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_rms_norm_fwd(x, weight, out, inv_rms, inv_rms_stride_0, inv_rms_stride_1, out_stride_0, out_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, n, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
def _helion_rms_norm_fwd(x, weight, out, inv_rms, inv_rms_stride_0, out_stride_0, out_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, n, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
Expand All @@ -2175,7 +2175,7 @@ def _helion_rms_norm_fwd(x, weight, out, inv_rms, inv_rms_stride_0, inv_rms_stri
v_9 = tl.cast(v_8, tl.float16)
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_9, mask_0[:, None] & mask_1[None, :])
v_10 = tl.cast(v_4, tl.float16)
tl.store(inv_rms + (indices_0[:, None] * inv_rms_stride_0 + indices_1[None, :] * inv_rms_stride_1), v_10, mask_0[:, None] & mask_1[None, :])
tl.store(inv_rms + indices_0[:, None] * inv_rms_stride_0, v_10, mask_0[:, None])

def rms_norm_fwd(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
"""
Expand All @@ -2191,15 +2191,15 @@ def rms_norm_fwd(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05, *, _la

Returns:
Output tensor of shape [M, N] with RMS normalization applied
RMS tensor of shape [M, N] with RMS values for each element
RMS tensor of shape [M, 1] with RMS values for each element
"""
m, n = x.size()
assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {n}'
out = torch.empty_like(x)
inv_rms = torch.empty_like(x)
inv_rms = torch.empty([m, 1], dtype=x.dtype, device=x.device)
_BLOCK_SIZE_0 = 16
_RDIM_SIZE_1 = triton.next_power_of_2(n)
_launcher(_helion_rms_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, out, inv_rms, inv_rms.stride(0), inv_rms.stride(1), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
_launcher(_helion_rms_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, out, inv_rms, inv_rms.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
return (out, inv_rms)

--- assertExpectedJournal(TestExamples.test_segment_reduction)
Expand Down
Loading