Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions examples/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,22 @@ def rms_norm_fwd(
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {n}"

out = torch.empty_like(x)
inv_rms = torch.empty([m, 1], dtype=x.dtype, device=x.device)
inv_rms = torch.empty([m], dtype=x.dtype, device=x.device)

for tile_m in hl.tile(m):
x_tile = x[tile_m, :].to(torch.float32)

# Compute inverse RMS: 1/sqrt(mean(x^2) + eps)
x_squared = x_tile * x_tile
mean_x_squared = torch.mean(x_squared, dim=-1, keepdim=True)
mean_x_squared = torch.mean(x_squared, dim=-1)
inv_rms_tile = torch.rsqrt(mean_x_squared + eps)

# Apply normalization and weight
normalized = x_tile * inv_rms_tile
normalized = x_tile * inv_rms_tile[:, None]
out[tile_m, :] = (normalized * weight[:].to(torch.float32)).to(out.dtype)
inv_rms[tile_m, :] = inv_rms_tile.to(out.dtype)
inv_rms[tile_m] = inv_rms_tile.to(out.dtype)

return out, inv_rms
return out, inv_rms.reshape(-1, 1)


@helion.kernel
Expand Down
11 changes: 6 additions & 5 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -2972,19 +2972,20 @@ def _helion_rms_norm_fwd(x, weight, out, inv_rms, inv_rms_stride_0, out_stride_0
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
v_0 = tl.cast(load, tl.float32)
v_1 = v_0 * v_0
mean_x_squared_extra = tl.cast(tl.reshape(tl.sum(v_1, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
mean_x_squared_extra = tl.cast(tl.sum(v_1, 1), tl.float32)
v_2 = mean_x_squared_extra / n.to(tl.float32)
v_3 = v_2 + eps
v_4 = libdevice.rsqrt(v_3)
v_5 = v_0 * v_4
subscript = v_4[:, None]
v_5 = v_0 * subscript
load_1 = tl.load(weight + indices_1 * weight_stride_0, mask_1, other=0)
v_6 = tl.cast(load_1, tl.float32)
v_7 = v_6[None, :]
v_8 = v_5 * v_7
v_9 = tl.cast(v_8, tl.float16)
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_9, mask_0[:, None] & mask_1[None, :])
v_10 = tl.cast(v_4, tl.float16)
tl.store(inv_rms + indices_0[:, None] * inv_rms_stride_0, v_10, mask_0[:, None])
tl.store(inv_rms + indices_0 * inv_rms_stride_0, v_10, mask_0)

def rms_norm_fwd(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
"""
Expand All @@ -3005,11 +3006,11 @@ def rms_norm_fwd(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05, *, _la
m, n = x.size()
assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {n}'
out = torch.empty_like(x)
inv_rms = torch.empty([m, 1], dtype=x.dtype, device=x.device)
inv_rms = torch.empty([m], dtype=x.dtype, device=x.device)
_BLOCK_SIZE_0 = 16
_RDIM_SIZE_1 = triton.next_power_of_2(n)
_launcher(_helion_rms_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, out, inv_rms, inv_rms.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
return (out, inv_rms)
return (out, inv_rms.reshape(-1, 1))

--- assertExpectedJournal(TestExamples.test_segment_reduction)
from __future__ import annotations
Expand Down
Loading