Skip to content

Conversation

@yf225
Copy link
Contributor

@yf225 yf225 commented Aug 29, 2025

Fixes #523.

@yf225 yf225 requested review from jansel and oulgen August 29, 2025 06:09
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 29, 2025
@yf225 yf225 force-pushed the torch_stack_v1 branch 2 times, most recently from c076d0e to 46cc702 Compare August 29, 2025 07:45
Copy link
Contributor

@jansel jansel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be handled similar to:

@register_lowering(
torch.ops.aten.permute.default, # pyright: ignore[reportAttributeAccessIssue]
masked_value_fn=passthrough_masked_value,
)
def codegen_permute(ctx: GraphInterpreter, node: torch.fx.Node) -> object:
assert not node.kwargs, "getitem kwargs not supported"
tensor, dims = map_arg(node.args, lambda arg: ctx.env[arg])
assert isinstance(tensor, ast.AST)
dims = [*dims] # pyright: ignore[reportGeneralTypeIssues,reportOptionalIterable]
assert {*dims} == {*range(len(dims))}, dims
return expr_from_string(
f"tl.permute({{tensor}}, {dims!r})",
tensor=tensor,
)

I'd expect that to be simpler with less need for special casing view ops.

).graph
decomp_table = select_decomp_table()
decomp_table.pop(torch.ops.aten.stack.default, None)
return proxy_tensor.make_fx(fn, decomposition_table=decomp_table)(*args).graph
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally, torch.stack is decomposed to unsqueeze + cat, but I haven’t figure out a way to make codegen_cat work, so as a workaround we disable the decomp for torch.stack and implement codegen_stack instead.

@yf225 yf225 requested a review from jansel August 30, 2025 23:04
@yf225 yf225 force-pushed the torch_stack_v1 branch 2 times, most recently from e8138a1 to dcfb3fe Compare August 31, 2025 00:37
@yf225 yf225 force-pushed the torch_stack_v1 branch 2 times, most recently from b766ebf to a83637d Compare August 31, 2025 04:58
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-5)
self.assertExpectedJournal(code)

# Verify torch.compile still decomposes aten.stack to aten.cat
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added test to make sure _get_custom_decomp_table doesn't affect normal torch.compile decomp for torch.stack

@yf225 yf225 requested a review from jansel August 31, 2025 05:03
@yf225 yf225 merged commit 0f3e2d5 into main Sep 1, 2025
13 checks passed
lolpack pushed a commit to lolpack/helion that referenced this pull request Oct 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

torch.stack support

3 participants