From 0af05f4ceac876b566090d7f05b73549248234b2 Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Sat, 30 Aug 2025 10:32:39 -0700 Subject: [PATCH] Allow dynamic fill values in full Fixes #448 stack-info: PR: https://github.com/pytorch/helion/pull/533, branch: jansel/stack/131 --- helion/_compiler/inductor_lowering.py | 11 +++++++---- test/test_broadcasting.py | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/helion/_compiler/inductor_lowering.py b/helion/_compiler/inductor_lowering.py index e4c8c32c7..04e46f3f6 100644 --- a/helion/_compiler/inductor_lowering.py +++ b/helion/_compiler/inductor_lowering.py @@ -766,17 +766,20 @@ def codegen_getitem(ctx: GraphInterpreter, node: torch.fx.Node) -> object: ) def codegen_full(ctx: GraphInterpreter, node: torch.fx.Node) -> object: env = CompileEnvironment.current() - size, fill_value = map_arg(node.args, lambda n: n.meta["val"]) + size = map_arg(node.args[0], lambda n: n.meta["val"]) dtype = node.kwargs.get("dtype", torch.get_default_dtype()) assert isinstance(dtype, torch.dtype) device = node.kwargs.get("device", env.device) assert device == env.device, f"expected {env.device}, got {device}" assert not node.kwargs.get("pin_memory"), "pin_memory not supported" - assert isinstance(fill_value, (int, float, bool)) - + value_ast = map_arg(node.args[1], lambda arg: ctx.env[arg]) + if isinstance(value_ast, (int, float, bool)): + value_ast = expr_from_string(constant_repr(value_ast)) + assert isinstance(value_ast, ast.AST), value_ast shape_str = ctx.cg.device_function.tile_strategy.shape_str([*size]) # pyright: ignore[reportGeneralTypeIssues,reportOptionalIterable] return expr_from_string( - f"tl.full({shape_str}, {constant_repr(fill_value)}, {triton_type(dtype)})" + f"tl.full({shape_str}, {{value}}, {triton_type(dtype)})", + value=value_ast, ) diff --git a/test/test_broadcasting.py b/test/test_broadcasting.py index 64032d04b..f5ece7801 100644 --- a/test/test_broadcasting.py +++ b/test/test_broadcasting.py @@ -126,6 +126,24 @@ def fn(a, beta): torch.testing.assert_close(out, expected) self.assertExpectedJournal(code) + def test_lerp_scalar_weight(self): + # Repro for https://github.com/pytorch/helion/issues/448 + # Using torch.lerp with a Python scalar weight should not crash. + @helion.kernel + def fn(a, b, w): + for tile0, tile1 in hl.tile(a.shape): + a[tile0, tile1] = torch.lerp(a[tile0, tile1], b[tile0, tile1], w) + return a + + a = torch.randn(128, 128, device=DEVICE) + b = torch.randn(128, 128, device=DEVICE) + w = 0.5 + args = (a.clone(), b, w) + + expected = torch.lerp(a, b, w) + code, out = code_and_output(fn, args, block_sizes=[16, 16]) + torch.testing.assert_close(out, expected) + if __name__ == "__main__": unittest.main()