From b09d450aa709de5d4485e39f403435582378c679 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 9 Oct 2025 14:17:09 -0700 Subject: [PATCH 1/5] test --- test/test_indexing.expected | 69 ++++++++++++++++++++++++++++++ test/test_indexing.py | 84 +++++++++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+) diff --git a/test/test_indexing.expected b/test/test_indexing.expected index 37d0b6aa6..d9b37bc06 100644 --- a/test/test_indexing.expected +++ b/test/test_indexing.expected @@ -185,6 +185,75 @@ def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor, _launcher(_helion_broadcast_add_3d, (triton.cdiv(d0, _BLOCK_SIZE_0) * triton.cdiv(d1, _BLOCK_SIZE_1) * triton.cdiv(d2, _BLOCK_SIZE_2),), x, bias1, bias2, out, bias1.size(1), bias1.size(2), bias2.size(0), bias2.size(2), out.size(0), out.size(1), out.size(2), x.size(0), x.size(1), x.size(2), bias1.stride(0), bias1.stride(1), bias1.stride(2), bias2.stride(0), bias2.stride(1), bias2.stride(2), out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), d0, d1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2) return out +--- assertExpectedJournal(TestIndexing.test_hl_arange_non_power_of_2) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion__matmul_layernorm_bwd_dxdy(z, grad_out, weight, mean, rstd, y, grad_x, x, grad_y, grad_out_stride_0, grad_out_stride_1, grad_x_stride_0, grad_x_stride_1, grad_y_stride_0, grad_y_stride_1, mean_stride_0, rstd_stride_0, weight_stride_0, x_stride_0, x_stride_1, y_stride_0, y_stride_1, z_stride_0, z_stride_1, m, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < m + indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) + mask_1 = indices_1 < 7 + indices_2 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32) + mask_2 = indices_2 < 3 + load = tl.load(z + (indices_0[:, None] * z_stride_0 + indices_1[None, :] * z_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + v_0 = tl.cast(load, tl.float32) + load_1 = tl.load(grad_out + (indices_0[:, None] * grad_out_stride_0 + indices_1[None, :] * grad_out_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + v_1 = tl.cast(load_1, tl.float32) + load_2 = tl.load(weight + indices_1 * weight_stride_0, mask_1, other=0) + v_2 = tl.cast(load_2, tl.float32) + mean_tile = tl.load(mean + indices_0 * mean_stride_0, mask_0, other=0) + rstd_tile = tl.load(rstd + indices_0 * rstd_stride_0, mask_0, other=0) + subscript = mean_tile[:, None] + v_3 = v_0 - subscript + subscript_1 = rstd_tile[:, None] + v_4 = v_3 * subscript_1 + v_5 = v_2[None, :] + v_6 = v_5 * v_1 + v_7 = v_4 * v_6 + sum_1 = tl.cast(tl.reshape(tl.sum(v_7, 1), [_BLOCK_SIZE_0, 1]), tl.float32) + v_8 = 0.14285714285714285 + v_9 = sum_1 * v_8 + sum_2 = tl.cast(tl.reshape(tl.sum(v_6, 1), [_BLOCK_SIZE_0, 1]), tl.float32) + v_10 = 0.14285714285714285 + v_11 = sum_2 * v_10 + v_12 = v_4 * v_9 + v_13 = v_12 + v_11 + v_14 = v_6 - v_13 + subscript_2 = rstd_tile[:, None] + v_15 = v_14 * subscript_2 + load_5 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0) + permute = tl.permute(load_5, [1, 0]) + v_16 = tl.cast(permute, tl.float32) + mm = tl.dot(tl.reshape(tl.permute(tl.join(tl.cast(v_15, tl.float32), tl.zeros_like(tl.cast(v_15, tl.float32))), [0, 2, 1]), [16, 16]), tl.reshape(tl.permute(tl.join(tl.cast(v_16, tl.float32), tl.zeros_like(tl.cast(v_16, tl.float32))), [2, 0, 1]), [16, 4]), input_precision='tf32', out_dtype=tl.float32) + v_17 = tl.cast(mm, tl.float16) + tl.store(grad_x + (indices_0[:, None] * grad_x_stride_0 + indices_2[None, :] * grad_x_stride_1), v_17, mask_0[:, None] & mask_2[None, :]) + load_6 = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + permute_1 = tl.permute(load_6, [1, 0]) + v_18 = tl.cast(permute_1, tl.float32) + mm_1 = tl.dot(tl.cast(v_18, tl.float32), tl.cast(v_15, tl.float32), input_precision='tf32', out_dtype=tl.float32) + v_19 = tl.cast(mm_1, tl.float16) + iota = tl.arange(0, 4) + iota_1 = tl.arange(0, 8) + tl.atomic_add(grad_y + (iota[:, None] * grad_y_stride_0 + iota_1[None, :] * grad_y_stride_1), v_19, mask=(iota < 3)[:, None] & (iota_1 < 7)[None, :], sem='relaxed') + +def _matmul_layernorm_bwd_dxdy(grad_out: torch.Tensor, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, mean: torch.Tensor, rstd: torch.Tensor, weight: torch.Tensor, *, _launcher=_default_launcher): + m, n = z.shape + grad_x = torch.empty_like(x) + grad_y = torch.zeros_like(y) + _BLOCK_SIZE_0 = 16 + _RDIM_SIZE_1 = 8 + _RDIM_SIZE_2 = 4 + _launcher(_helion__matmul_layernorm_bwd_dxdy, (triton.cdiv(m, _BLOCK_SIZE_0),), z, grad_out, weight, mean, rstd, y, grad_x, x, grad_y, grad_out.stride(0), grad_out.stride(1), grad_x.stride(0), grad_x.stride(1), grad_y.stride(0), grad_y.stride(1), mean.stride(0), rstd.stride(0), weight.stride(0), x.stride(0), x.stride(1), y.stride(0), y.stride(1), z.stride(0), z.stride(1), m, _BLOCK_SIZE_0, _RDIM_SIZE_1, _RDIM_SIZE_2, num_warps=4, num_stages=3) + return (grad_x, grad_y) + --- assertExpectedJournal(TestIndexing.test_mask_load) from __future__ import annotations diff --git a/test/test_indexing.py b/test/test_indexing.py index 215ba4c0b..585a9d3c4 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -63,6 +63,90 @@ def arange(length: int, device: torch.device) -> torch.Tensor: ) self.assertExpectedJournal(code) + def test_hl_arange_non_power_of_2(self): + @helion.kernel + def _matmul_layernorm_bwd_dxdy( + grad_out: torch.Tensor, + x: torch.Tensor, + y: torch.Tensor, + z: torch.Tensor, + mean: torch.Tensor, + rstd: torch.Tensor, + weight: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + m, n = z.shape + k = x.shape[1] + n = hl.specialize(n) + k = hl.specialize(k) + + grad_x = torch.empty_like(x) + grad_y = torch.zeros_like(y) + + for tile_m in hl.tile(m): + z_tile = z[tile_m, :].to(torch.float32) + dy_tile = grad_out[tile_m, :].to(torch.float32) + w = weight[:].to(torch.float32) + mean_tile = mean[tile_m] + rstd_tile = rstd[tile_m] + + z_hat = (z_tile - mean_tile[:, None]) * rstd_tile[:, None] + wdy = w * dy_tile + c1 = torch.sum(z_hat * wdy, dim=-1, keepdim=True) / float(n) + c2 = torch.sum(wdy, dim=-1, keepdim=True) / float(n) + dz = (wdy - (z_hat * c1 + c2)) * rstd_tile[:, None] + + grad_x[tile_m, :] = (dz @ y[:, :].t().to(torch.float32)).to(x.dtype) + grad_y_update = (x[tile_m, :].t().to(torch.float32) @ dz).to(y.dtype) + + hl.atomic_add( + grad_y, + [ + hl.arange(0, k), + hl.arange(0, n), + ], + grad_y_update, + ) + + return grad_x, grad_y + + m, k, n = 5, 3, 7 + eps = 1e-5 + + x = torch.randn((m, k), device=DEVICE, dtype=torch.float16) + y = torch.randn((k, n), device=DEVICE, dtype=torch.float16) + weight = torch.randn((n,), device=DEVICE, dtype=torch.float16) + grad_out = torch.randn((m, n), device=DEVICE, dtype=torch.float16) + + z = (x @ y).to(torch.float32) + var, mean = torch.var_mean(z, dim=-1, keepdim=True, correction=0) + rstd = torch.rsqrt(var + eps) + + code, (grad_x, grad_y) = code_and_output( + _matmul_layernorm_bwd_dxdy, + ( + grad_out, + x, + y, + z.to(x.dtype), + mean.squeeze(-1), + rstd.squeeze(-1), + weight, + ), + ) + + # PyTorch reference gradients + z_hat = (z - mean) * rstd + wdy = weight.to(torch.float32) * grad_out.to(torch.float32) + c1 = torch.sum(z_hat * wdy, dim=-1, keepdim=True) / float(n) + c2 = torch.sum(wdy, dim=-1, keepdim=True) / float(n) + dz = (wdy - (z_hat * c1 + c2)) * rstd + ref_grad_x = (dz @ y.to(torch.float32).t()).to(grad_x.dtype) + ref_grad_y = (x.to(torch.float32).t() @ dz).to(grad_y.dtype) + + torch.testing.assert_close(grad_x, ref_grad_x, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(grad_y, ref_grad_y, rtol=1e-3, atol=1e-3) + self.assertExpectedJournal(code) + def test_pairwise_add(self): @helion.kernel() def pairwise_add(x: torch.Tensor) -> torch.Tensor: From 562acbdd74403916b4cab329ded6726550963baf Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 9 Oct 2025 14:17:13 -0700 Subject: [PATCH 2/5] fix --- helion/_compiler/indexing_strategy.py | 32 +++++++++++++++++++++++++++ helion/_compiler/inductor_lowering.py | 11 +++++++-- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index 7f329c391..32b479f9a 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -10,6 +10,7 @@ import torch from torch._inductor.utils import triton_type from torch._prims_common import compute_required_storage_length +from triton import next_power_of_2 from .. import exc from .._compat import get_tensor_descriptor_fn_name @@ -32,6 +33,32 @@ ShapeLike = Sequence[SymIntLike] +def _get_padded_iota_original_length( + state: CodegenState, index_position: int +) -> int | None: + """Get the original length of a padded iota node at the given index position. + + Args: + state: The codegen state containing fx_node information + index_position: The position in the index list to check + + Returns: + The original (unpadded) length if the index is a padded iota, None otherwise + """ + try: + index_node = state.fx_node.args[1][index_position] # type: ignore[union-attr, index] + if ( + isinstance(index_node, torch.fx.Node) + and index_node.target == torch.ops.prims.iota.default # pyright: ignore[reportAttributeAccessIssue] + and isinstance(length_arg := index_node.args[0], int) + and length_arg != next_power_of_2(length_arg) + ): + return length_arg + except (AttributeError, IndexError, TypeError): + pass + return None + + class IndexingStrategy: def codegen_load( self, @@ -634,6 +661,11 @@ def _is_size_one(size: int | torch.SymInt) -> bool: if (block_idx := env.get_block_id(output_size[output_idx])) is not None: if mask := state.codegen.mask_var(block_idx): mask_values.setdefault(f"({mask}){expand}") + # Check if this index comes from a padded hl.arange and generate mask + if ( + original_length := _get_padded_iota_original_length(state, n) + ) is not None: + mask_values.setdefault(f"({index_var} < {original_length}){expand}") output_idx += 1 elif ( isinstance(k, torch.Tensor) and len(index) == 1 and fake_value.ndim == 1 diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index dc4496c10..009059241 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -39,6 +39,7 @@ from torch.fx.interpreter import Interpreter from torch.fx.node import Node from torch.fx.node import map_arg +from triton import next_power_of_2 from .. import exc from ..exc import InductorLoweringError @@ -1451,7 +1452,7 @@ def sympy_expr(self, expr: sympy.Expr) -> str: @register_lowering(torch.ops.prims.iota.default) # pyright: ignore[reportAttributeAccessIssue] def codegen_iota(ctx: GraphInterpreter, node: torch.fx.Node) -> object: - """Generate tl.arange for torch.ops.prims.iota.default operations.""" + """Generate tl.arange for torch.ops.prims.iota.default operations with automatic power-of-2 padding.""" start = node.kwargs.get("start", 0) step = node.kwargs.get("step", 1) dtype = ( @@ -1459,7 +1460,13 @@ def codegen_iota(ctx: GraphInterpreter, node: torch.fx.Node) -> object: ) assert isinstance(dtype, torch.dtype) (length_arg,) = node.args # expecting a single argument for length - expr = "tl.arange(0, {length})" + + # Pad static non-power-of-2 lengths to next power of 2 + length_expr = "{length}" + if isinstance(length_arg, int) and length_arg != next_power_of_2(length_arg): + length_expr = str(next_power_of_2(length_arg)) + + expr = f"tl.arange(0, {length_expr})" if step != 1: expr = f"{{step}} * {expr}" if start != 0: From 779c5202848e3527c005b9e22095e03608cc1625 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 13 Oct 2025 14:12:35 -0700 Subject: [PATCH 3/5] up --- test/test_indexing.expected | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_indexing.expected b/test/test_indexing.expected index d9b37bc06..a0e8ca6b6 100644 --- a/test/test_indexing.expected +++ b/test/test_indexing.expected @@ -251,7 +251,7 @@ def _matmul_layernorm_bwd_dxdy(grad_out: torch.Tensor, x: torch.Tensor, y: torch _BLOCK_SIZE_0 = 16 _RDIM_SIZE_1 = 8 _RDIM_SIZE_2 = 4 - _launcher(_helion__matmul_layernorm_bwd_dxdy, (triton.cdiv(m, _BLOCK_SIZE_0),), z, grad_out, weight, mean, rstd, y, grad_x, x, grad_y, grad_out.stride(0), grad_out.stride(1), grad_x.stride(0), grad_x.stride(1), grad_y.stride(0), grad_y.stride(1), mean.stride(0), rstd.stride(0), weight.stride(0), x.stride(0), x.stride(1), y.stride(0), y.stride(1), z.stride(0), z.stride(1), m, _BLOCK_SIZE_0, _RDIM_SIZE_1, _RDIM_SIZE_2, num_warps=4, num_stages=3) + _launcher(_helion__matmul_layernorm_bwd_dxdy, (triton.cdiv(m, _BLOCK_SIZE_0),), z, grad_out, weight, mean, rstd, y, grad_x, x, grad_y, grad_out.stride(0), grad_out.stride(1), grad_x.stride(0), grad_x.stride(1), grad_y.stride(0), grad_y.stride(1), mean.stride(0), rstd.stride(0), weight.stride(0), x.stride(0), x.stride(1), y.stride(0), y.stride(1), z.stride(0), z.stride(1), m, _BLOCK_SIZE_0, _RDIM_SIZE_1, _RDIM_SIZE_2, num_warps=4, num_stages=2) return (grad_x, grad_y) --- assertExpectedJournal(TestIndexing.test_mask_load) From d93dfa740ba2cdc98123c67b204d9922ee583468 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 13 Oct 2025 14:47:23 -0700 Subject: [PATCH 4/5] adjust atol --- test/test_indexing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_indexing.py b/test/test_indexing.py index 585a9d3c4..b2c29452f 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -143,8 +143,8 @@ def _matmul_layernorm_bwd_dxdy( ref_grad_x = (dz @ y.to(torch.float32).t()).to(grad_x.dtype) ref_grad_y = (x.to(torch.float32).t() @ dz).to(grad_y.dtype) - torch.testing.assert_close(grad_x, ref_grad_x, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(grad_y, ref_grad_y, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(grad_x, ref_grad_x, rtol=1e-3, atol=2e-3) + torch.testing.assert_close(grad_y, ref_grad_y, rtol=1e-3, atol=2e-3) self.assertExpectedJournal(code) def test_pairwise_add(self): From 075d8667add074f4ae68772ba2143e656fd3bdc9 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 13 Oct 2025 15:12:06 -0700 Subject: [PATCH 5/5] fix ref eager mode --- helion/language/atomic_ops.py | 90 +++++++++++++++++------------------ 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/helion/language/atomic_ops.py b/helion/language/atomic_ops.py index f2674b6a1..417daf25d 100644 --- a/helion/language/atomic_ops.py +++ b/helion/language/atomic_ops.py @@ -1,6 +1,7 @@ from __future__ import annotations import ast +import itertools from typing import TYPE_CHECKING from typing import Callable @@ -125,15 +126,22 @@ def _ref_apply( if tensor_indices: # Element-wise processing for tensor indices (handle first tensor index) i, tensor_idx = tensor_indices[0] - for j, elem in enumerate(tensor_idx): + + if tensor_idx.ndim == 0: + coords_iter = [()] + else: + ranges = [range(dim) for dim in tensor_idx.shape] + coords_iter = itertools.product(*ranges) + + for coords in coords_iter: + elem = tensor_idx[coords].item() new_index = processed_index.copy() - new_index[i] = int(elem.item()) - val = ( - value[j] - if isinstance(value, torch.Tensor) and value.numel() > 1 - else value - ) - apply_fn(target, tuple(new_index), val) + new_index[i] = int(elem) + if isinstance(value, torch.Tensor) and value.numel() > 1: + next_value = value[coords] + else: + next_value = value + _ref_apply(target, new_index, apply_fn, next_value) else: apply_fn(target, tuple(processed_index), value) @@ -208,10 +216,10 @@ def _( _validate_sem(sem) from .ref_tile import RefTile - # Convert indices and detect tensor indices for element-wise updates + # Convert indices for shape computation and fast path detection processed_index: list[object] = [] - tensor_indices: list[tuple[int, torch.Tensor]] = [] - for i, idx in enumerate(index): + has_tensor_index = False + for idx in index: if isinstance(idx, RefTile): processed_index.append(idx._slice) elif isinstance(idx, torch.Tensor): @@ -219,47 +227,39 @@ def _( processed_index.append(int(idx.item())) else: processed_index.append(idx) - tensor_indices.append((i, idx)) + has_tensor_index = True else: processed_index.append(idx) - if tensor_indices: - # Element-wise processing for the first tensor index to ensure correct semantics - i, idx_tensor = tensor_indices[0] - ret = torch.empty_like(idx_tensor, dtype=target.dtype, device=target.device) - # Flatten to assign easily - flat_ret = ret.reshape(-1) - flat_idx = idx_tensor.reshape(-1) - # Prepare value per element - if isinstance(value, torch.Tensor) and value.numel() > 1: - flat_val = value.reshape(-1) + def _convert_value_to_target_dtype(val: object) -> torch.Tensor: + if isinstance(val, torch.Tensor): + vt = val.to(device=target.device) + if vt.dtype != target.dtype: + vt = vt.to(dtype=target.dtype) + return vt + return torch.as_tensor(val, dtype=target.dtype, device=target.device) + + if has_tensor_index: + ret_shape = SubscriptIndexing.compute_shape(target, processed_index) + prev_chunks: list[torch.Tensor] = [] + + def apply(t: torch.Tensor, idx_tuple: tuple, v: object) -> None: + prev_val = t[idx_tuple].clone() # pyright: ignore[reportArgumentType] + val_tensor = _convert_value_to_target_dtype(v) + t[idx_tuple] = t[idx_tuple] + val_tensor # pyright: ignore[reportArgumentType] + prev_chunks.append(prev_val.reshape(-1)) + + _ref_apply(target, index, apply, value) + if prev_chunks: + flat_prev = torch.cat(prev_chunks) else: - flat_val = None - for j, elem in enumerate(flat_idx): - new_index = list(processed_index) - new_index[i] = int(elem.item()) - new_index_t = tuple(new_index) - prev = target[new_index_t] # pyright: ignore[reportArgumentType] - vj = flat_val[j] if flat_val is not None else value - # Convert scalar to tensor on device - vj_t = ( - vj - if isinstance(vj, torch.Tensor) - else torch.as_tensor(vj, dtype=target.dtype, device=target.device) - ) - target[new_index_t] = target[new_index_t] + vj_t # pyright: ignore[reportArgumentType] - flat_ret[j] = prev # pyright: ignore[reportArgumentType] - return ret + flat_prev = target.new_empty(0, dtype=target.dtype, device=target.device) + return flat_prev.reshape(ret_shape) - # Scalar or simple indexing path idx_tuple = tuple(processed_index) prev = target[idx_tuple].clone() # pyright: ignore[reportArgumentType] - val = ( - value - if isinstance(value, torch.Tensor) - else torch.as_tensor(value, dtype=target.dtype, device=target.device) - ) - target[idx_tuple] = target[idx_tuple] + val # pyright: ignore[reportArgumentType] + val_tensor = _convert_value_to_target_dtype(value) + target[idx_tuple] = target[idx_tuple] + val_tensor # pyright: ignore[reportArgumentType] return prev