diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index 009059241..667249e7c 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -1133,6 +1133,45 @@ def __init__( self.cg = cg self.input_name_lookup = input_name_lookup + def _expected_tensor_dtype(self) -> torch.dtype | None: + """Best-effort retrieval of the current FX node's tensor dtype.""" + current_node = V.current_node + if current_node is None: + return None + val = current_node.meta.get("val") + if isinstance(val, torch.Tensor): + return val.dtype + return None + + def _create_cast_expr(self, x: object, target_dtype_str: str) -> ast.AST: + """Create a tl.cast expression from AST or string input. + + Args: + x: Input value (AST node or string/OpsValue) + target_dtype_str: Target Triton dtype as string (e.g., "tl.float32") + + Returns: + AST expression for the cast operation + """ + if isinstance(x, ast.AST): + return expr_from_string(f"tl.cast({{x}}, {target_dtype_str})", x=x) + base = _unpack_opsvalue(x) + return expr_from_string(f"tl.cast({base}, {target_dtype_str})") + + def _maybe_cast_to_expected_dtype(self, expr: ast.AST) -> ast.AST: + """Cast expression to expected dtype if needed. + + Args: + expr: Input expression to potentially cast + + Returns: + Original or casted expression + """ + expected_dtype = self._expected_tensor_dtype() + if expected_dtype is None: + return expr + return self._create_cast_expr(expr, triton_type(expected_dtype)) + def _default( self, name: str, args: tuple[object, ...], kwargs: dict[str, object] ) -> str: @@ -1155,12 +1194,7 @@ def to_dtype( device context during compute-type selection, and to guarantee a visible cast in generated code that matches PyTorch's dtype semantics. """ - # Accept both AST-like and string-like inputs from the parent pipeline - if isinstance(x, ast.AST): - cast_expr = expr_from_string(f"tl.cast({{x}}, {triton_type(dtype)})", x=x) - else: - base = _unpack_opsvalue(x) - cast_expr = expr_from_string(f"tl.cast({base}, {triton_type(dtype)})") + cast_expr = self._create_cast_expr(x, triton_type(dtype)) return self.cg.lift(cast_expr).id def _is_scalar_like_str(self, x_str: str) -> bool: @@ -1174,12 +1208,38 @@ def _is_scalar_like_str(self, x_str: str) -> bool: # Ensure non-linear elementwise ops receive fp32 inputs for Triton def sigmoid(self, x: object) -> str: # type: ignore[override] # Build tl.sigmoid(tl.cast(x, tl.float32)) and lift - if isinstance(x, ast.AST): - inner = expr_from_string("tl.cast({x}, tl.float32)", x=x) - else: - base = _unpack_opsvalue(x) - inner = expr_from_string(f"tl.cast({base}, tl.float32)") - return self.cg.lift(expr_from_string("tl.sigmoid({x})", x=inner)).id + inner = self._create_cast_expr(x, "tl.float32") + result = expr_from_string("tl.sigmoid({x})", x=inner) + + # Only cast if expected dtype is not float32 + expected_dtype = self._expected_tensor_dtype() + if expected_dtype is not None and expected_dtype != torch.float32: + result = self._maybe_cast_to_expected_dtype(result) + + return self.cg.lift(result).id + + def mul(self, a: object, b: object) -> str: # type: ignore[override] + def has_scalar_operand() -> bool: + current_node = V.current_node + if current_node is None: + return False + return any(isinstance(arg, (int, float, bool)) for arg in current_node.args) + + result_str = _unpack_opsvalue(self.parent_handler.mul(a, b)) + result_expr = expr_from_string(result_str) + + # Only cast if we have a scalar operand and expected dtype is not float32. + # This is to handle cases like `x_bf16 * 0.1` where Triton would promote the result to float32, + # deviating from PyTorch semantics. + expected_dtype = self._expected_tensor_dtype() + if ( + has_scalar_operand() + and expected_dtype is not None + and expected_dtype != torch.float32 + ): + result_expr = self._maybe_cast_to_expected_dtype(result_expr) + + return self.cg.lift(result_expr).id def load(self, name: str, index: sympy.Expr) -> str: # TODO(jansel): assert the index is correct @@ -1331,7 +1391,7 @@ def _collect_multi_outputs( def run_node(self, n: Node) -> object: if n.op == "call_function": - with self._set_current_node(n), n.meta["location"]: + with self._set_current_node(n), n.meta["location"], V.set_current_node(n): try: lowering: Lowering = n.meta["lowering"] result = lowering.codegen(self, n) diff --git a/test/test_examples.expected b/test/test_examples.expected index bd43104da..526b1c2ca 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -245,12 +245,12 @@ def _helion_attention(q_view, k_view, v_view, out, _BLOCK_SIZE_0: tl.constexpr, # src[attention.py:N]: m_ij = torch.maximum(m_i, torch.amax(qk, -1) * qk_scale) amax = tl.cast(tl.max(qk, 2), tl.float16) v_0 = 0.18033688 - v_1 = amax * v_0 + v_1 = tl.cast(amax * v_0, tl.float16) v_2 = tl.cast(v_1, tl.float32) v_3 = triton_helpers.maximum(m_i_copy_0, v_2) # src[attention.py:N]: qk = qk * qk_scale - m_ij[:, :, None] v_4 = 0.18033688 - v_5 = qk * v_4 + v_5 = tl.cast(qk * v_4, tl.float16) subscript = v_3[:, :, None] v_6 = tl.cast(v_5, tl.float32) v_7 = v_6 - subscript @@ -523,12 +523,12 @@ def _helion_attention(q_view, k_view, v_view, out, _NUM_SM: tl.constexpr, _BLOCK # src[attention.py:N]: m_ij = torch.maximum(m_i, torch.amax(qk, -1) * qk_scale) amax = tl.cast(tl.max(qk, 2), tl.float16) v_0 = 0.18033688 - v_1 = amax * v_0 + v_1 = tl.cast(amax * v_0, tl.float16) v_2 = tl.cast(v_1, tl.float32) v_3 = triton_helpers.maximum(m_i_copy_0, v_2) # src[attention.py:N]: qk = qk * qk_scale - m_ij[:, :, None] v_4 = 0.18033688 - v_5 = qk * v_4 + v_5 = tl.cast(qk * v_4, tl.float16) subscript = v_3[:, :, None] v_6 = tl.cast(v_5, tl.float32) v_7 = v_6 - subscript @@ -1129,7 +1129,7 @@ def _helion_cross_entropy(labels, logits_flat, logits, losses, _RDIM_SIZE_1: tl. labels_tile = tl.load(labels + indices_0 * 1, None) # src[cross_entropy.py:N]: base_indices_tile = tile_n.index * v # [tile_size] v_0 = tl.full([], 1000, tl.int32) - v_1 = indices_0 * v_0 + v_1 = tl.cast(indices_0 * v_0, tl.int32) # src[cross_entropy.py:N]: flat_indices = base_indices_tile + labels_tile v_2 = tl.cast(v_1, tl.int64) v_3 = v_2 + labels_tile @@ -2547,7 +2547,7 @@ def _helion_jagged_layer_norm_kernel(x_offsets, x_flat, out_flat, eps, _BLOCK_SI # src[jagged_layer_norm.py:N]: flat_indices = indices[:, :, None] * M + tile_m.index[None, None, :] subscript_2 = v_4[:, :, None] v_5 = tl.full([], 8, tl.int64) - v_6 = subscript_2 * v_5 + v_6 = tl.cast(subscript_2 * v_5, tl.int64) subscript_3 = indices_1[None, None, :] v_7 = tl.cast(subscript_3, tl.int64) v_8 = v_6 + v_7 @@ -2616,7 +2616,7 @@ def _helion_jagged_layer_norm_kernel(x_offsets, x_flat, out_flat, eps, _BLOCK_SI # src[jagged_layer_norm.py:N]: flat_indices = indices[:, :, None] * M + tile_m.index[None, None, :] subscript_8 = v_18[:, :, None] v_19 = tl.full([], 8, tl.int64) - v_20 = subscript_8 * v_19 + v_20 = tl.cast(subscript_8 * v_19, tl.int64) subscript_9 = indices_3[None, None, :] v_21 = tl.cast(subscript_9, tl.int64) v_22 = v_20 + v_21 @@ -2696,7 +2696,7 @@ def _helion_jagged_layer_norm_kernel(x_offsets, x_flat, out_flat, eps, _BLOCK_SI # src[jagged_layer_norm.py:N]: flat_indices = indices[:, :, None] * M + tile_m.index[None, None, :] subscript_15 = v_38[:, :, None] v_39 = tl.full([], 8, tl.int64) - v_40 = subscript_15 * v_39 + v_40 = tl.cast(subscript_15 * v_39, tl.int64) subscript_16 = indices_5[None, None, :] v_41 = tl.cast(subscript_16, tl.int64) v_42 = v_40 + v_41 @@ -3014,7 +3014,7 @@ def _helion_jagged_softmax_kernel(x_offsets, x_flat, out, _BLOCK_SIZE_0: tl.cons # src[jagged_softmax.py:N]: base_indices[:, :, None] * M + tile_m.index[None, None, :] subscript_2 = v_4[:, :, None] v_5 = tl.full([], 8, tl.int64) - v_6 = subscript_2 * v_5 + v_6 = tl.cast(subscript_2 * v_5, tl.int64) subscript_3 = indices_1[None, None, :] v_7 = tl.cast(subscript_3, tl.int64) v_8 = v_6 + v_7 @@ -3092,7 +3092,7 @@ def _helion_jagged_softmax_kernel(x_offsets, x_flat, out, _BLOCK_SIZE_0: tl.cons # src[jagged_softmax.py:N]: base_indices[:, :, None] * M + tile_m.index[None, None, :] subscript_11 = v_28[:, :, None] v_29 = tl.full([], 8, tl.int64) - v_30 = subscript_11 * v_29 + v_30 = tl.cast(subscript_11 * v_29, tl.int64) subscript_12 = indices_1[None, None, :] v_31 = tl.cast(subscript_12, tl.int64) v_32 = v_30 + v_31 @@ -3228,7 +3228,7 @@ def _helion_jagged_sum_kernel(x_offsets, x_flat, out, _BLOCK_SIZE_0: tl.constexp # src[jagged_sum.py:N]: base_indices[:, :, None] * M + tile_m.index[None, None, :] subscript_2 = v_4[:, :, None] v_5 = tl.full([], 8, tl.int64) - v_6 = subscript_2 * v_5 + v_6 = tl.cast(subscript_2 * v_5, tl.int64) subscript_3 = indices_1[None, None, :] v_7 = tl.cast(subscript_3, tl.int64) v_8 = v_6 + v_7 diff --git a/test/test_generate_ast.expected b/test/test_generate_ast.expected index fcf2ae633..3793ea31a 100644 --- a/test/test_generate_ast.expected +++ b/test/test_generate_ast.expected @@ -535,6 +535,48 @@ def inplace_mul(x, c, *, _launcher=_default_launcher): # src[basic_kernels.py:N]: return x return x +--- assertExpectedJournal(TestGenerateAst.test_sigmoid_scalar_autocast) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_se_block_fwd(x, w, out, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): + # src[test_generate_ast.py:N]: for tile_m in hl.tile(m): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + # src[test_generate_ast.py:N]: x_tile = x[tile_m, :] + x_tile = tl.load(tl.make_block_ptr(x, [4096, 128], [128, 1], [offset_0, 0], [_BLOCK_SIZE_0, _RDIM_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero') + # src[test_generate_ast.py:N]: sigmoid_result = torch.sigmoid(x_tile @ w[:, :]) + load_1 = tl.load(tl.make_block_ptr(w, [128, 128], [128, 1], [0, 0], [_RDIM_SIZE_1, _RDIM_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero') + mm = tl.dot(tl.cast(x_tile, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='tf32', out_dtype=tl.float32) + v_0 = tl.cast(tl.sigmoid(tl.cast(mm, tl.float32)), tl.bfloat16) + # src[test_generate_ast.py:N]: acc = 2.0 * x_tile * sigmoid_result + v_1 = 2.0 + v_2 = tl.cast(x_tile * v_1, tl.bfloat16) + v_3 = v_2 * v_0 + # src[test_generate_ast.py:N]: out[tile_m, :] = acc.to(x.dtype) + tl.store(tl.make_block_ptr(out, [4096, 128], [128, 1], [offset_0, 0], [_BLOCK_SIZE_0, _RDIM_SIZE_1], [1, 0]), v_3, boundary_check=[0, 1]) + +def se_block_fwd(x: torch.Tensor, w: torch.Tensor, *, _launcher=_default_launcher): + # src[test_generate_ast.py:N]: m, n = x.size() + m, n = x.size() + # src[test_generate_ast.py:N]: out = torch.empty([m, n], dtype=x.dtype, device=x.device) + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + # src[test_generate_ast.py:N]: for tile_m in hl.tile(m): + _BLOCK_SIZE_0 = 32 + _RDIM_SIZE_1 = 128 + # src[test_generate_ast.py:N]: for tile_m in hl.tile(m): + # src[test_generate_ast.py:N]: x_tile = x[tile_m, :] + # src[test_generate_ast.py:N]: sigmoid_result = torch.sigmoid(x_tile @ w[:, :]) + # src[test_generate_ast.py:N-N]: ... + _launcher(_helion_se_block_fwd, (triton.cdiv(4096, _BLOCK_SIZE_0),), x, w, out, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=1) + # src[test_generate_ast.py:N]: return out + return out + --- assertExpectedJournal(TestGenerateAst.test_torch_ops_pointwise) from __future__ import annotations diff --git a/test/test_generate_ast.py b/test/test_generate_ast.py index 00e12994f..d0b428a47 100644 --- a/test/test_generate_ast.py +++ b/test/test_generate_ast.py @@ -12,6 +12,7 @@ from helion._testing import code_and_output from helion._testing import import_path from helion._testing import skipIfRefEager +import helion.language as hl datadir = Path(__file__).parent / "data" basic_kernels = import_path(datadir / "basic_kernels.py") @@ -212,6 +213,41 @@ def test_final_cast_enforced_for_to_dtype(self): # Ensure codegen emits a final tl.cast(..., tl.bfloat16) assert "tl.cast" in code and "tl.bfloat16" in code + def test_sigmoid_scalar_autocast(self): + @helion.kernel( + config=helion.Config( + block_sizes=[32], + indexing="block_ptr", + ), + static_shapes=True, + ) + def se_block_fwd(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + m, n = x.size() + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + + for tile_m in hl.tile(m): + x_tile = x[tile_m, :] + sigmoid_result = torch.sigmoid(x_tile @ w[:, :]) + acc = 2.0 * x_tile * sigmoid_result + out[tile_m, :] = acc.to(x.dtype) + + return out + + m, n = 4096, 128 + dtype = torch.bfloat16 + + x = torch.randn(m, n, device=DEVICE, dtype=dtype) + w = torch.randn(n, n, device=DEVICE, dtype=dtype) + + code, result = code_and_output(se_block_fwd, (x, w)) + + x_fp32 = x.to(torch.float32) + w_fp32 = w.to(torch.float32) + expected = (2.0 * x_fp32 * torch.sigmoid(x_fp32 @ w_fp32)).to(dtype) + + torch.testing.assert_close(result, expected, atol=1e-1, rtol=1e-1) + self.assertExpectedJournal(code) + if __name__ == "__main__": unittest.main() diff --git a/test/test_tensor_descriptor.expected b/test/test_tensor_descriptor.expected index 33be7a54e..2316c28d8 100644 --- a/test/test_tensor_descriptor.expected +++ b/test/test_tensor_descriptor.expected @@ -185,12 +185,12 @@ def _helion_attention(q_view, k_view, v_view, out, _BLOCK_SIZE_1: tl.constexpr, # src[attention.py:N]: m_ij = torch.maximum(m_i, torch.amax(qk, -1) * qk_scale) amax = tl.cast(tl.max(qk, 2), tl.float16) v_0 = 0.18033688 - v_1 = amax * v_0 + v_1 = tl.cast(amax * v_0, tl.float16) v_2 = tl.cast(v_1, tl.float32) v_3 = triton_helpers.maximum(m_i_copy_0, v_2) # src[attention.py:N]: qk = qk * qk_scale - m_ij[:, :, None] v_4 = 0.18033688 - v_5 = qk * v_4 + v_5 = tl.cast(qk * v_4, tl.float16) subscript = v_3[:, :, None] v_6 = tl.cast(v_5, tl.float32) v_7 = v_6 - subscript