Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 73 additions & 13 deletions helion/_compiler/inductor_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,45 @@ def __init__(
self.cg = cg
self.input_name_lookup = input_name_lookup

def _expected_tensor_dtype(self) -> torch.dtype | None:
"""Best-effort retrieval of the current FX node's tensor dtype."""
current_node = V.current_node
if current_node is None:
return None
val = current_node.meta.get("val")
if isinstance(val, torch.Tensor):
return val.dtype
return None

def _create_cast_expr(self, x: object, target_dtype_str: str) -> ast.AST:
"""Create a tl.cast expression from AST or string input.

Args:
x: Input value (AST node or string/OpsValue)
target_dtype_str: Target Triton dtype as string (e.g., "tl.float32")

Returns:
AST expression for the cast operation
"""
if isinstance(x, ast.AST):
return expr_from_string(f"tl.cast({{x}}, {target_dtype_str})", x=x)
base = _unpack_opsvalue(x)
return expr_from_string(f"tl.cast({base}, {target_dtype_str})")

def _maybe_cast_to_expected_dtype(self, expr: ast.AST) -> ast.AST:
"""Cast expression to expected dtype if needed.

Args:
expr: Input expression to potentially cast

Returns:
Original or casted expression
"""
expected_dtype = self._expected_tensor_dtype()
if expected_dtype is None:
return expr
return self._create_cast_expr(expr, triton_type(expected_dtype))

def _default(
self, name: str, args: tuple[object, ...], kwargs: dict[str, object]
) -> str:
Expand All @@ -1155,12 +1194,7 @@ def to_dtype(
device context during compute-type selection, and to guarantee a visible
cast in generated code that matches PyTorch's dtype semantics.
"""
# Accept both AST-like and string-like inputs from the parent pipeline
if isinstance(x, ast.AST):
cast_expr = expr_from_string(f"tl.cast({{x}}, {triton_type(dtype)})", x=x)
else:
base = _unpack_opsvalue(x)
cast_expr = expr_from_string(f"tl.cast({base}, {triton_type(dtype)})")
cast_expr = self._create_cast_expr(x, triton_type(dtype))
return self.cg.lift(cast_expr).id

def _is_scalar_like_str(self, x_str: str) -> bool:
Expand All @@ -1174,12 +1208,38 @@ def _is_scalar_like_str(self, x_str: str) -> bool:
# Ensure non-linear elementwise ops receive fp32 inputs for Triton
def sigmoid(self, x: object) -> str: # type: ignore[override]
# Build tl.sigmoid(tl.cast(x, tl.float32)) and lift
if isinstance(x, ast.AST):
inner = expr_from_string("tl.cast({x}, tl.float32)", x=x)
else:
base = _unpack_opsvalue(x)
inner = expr_from_string(f"tl.cast({base}, tl.float32)")
return self.cg.lift(expr_from_string("tl.sigmoid({x})", x=inner)).id
inner = self._create_cast_expr(x, "tl.float32")
result = expr_from_string("tl.sigmoid({x})", x=inner)

# Only cast if expected dtype is not float32
expected_dtype = self._expected_tensor_dtype()
if expected_dtype is not None and expected_dtype != torch.float32:
result = self._maybe_cast_to_expected_dtype(result)

return self.cg.lift(result).id

def mul(self, a: object, b: object) -> str: # type: ignore[override]
def has_scalar_operand() -> bool:
current_node = V.current_node
if current_node is None:
return False
return any(isinstance(arg, (int, float, bool)) for arg in current_node.args)

result_str = _unpack_opsvalue(self.parent_handler.mul(a, b))
result_expr = expr_from_string(result_str)

# Only cast if we have a scalar operand and expected dtype is not float32.
# This is to handle cases like `x_bf16 * 0.1` where Triton would promote the result to float32,
# deviating from PyTorch semantics.
expected_dtype = self._expected_tensor_dtype()
if (
has_scalar_operand()
and expected_dtype is not None
and expected_dtype != torch.float32
):
result_expr = self._maybe_cast_to_expected_dtype(result_expr)

return self.cg.lift(result_expr).id

def load(self, name: str, index: sympy.Expr) -> str:
# TODO(jansel): assert the index is correct
Expand Down Expand Up @@ -1331,7 +1391,7 @@ def _collect_multi_outputs(

def run_node(self, n: Node) -> object:
if n.op == "call_function":
with self._set_current_node(n), n.meta["location"]:
with self._set_current_node(n), n.meta["location"], V.set_current_node(n):
try:
lowering: Lowering = n.meta["lowering"]
result = lowering.codegen(self, n)
Expand Down
22 changes: 11 additions & 11 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,12 @@ def _helion_attention(q_view, k_view, v_view, out, _BLOCK_SIZE_0: tl.constexpr,
# src[attention.py:N]: m_ij = torch.maximum(m_i, torch.amax(qk, -1) * qk_scale)
amax = tl.cast(tl.max(qk, 2), tl.float16)
v_0 = 0.18033688
v_1 = amax * v_0
v_1 = tl.cast(amax * v_0, tl.float16)
v_2 = tl.cast(v_1, tl.float32)
v_3 = triton_helpers.maximum(m_i_copy_0, v_2)
# src[attention.py:N]: qk = qk * qk_scale - m_ij[:, :, None]
v_4 = 0.18033688
v_5 = qk * v_4
v_5 = tl.cast(qk * v_4, tl.float16)
subscript = v_3[:, :, None]
v_6 = tl.cast(v_5, tl.float32)
v_7 = v_6 - subscript
Expand Down Expand Up @@ -523,12 +523,12 @@ def _helion_attention(q_view, k_view, v_view, out, _NUM_SM: tl.constexpr, _BLOCK
# src[attention.py:N]: m_ij = torch.maximum(m_i, torch.amax(qk, -1) * qk_scale)
amax = tl.cast(tl.max(qk, 2), tl.float16)
v_0 = 0.18033688
v_1 = amax * v_0
v_1 = tl.cast(amax * v_0, tl.float16)
v_2 = tl.cast(v_1, tl.float32)
v_3 = triton_helpers.maximum(m_i_copy_0, v_2)
# src[attention.py:N]: qk = qk * qk_scale - m_ij[:, :, None]
v_4 = 0.18033688
v_5 = qk * v_4
v_5 = tl.cast(qk * v_4, tl.float16)
subscript = v_3[:, :, None]
v_6 = tl.cast(v_5, tl.float32)
v_7 = v_6 - subscript
Expand Down Expand Up @@ -1129,7 +1129,7 @@ def _helion_cross_entropy(labels, logits_flat, logits, losses, _RDIM_SIZE_1: tl.
labels_tile = tl.load(labels + indices_0 * 1, None)
# src[cross_entropy.py:N]: base_indices_tile = tile_n.index * v # [tile_size]
v_0 = tl.full([], 1000, tl.int32)
v_1 = indices_0 * v_0
v_1 = tl.cast(indices_0 * v_0, tl.int32)
# src[cross_entropy.py:N]: flat_indices = base_indices_tile + labels_tile
v_2 = tl.cast(v_1, tl.int64)
v_3 = v_2 + labels_tile
Expand Down Expand Up @@ -2547,7 +2547,7 @@ def _helion_jagged_layer_norm_kernel(x_offsets, x_flat, out_flat, eps, _BLOCK_SI
# src[jagged_layer_norm.py:N]: flat_indices = indices[:, :, None] * M + tile_m.index[None, None, :]
subscript_2 = v_4[:, :, None]
v_5 = tl.full([], 8, tl.int64)
v_6 = subscript_2 * v_5
v_6 = tl.cast(subscript_2 * v_5, tl.int64)
subscript_3 = indices_1[None, None, :]
v_7 = tl.cast(subscript_3, tl.int64)
v_8 = v_6 + v_7
Expand Down Expand Up @@ -2616,7 +2616,7 @@ def _helion_jagged_layer_norm_kernel(x_offsets, x_flat, out_flat, eps, _BLOCK_SI
# src[jagged_layer_norm.py:N]: flat_indices = indices[:, :, None] * M + tile_m.index[None, None, :]
subscript_8 = v_18[:, :, None]
v_19 = tl.full([], 8, tl.int64)
v_20 = subscript_8 * v_19
v_20 = tl.cast(subscript_8 * v_19, tl.int64)
subscript_9 = indices_3[None, None, :]
v_21 = tl.cast(subscript_9, tl.int64)
v_22 = v_20 + v_21
Expand Down Expand Up @@ -2696,7 +2696,7 @@ def _helion_jagged_layer_norm_kernel(x_offsets, x_flat, out_flat, eps, _BLOCK_SI
# src[jagged_layer_norm.py:N]: flat_indices = indices[:, :, None] * M + tile_m.index[None, None, :]
subscript_15 = v_38[:, :, None]
v_39 = tl.full([], 8, tl.int64)
v_40 = subscript_15 * v_39
v_40 = tl.cast(subscript_15 * v_39, tl.int64)
subscript_16 = indices_5[None, None, :]
v_41 = tl.cast(subscript_16, tl.int64)
v_42 = v_40 + v_41
Expand Down Expand Up @@ -3014,7 +3014,7 @@ def _helion_jagged_softmax_kernel(x_offsets, x_flat, out, _BLOCK_SIZE_0: tl.cons
# src[jagged_softmax.py:N]: base_indices[:, :, None] * M + tile_m.index[None, None, :]
subscript_2 = v_4[:, :, None]
v_5 = tl.full([], 8, tl.int64)
v_6 = subscript_2 * v_5
v_6 = tl.cast(subscript_2 * v_5, tl.int64)
subscript_3 = indices_1[None, None, :]
v_7 = tl.cast(subscript_3, tl.int64)
v_8 = v_6 + v_7
Expand Down Expand Up @@ -3092,7 +3092,7 @@ def _helion_jagged_softmax_kernel(x_offsets, x_flat, out, _BLOCK_SIZE_0: tl.cons
# src[jagged_softmax.py:N]: base_indices[:, :, None] * M + tile_m.index[None, None, :]
subscript_11 = v_28[:, :, None]
v_29 = tl.full([], 8, tl.int64)
v_30 = subscript_11 * v_29
v_30 = tl.cast(subscript_11 * v_29, tl.int64)
subscript_12 = indices_1[None, None, :]
v_31 = tl.cast(subscript_12, tl.int64)
v_32 = v_30 + v_31
Expand Down Expand Up @@ -3228,7 +3228,7 @@ def _helion_jagged_sum_kernel(x_offsets, x_flat, out, _BLOCK_SIZE_0: tl.constexp
# src[jagged_sum.py:N]: base_indices[:, :, None] * M + tile_m.index[None, None, :]
subscript_2 = v_4[:, :, None]
v_5 = tl.full([], 8, tl.int64)
v_6 = subscript_2 * v_5
v_6 = tl.cast(subscript_2 * v_5, tl.int64)
subscript_3 = indices_1[None, None, :]
v_7 = tl.cast(subscript_3, tl.int64)
v_8 = v_6 + v_7
Expand Down
42 changes: 42 additions & 0 deletions test/test_generate_ast.expected
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,48 @@ def inplace_mul(x, c, *, _launcher=_default_launcher):
# src[basic_kernels.py:N]: return x
return x

--- assertExpectedJournal(TestGenerateAst.test_sigmoid_scalar_autocast)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_se_block_fwd(x, w, out, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr):
# src[test_generate_ast.py:N]: for tile_m in hl.tile(m):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
# src[test_generate_ast.py:N]: x_tile = x[tile_m, :]
x_tile = tl.load(tl.make_block_ptr(x, [4096, 128], [128, 1], [offset_0, 0], [_BLOCK_SIZE_0, _RDIM_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
# src[test_generate_ast.py:N]: sigmoid_result = torch.sigmoid(x_tile @ w[:, :])
load_1 = tl.load(tl.make_block_ptr(w, [128, 128], [128, 1], [0, 0], [_RDIM_SIZE_1, _RDIM_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
mm = tl.dot(tl.cast(x_tile, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32)
v_0 = tl.cast(tl.sigmoid(tl.cast(mm, tl.float32)), tl.bfloat16)
# src[test_generate_ast.py:N]: acc = 2.0 * x_tile * sigmoid_result
v_1 = 2.0
v_2 = tl.cast(x_tile * v_1, tl.bfloat16)
v_3 = v_2 * v_0
# src[test_generate_ast.py:N]: out[tile_m, :] = acc.to(x.dtype)
tl.store(tl.make_block_ptr(out, [4096, 128], [128, 1], [offset_0, 0], [_BLOCK_SIZE_0, _RDIM_SIZE_1], [1, 0]), v_3, boundary_check=[0, 1])

def se_block_fwd(x: torch.Tensor, w: torch.Tensor, *, _launcher=_default_launcher):
# src[test_generate_ast.py:N]: m, n = x.size()
m, n = x.size()
# src[test_generate_ast.py:N]: out = torch.empty([m, n], dtype=x.dtype, device=x.device)
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
# src[test_generate_ast.py:N]: for tile_m in hl.tile(m):
_BLOCK_SIZE_0 = 32
_RDIM_SIZE_1 = 128
# src[test_generate_ast.py:N]: for tile_m in hl.tile(m):
# src[test_generate_ast.py:N]: x_tile = x[tile_m, :]
# src[test_generate_ast.py:N]: sigmoid_result = torch.sigmoid(x_tile @ w[:, :])
# src[test_generate_ast.py:N-N]: ...
_launcher(_helion_se_block_fwd, (triton.cdiv(4096, _BLOCK_SIZE_0),), x, w, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1)
# src[test_generate_ast.py:N]: return out
return out

--- assertExpectedJournal(TestGenerateAst.test_torch_ops_pointwise)
from __future__ import annotations

Expand Down
36 changes: 36 additions & 0 deletions test/test_generate_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from helion._testing import code_and_output
from helion._testing import import_path
from helion._testing import skipIfRefEager
import helion.language as hl

datadir = Path(__file__).parent / "data"
basic_kernels = import_path(datadir / "basic_kernels.py")
Expand Down Expand Up @@ -212,6 +213,41 @@ def test_final_cast_enforced_for_to_dtype(self):
# Ensure codegen emits a final tl.cast(..., tl.bfloat16)
assert "tl.cast" in code and "tl.bfloat16" in code

def test_sigmoid_scalar_autocast(self):
@helion.kernel(
config=helion.Config(
block_sizes=[32],
indexing="block_ptr",
),
static_shapes=True,
)
def se_block_fwd(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
m, n = x.size()
out = torch.empty([m, n], dtype=x.dtype, device=x.device)

for tile_m in hl.tile(m):
x_tile = x[tile_m, :]
sigmoid_result = torch.sigmoid(x_tile @ w[:, :])
acc = 2.0 * x_tile * sigmoid_result
out[tile_m, :] = acc.to(x.dtype)

return out

m, n = 4096, 128
dtype = torch.bfloat16

x = torch.randn(m, n, device=DEVICE, dtype=dtype)
w = torch.randn(n, n, device=DEVICE, dtype=dtype)

code, result = code_and_output(se_block_fwd, (x, w))

x_fp32 = x.to(torch.float32)
w_fp32 = w.to(torch.float32)
expected = (2.0 * x_fp32 * torch.sigmoid(x_fp32 @ w_fp32)).to(dtype)

torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-1)
self.assertExpectedJournal(code)


if __name__ == "__main__":
unittest.main()
4 changes: 2 additions & 2 deletions test/test_tensor_descriptor.expected
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,12 @@ def _helion_attention(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr,
# src[attention.py:N]: m_ij = torch.maximum(m_i, torch.amax(qk, -1) * qk_scale)
amax = tl.cast(tl.max(qk, 2), tl.float16)
v_0 = 0.18033688
v_1 = amax * v_0
v_1 = tl.cast(amax * v_0, tl.float16)
v_2 = tl.cast(v_1, tl.float32)
v_3 = triton_helpers.maximum(m_i_copy_0, v_2)
# src[attention.py:N]: qk = qk * qk_scale - m_ij[:, :, None]
v_4 = 0.18033688
v_5 = qk * v_4
v_5 = tl.cast(qk * v_4, tl.float16)
subscript = v_3[:, :, None]
v_6 = tl.cast(v_5, tl.float32)
v_7 = v_6 - subscript
Expand Down
Loading