Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions helion/_compiler/indexing_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from torch._inductor.utils import triton_type
from torch._prims_common import compute_required_storage_length
from triton import next_power_of_2

from .. import exc
from .._compat import get_tensor_descriptor_fn_name
Expand All @@ -32,6 +33,32 @@
ShapeLike = Sequence[SymIntLike]


def _get_padded_iota_original_length(
state: CodegenState, index_position: int
) -> int | None:
"""Get the original length of a padded iota node at the given index position.

Args:
state: The codegen state containing fx_node information
index_position: The position in the index list to check

Returns:
The original (unpadded) length if the index is a padded iota, None otherwise
"""
try:
index_node = state.fx_node.args[1][index_position] # type: ignore[union-attr, index]
if (
isinstance(index_node, torch.fx.Node)
and index_node.target == torch.ops.prims.iota.default # pyright: ignore[reportAttributeAccessIssue]
and isinstance(length_arg := index_node.args[0], int)
and length_arg != next_power_of_2(length_arg)
):
return length_arg
except (AttributeError, IndexError, TypeError):
pass
return None


class IndexingStrategy:
def codegen_load(
self,
Expand Down Expand Up @@ -634,6 +661,11 @@ def _is_size_one(size: int | torch.SymInt) -> bool:
if (block_idx := env.get_block_id(output_size[output_idx])) is not None:
if mask := state.codegen.mask_var(block_idx):
mask_values.setdefault(f"({mask}){expand}")
# Check if this index comes from a padded hl.arange and generate mask
if (
original_length := _get_padded_iota_original_length(state, n)
) is not None:
mask_values.setdefault(f"({index_var} < {original_length}){expand}")
output_idx += 1
elif (
isinstance(k, torch.Tensor) and len(index) == 1 and fake_value.ndim == 1
Expand Down
11 changes: 9 additions & 2 deletions helion/_compiler/inductor_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from torch.fx.interpreter import Interpreter
from torch.fx.node import Node
from torch.fx.node import map_arg
from triton import next_power_of_2

from .. import exc
from ..exc import InductorLoweringError
Expand Down Expand Up @@ -1451,15 +1452,21 @@ def sympy_expr(self, expr: sympy.Expr) -> str:

@register_lowering(torch.ops.prims.iota.default) # pyright: ignore[reportAttributeAccessIssue]
def codegen_iota(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
"""Generate tl.arange for torch.ops.prims.iota.default operations."""
"""Generate tl.arange for torch.ops.prims.iota.default operations with automatic power-of-2 padding."""
start = node.kwargs.get("start", 0)
step = node.kwargs.get("step", 1)
dtype = (
node.kwargs.get("dtype") or CompileEnvironment.current().settings.index_dtype
)
assert isinstance(dtype, torch.dtype)
(length_arg,) = node.args # expecting a single argument for length
expr = "tl.arange(0, {length})"

# Pad static non-power-of-2 lengths to next power of 2
length_expr = "{length}"
if isinstance(length_arg, int) and length_arg != next_power_of_2(length_arg):
length_expr = str(next_power_of_2(length_arg))

expr = f"tl.arange(0, {length_expr})"
if step != 1:
expr = f"{{step}} * {expr}"
if start != 0:
Expand Down
90 changes: 45 additions & 45 deletions helion/language/atomic_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import ast
import itertools
from typing import TYPE_CHECKING
from typing import Callable

Expand Down Expand Up @@ -125,15 +126,22 @@ def _ref_apply(
if tensor_indices:
# Element-wise processing for tensor indices (handle first tensor index)
i, tensor_idx = tensor_indices[0]
for j, elem in enumerate(tensor_idx):

if tensor_idx.ndim == 0:
coords_iter = [()]
else:
ranges = [range(dim) for dim in tensor_idx.shape]
coords_iter = itertools.product(*ranges)

for coords in coords_iter:
elem = tensor_idx[coords].item()
new_index = processed_index.copy()
new_index[i] = int(elem.item())
val = (
value[j]
if isinstance(value, torch.Tensor) and value.numel() > 1
else value
)
apply_fn(target, tuple(new_index), val)
new_index[i] = int(elem)
if isinstance(value, torch.Tensor) and value.numel() > 1:
next_value = value[coords]
else:
next_value = value
_ref_apply(target, new_index, apply_fn, next_value)
else:
apply_fn(target, tuple(processed_index), value)

Expand Down Expand Up @@ -208,58 +216,50 @@ def _(
_validate_sem(sem)
from .ref_tile import RefTile

# Convert indices and detect tensor indices for element-wise updates
# Convert indices for shape computation and fast path detection
processed_index: list[object] = []
tensor_indices: list[tuple[int, torch.Tensor]] = []
for i, idx in enumerate(index):
has_tensor_index = False
for idx in index:
if isinstance(idx, RefTile):
processed_index.append(idx._slice)
elif isinstance(idx, torch.Tensor):
if idx.numel() == 1:
processed_index.append(int(idx.item()))
else:
processed_index.append(idx)
tensor_indices.append((i, idx))
has_tensor_index = True
else:
processed_index.append(idx)

if tensor_indices:
# Element-wise processing for the first tensor index to ensure correct semantics
i, idx_tensor = tensor_indices[0]
ret = torch.empty_like(idx_tensor, dtype=target.dtype, device=target.device)
# Flatten to assign easily
flat_ret = ret.reshape(-1)
flat_idx = idx_tensor.reshape(-1)
# Prepare value per element
if isinstance(value, torch.Tensor) and value.numel() > 1:
flat_val = value.reshape(-1)
def _convert_value_to_target_dtype(val: object) -> torch.Tensor:
if isinstance(val, torch.Tensor):
vt = val.to(device=target.device)
if vt.dtype != target.dtype:
vt = vt.to(dtype=target.dtype)
return vt
return torch.as_tensor(val, dtype=target.dtype, device=target.device)

if has_tensor_index:
ret_shape = SubscriptIndexing.compute_shape(target, processed_index)
prev_chunks: list[torch.Tensor] = []

def apply(t: torch.Tensor, idx_tuple: tuple, v: object) -> None:
prev_val = t[idx_tuple].clone() # pyright: ignore[reportArgumentType]
val_tensor = _convert_value_to_target_dtype(v)
t[idx_tuple] = t[idx_tuple] + val_tensor # pyright: ignore[reportArgumentType]
prev_chunks.append(prev_val.reshape(-1))

_ref_apply(target, index, apply, value)
if prev_chunks:
flat_prev = torch.cat(prev_chunks)
else:
flat_val = None
for j, elem in enumerate(flat_idx):
new_index = list(processed_index)
new_index[i] = int(elem.item())
new_index_t = tuple(new_index)
prev = target[new_index_t] # pyright: ignore[reportArgumentType]
vj = flat_val[j] if flat_val is not None else value
# Convert scalar to tensor on device
vj_t = (
vj
if isinstance(vj, torch.Tensor)
else torch.as_tensor(vj, dtype=target.dtype, device=target.device)
)
target[new_index_t] = target[new_index_t] + vj_t # pyright: ignore[reportArgumentType]
flat_ret[j] = prev # pyright: ignore[reportArgumentType]
return ret
flat_prev = target.new_empty(0, dtype=target.dtype, device=target.device)
return flat_prev.reshape(ret_shape)

# Scalar or simple indexing path
idx_tuple = tuple(processed_index)
prev = target[idx_tuple].clone() # pyright: ignore[reportArgumentType]
val = (
value
if isinstance(value, torch.Tensor)
else torch.as_tensor(value, dtype=target.dtype, device=target.device)
)
target[idx_tuple] = target[idx_tuple] + val # pyright: ignore[reportArgumentType]
val_tensor = _convert_value_to_target_dtype(value)
target[idx_tuple] = target[idx_tuple] + val_tensor # pyright: ignore[reportArgumentType]
return prev


Expand Down
69 changes: 69 additions & 0 deletions test/test_indexing.expected
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,75 @@ def broadcast_add_3d(x: torch.Tensor, bias1: torch.Tensor, bias2: torch.Tensor,
_launcher(_helion_broadcast_add_3d, (triton.cdiv(d0, _BLOCK_SIZE_0) * triton.cdiv(d1, _BLOCK_SIZE_1) * triton.cdiv(d2, _BLOCK_SIZE_2),), x, bias1, bias2, out, bias1.size(1), bias1.size(2), bias2.size(0), bias2.size(2), out.size(0), out.size(1), out.size(2), x.size(0), x.size(1), x.size(2), bias1.stride(0), bias1.stride(1), bias1.stride(2), bias2.stride(0), bias2.stride(1), bias2.stride(2), out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), d0, d1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=2)
return out

--- assertExpectedJournal(TestIndexing.test_hl_arange_non_power_of_2)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion__matmul_layernorm_bwd_dxdy(z, grad_out, weight, mean, rstd, y, grad_x, x, grad_y, grad_out_stride_0, grad_out_stride_1, grad_x_stride_0, grad_x_stride_1, grad_y_stride_0, grad_y_stride_1, mean_stride_0, rstd_stride_0, weight_stride_0, x_stride_0, x_stride_1, y_stride_0, y_stride_1, z_stride_0, z_stride_1, m, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < m
indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32)
mask_1 = indices_1 < 7
indices_2 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
mask_2 = indices_2 < 3
load = tl.load(z + (indices_0[:, None] * z_stride_0 + indices_1[None, :] * z_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
v_0 = tl.cast(load, tl.float32)
load_1 = tl.load(grad_out + (indices_0[:, None] * grad_out_stride_0 + indices_1[None, :] * grad_out_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
v_1 = tl.cast(load_1, tl.float32)
load_2 = tl.load(weight + indices_1 * weight_stride_0, mask_1, other=0)
v_2 = tl.cast(load_2, tl.float32)
mean_tile = tl.load(mean + indices_0 * mean_stride_0, mask_0, other=0)
rstd_tile = tl.load(rstd + indices_0 * rstd_stride_0, mask_0, other=0)
subscript = mean_tile[:, None]
v_3 = v_0 - subscript
subscript_1 = rstd_tile[:, None]
v_4 = v_3 * subscript_1
v_5 = v_2[None, :]
v_6 = v_5 * v_1
v_7 = v_4 * v_6
sum_1 = tl.cast(tl.reshape(tl.sum(v_7, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
v_8 = 0.14285714285714285
v_9 = sum_1 * v_8
sum_2 = tl.cast(tl.reshape(tl.sum(v_6, 1), [_BLOCK_SIZE_0, 1]), tl.float32)
v_10 = 0.14285714285714285
v_11 = sum_2 * v_10
v_12 = v_4 * v_9
v_13 = v_12 + v_11
v_14 = v_6 - v_13
subscript_2 = rstd_tile[:, None]
v_15 = v_14 * subscript_2
load_5 = tl.load(y + (indices_2[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_2[:, None] & mask_1[None, :], other=0)
permute = tl.permute(load_5, [1, 0])
v_16 = tl.cast(permute, tl.float32)
mm = tl.dot(tl.reshape(tl.permute(tl.join(tl.cast(v_15, tl.float32), tl.zeros_like(tl.cast(v_15, tl.float32))), [0, 2, 1]), [16, 16]), tl.reshape(tl.permute(tl.join(tl.cast(v_16, tl.float32), tl.zeros_like(tl.cast(v_16, tl.float32))), [2, 0, 1]), [16, 4]), input_precision='tf32', out_dtype=tl.float32)
v_17 = tl.cast(mm, tl.float16)
tl.store(grad_x + (indices_0[:, None] * grad_x_stride_0 + indices_2[None, :] * grad_x_stride_1), v_17, mask_0[:, None] & mask_2[None, :])
load_6 = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
permute_1 = tl.permute(load_6, [1, 0])
v_18 = tl.cast(permute_1, tl.float32)
mm_1 = tl.dot(tl.cast(v_18, tl.float32), tl.cast(v_15, tl.float32), input_precision='tf32', out_dtype=tl.float32)
v_19 = tl.cast(mm_1, tl.float16)
iota = tl.arange(0, 4)
iota_1 = tl.arange(0, 8)
tl.atomic_add(grad_y + (iota[:, None] * grad_y_stride_0 + iota_1[None, :] * grad_y_stride_1), v_19, mask=(iota < 3)[:, None] & (iota_1 < 7)[None, :], sem='relaxed')

def _matmul_layernorm_bwd_dxdy(grad_out: torch.Tensor, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, mean: torch.Tensor, rstd: torch.Tensor, weight: torch.Tensor, *, _launcher=_default_launcher):
m, n = z.shape
grad_x = torch.empty_like(x)
grad_y = torch.zeros_like(y)
_BLOCK_SIZE_0 = 16
_RDIM_SIZE_1 = 8
_RDIM_SIZE_2 = 4
_launcher(_helion__matmul_layernorm_bwd_dxdy, (triton.cdiv(m, _BLOCK_SIZE_0),), z, grad_out, weight, mean, rstd, y, grad_x, x, grad_y, grad_out.stride(0), grad_out.stride(1), grad_x.stride(0), grad_x.stride(1), grad_y.stride(0), grad_y.stride(1), mean.stride(0), rstd.stride(0), weight.stride(0), x.stride(0), x.stride(1), y.stride(0), y.stride(1), z.stride(0), z.stride(1), m, _BLOCK_SIZE_0, _RDIM_SIZE_1, _RDIM_SIZE_2, num_warps=4, num_stages=2)
return (grad_x, grad_y)

--- assertExpectedJournal(TestIndexing.test_mask_load)
from __future__ import annotations

Expand Down
84 changes: 84 additions & 0 deletions test/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,90 @@ def arange(length: int, device: torch.device) -> torch.Tensor:
)
self.assertExpectedJournal(code)

def test_hl_arange_non_power_of_2(self):
@helion.kernel
def _matmul_layernorm_bwd_dxdy(
grad_out: torch.Tensor,
x: torch.Tensor,
y: torch.Tensor,
z: torch.Tensor,
mean: torch.Tensor,
rstd: torch.Tensor,
weight: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
m, n = z.shape
k = x.shape[1]
n = hl.specialize(n)
k = hl.specialize(k)

grad_x = torch.empty_like(x)
grad_y = torch.zeros_like(y)

for tile_m in hl.tile(m):
z_tile = z[tile_m, :].to(torch.float32)
dy_tile = grad_out[tile_m, :].to(torch.float32)
w = weight[:].to(torch.float32)
mean_tile = mean[tile_m]
rstd_tile = rstd[tile_m]

z_hat = (z_tile - mean_tile[:, None]) * rstd_tile[:, None]
wdy = w * dy_tile
c1 = torch.sum(z_hat * wdy, dim=-1, keepdim=True) / float(n)
c2 = torch.sum(wdy, dim=-1, keepdim=True) / float(n)
dz = (wdy - (z_hat * c1 + c2)) * rstd_tile[:, None]

grad_x[tile_m, :] = (dz @ y[:, :].t().to(torch.float32)).to(x.dtype)
grad_y_update = (x[tile_m, :].t().to(torch.float32) @ dz).to(y.dtype)

hl.atomic_add(
grad_y,
[
hl.arange(0, k),
hl.arange(0, n),
],
grad_y_update,
)

return grad_x, grad_y

m, k, n = 5, 3, 7
eps = 1e-5

x = torch.randn((m, k), device=DEVICE, dtype=torch.float16)
y = torch.randn((k, n), device=DEVICE, dtype=torch.float16)
weight = torch.randn((n,), device=DEVICE, dtype=torch.float16)
grad_out = torch.randn((m, n), device=DEVICE, dtype=torch.float16)

z = (x @ y).to(torch.float32)
var, mean = torch.var_mean(z, dim=-1, keepdim=True, correction=0)
rstd = torch.rsqrt(var + eps)

code, (grad_x, grad_y) = code_and_output(
_matmul_layernorm_bwd_dxdy,
(
grad_out,
x,
y,
z.to(x.dtype),
mean.squeeze(-1),
rstd.squeeze(-1),
weight,
),
)

# PyTorch reference gradients
z_hat = (z - mean) * rstd
wdy = weight.to(torch.float32) * grad_out.to(torch.float32)
c1 = torch.sum(z_hat * wdy, dim=-1, keepdim=True) / float(n)
c2 = torch.sum(wdy, dim=-1, keepdim=True) / float(n)
dz = (wdy - (z_hat * c1 + c2)) * rstd
ref_grad_x = (dz @ y.to(torch.float32).t()).to(grad_x.dtype)
ref_grad_y = (x.to(torch.float32).t() @ dz).to(grad_y.dtype)

torch.testing.assert_close(grad_x, ref_grad_x, rtol=1e-3, atol=2e-3)
torch.testing.assert_close(grad_y, ref_grad_y, rtol=1e-3, atol=2e-3)
self.assertExpectedJournal(code)

def test_pairwise_add(self):
@helion.kernel()
def pairwise_add(x: torch.Tensor) -> torch.Tensor:
Expand Down
Loading