-
Notifications
You must be signed in to change notification settings - Fork 61
Add torch.stack support #524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
c076d0e to
46cc702
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be handled similar to:
helion/helion/_compiler/inductor_lowering.py
Lines 823 to 836 in 3c0348a
| @register_lowering( | |
| torch.ops.aten.permute.default, # pyright: ignore[reportAttributeAccessIssue] | |
| masked_value_fn=passthrough_masked_value, | |
| ) | |
| def codegen_permute(ctx: GraphInterpreter, node: torch.fx.Node) -> object: | |
| assert not node.kwargs, "getitem kwargs not supported" | |
| tensor, dims = map_arg(node.args, lambda arg: ctx.env[arg]) | |
| assert isinstance(tensor, ast.AST) | |
| dims = [*dims] # pyright: ignore[reportGeneralTypeIssues,reportOptionalIterable] | |
| assert {*dims} == {*range(len(dims))}, dims | |
| return expr_from_string( | |
| f"tl.permute({{tensor}}, {dims!r})", | |
| tensor=tensor, | |
| ) |
I'd expect that to be simpler with less need for special casing view ops.
helion/_compiler/device_ir.py
Outdated
| ).graph | ||
| decomp_table = select_decomp_table() | ||
| decomp_table.pop(torch.ops.aten.stack.default, None) | ||
| return proxy_tensor.make_fx(fn, decomposition_table=decomp_table)(*args).graph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Normally, torch.stack is decomposed to unsqueeze + cat, but I haven’t figure out a way to make codegen_cat work, so as a workaround we disable the decomp for torch.stack and implement codegen_stack instead.
e8138a1 to
dcfb3fe
Compare
b766ebf to
a83637d
Compare
| torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5) | ||
| self.assertExpectedJournal(code) | ||
|
|
||
| # Verify torch.compile still decomposes aten.stack to aten.cat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added test to make sure _get_custom_decomp_table doesn't affect normal torch.compile decomp for torch.stack
Fixes #523.