Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions examples/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,33 @@ def rms_norm_fwd(
out = torch.empty_like(x)
inv_rms = torch.empty([m], dtype=x.dtype, device=x.device)

for tile_m in hl.tile(m):
x_tile = x[tile_m, :].to(torch.float32)

# Compute inverse RMS: 1/sqrt(mean(x^2) + eps)
x_squared = x_tile * x_tile
mean_x_squared = torch.mean(x_squared, dim=-1)
inv_rms_tile = torch.rsqrt(mean_x_squared + eps)
block_size_n = hl.register_block_size(n)
n_spec = hl.specialize(n)

# Apply normalization and weight
normalized = x_tile * inv_rms_tile[:, None]
out[tile_m, :] = (normalized * weight[:].to(torch.float32)).to(out.dtype)
inv_rms[tile_m] = inv_rms_tile.to(out.dtype)
for tile_m in hl.tile(m):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That’s a smart way to address the performance issue!

Just to call out — longer term, it might be ideal for the autotuner to handle eviction policies and loop reduction setup automatically, rather than users specifying them directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I agree, I was planning to implement that next but wanted to check if this is doable at all.

# First pass: accumulate sum of squares across N in blocks
sum_sq = hl.zeros([tile_m], dtype=torch.float32)
for tile_n in hl.tile(n, block_size=block_size_n):
xi_chunk = hl.load(x, [tile_m, tile_n], eviction_policy="evict_last").to(
torch.float32
)
sum_sq = sum_sq + (xi_chunk * xi_chunk).sum(dim=1)

mean_sq = sum_sq / n_spec
inv_tile = torch.rsqrt(mean_sq + eps)
inv_rms[tile_m] = inv_tile.to(inv_rms.dtype)

# Second pass: apply normalization and weight
for tile_n in hl.tile(n, block_size=block_size_n):
w_chunk = hl.load(weight, [tile_n], eviction_policy="evict_last").to(
torch.float32
)
x_chunk = hl.load(x, [tile_m, tile_n], eviction_policy="evict_first").to(
torch.float32
)
out[tile_m, tile_n] = (x_chunk * inv_tile[:, None] * w_chunk[None, :]).to(
out.dtype
)

return out, inv_rms.reshape(-1, 1)

Expand Down
61 changes: 36 additions & 25 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -3318,30 +3318,41 @@ from torch._inductor.runtime.triton_compat import libdevice
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_rms_norm_fwd(x, weight, out, inv_rms, inv_rms_stride_0, out_stride_0, out_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, n, eps, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
def _helion_rms_norm_fwd(x, inv_rms, weight, out, inv_rms_stride_0, out_stride_0, out_stride_1, weight_stride_0, x_stride_0, x_stride_1, m, eps, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < m
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
mask_1 = indices_1 < n
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
v_0 = tl.cast(load, tl.float32)
v_1 = v_0 * v_0
mean_x_squared_extra = tl.cast(tl.sum(v_1, 1), tl.float32)
v_2 = mean_x_squared_extra / n.to(tl.float32)
v_3 = v_2 + eps
v_4 = libdevice.rsqrt(v_3)
subscript = v_4[:, None]
v_5 = v_0 * subscript
load_1 = tl.load(weight + indices_1 * weight_stride_0, mask_1, other=0)
v_6 = tl.cast(load_1, tl.float32)
v_7 = v_6[None, :]
v_8 = v_5 * v_7
v_9 = tl.cast(v_8, tl.float16)
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_9, mask_0[:, None] & mask_1[None, :])
v_10 = tl.cast(v_4, tl.float16)
tl.store(inv_rms + indices_0 * inv_rms_stride_0, v_10, mask_0)
offset_1 = pid_0 * _BLOCK_SIZE_1
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
mask_1 = indices_1 < m
sum_sq = tl.full([_BLOCK_SIZE_1], 0.0, tl.float32)
for offset_2 in tl.range(0, 256, _BLOCK_SIZE_0):
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
sum_sq_copy = sum_sq
sum_sq_copy_0 = sum_sq_copy
load = tl.load(x + (indices_1[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_1[:, None], other=0, eviction_policy='evict_last')
v_0 = tl.cast(load, tl.float32)
v_1 = v_0 * v_0
sum_1 = tl.cast(tl.sum(v_1, 1), tl.float32)
sum_sq = sum_sq_copy_0 + sum_1
v_3 = 0.00390625
v_4 = sum_sq * v_3
v_5 = v_4 + eps
v_6 = libdevice.rsqrt(v_5)
v_7 = tl.cast(v_6, tl.float16)
tl.store(inv_rms + indices_1 * inv_rms_stride_0, v_7, mask_1)
for offset_2 in tl.range(0, 256, _BLOCK_SIZE_0):
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
v_6_copy = v_6
v_6_copy_0 = v_6_copy
load_1 = tl.load(weight + indices_2 * weight_stride_0, None, eviction_policy='evict_last')
v_8 = tl.cast(load_1, tl.float32)
load_2 = tl.load(x + (indices_1[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_1[:, None], other=0, eviction_policy='evict_first')
v_9 = tl.cast(load_2, tl.float32)
subscript = v_6_copy_0[:, None]
v_10 = v_9 * subscript
subscript_1 = v_8[None, :]
v_11 = v_10 * subscript_1
v_12 = tl.cast(v_11, tl.float16)
tl.store(out + (indices_1[:, None] * out_stride_0 + indices_2[None, :] * out_stride_1), v_12, mask_1[:, None])

def rms_norm_fwd(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
"""
Expand All @@ -3363,9 +3374,9 @@ def rms_norm_fwd(x: torch.Tensor, weight: torch.Tensor, eps: float=1e-05, *, _la
assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {n}'
out = torch.empty_like(x)
inv_rms = torch.empty([m], dtype=x.dtype, device=x.device)
_BLOCK_SIZE_1 = 16
_BLOCK_SIZE_0 = 16
_RDIM_SIZE_1 = triton.next_power_of_2(n)
_launcher(_helion_rms_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, out, inv_rms, inv_rms.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
_launcher(_helion_rms_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_1),), x, inv_rms, weight, out, inv_rms.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, eps, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return (out, inv_rms.reshape(-1, 1))

--- assertExpectedJournal(TestExamples.test_segment_reduction)
Expand Down
4 changes: 2 additions & 2 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def test_rms_norm_fwd(self):
args,
(expected, None), # Expected: (output, 1/rms)
fn_name="rms_norm_fwd",
block_sizes=[16],
block_sizes=[16, 16],
indexing="pointer",
)
)
Expand All @@ -343,7 +343,7 @@ def test_rms_norm_bwd(self):
from examples.rms_norm import rms_norm_fwd

# Create configured kernel with explicit config
config = helion.Config(block_size=32, num_warps=4, num_stages=3)
config = helion.Config(block_size=[32, 32], num_warps=4, num_stages=3)
configured_kernel = helion.kernel(rms_norm_fwd.fn, config=config)
y, rms = configured_kernel(x, weight, eps)

Expand Down
Loading