Skip to content

Conversation

xmfan
Copy link
Member

@xmfan xmfan commented Sep 19, 2025

Stack from ghstack (oldest at bottom):

Some ops like local_map hop's deferred mode are not desugared by make_fx, this means that when we apply SAC tags, we will need to define dispatch rules for the SAC torch dispatch modes as pointed out here: #162246 (comment). This PR adds those rules.

Additionally it fixes a pre-existing issue where we weren't coercing tangent layout (that AOTAutograd typically does) when partitioning the HOP joint.

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Sep 19, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163322

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 72c863b with merge base 607489f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

xmfan added a commit that referenced this pull request Sep 19, 2025
ghstack-source-id: 719f629
Pull Request resolved: #163322
xmfan added a commit that referenced this pull request Sep 19, 2025
ghstack-source-id: a313618
Pull Request resolved: #163322
@xmfan xmfan added the topic: not user facing topic category label Sep 19, 2025
@xmfan xmfan marked this pull request as ready for review September 19, 2025 16:40
@xmfan xmfan requested a review from zou3519 as a code owner September 19, 2025 16:40
@xmfan xmfan requested review from soulitzer and ydwu4 September 19, 2025 16:40
@xmfan
Copy link
Member Author

xmfan commented Sep 19, 2025

hold on, updating test to use aot eager backend

or op == torch.ops.aten._scaled_dot_product_efficient_attention.default
):
# NOTE: we can't save nondeterministic_seeded ops, the run with rng wrapper is not traceable yet
return torch.utils.checkpoint.CheckpointPolicy.PREFER_SAVE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer save is completely ignored by the compiler btw, so generally its recommended to use MUST_SAVE, but probably fine I if we're testing eager only

op == torch.ops.aten._scaled_dot_product_flash_attention.default
or op == torch.ops.aten._scaled_dot_product_efficient_attention.default
):
# NOTE: we can't save nondeterministic_seeded ops, the run with rng wrapper is not traceable yet
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait do you mean "cannot recompute RNG ops"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh ignore this, this is autoparallel frontend specific

xmfan added a commit that referenced this pull request Sep 19, 2025
ghstack-source-id: 231ad97
Pull Request resolved: #163322
xmfan added a commit that referenced this pull request Sep 20, 2025
ghstack-source-id: c4ca088
Pull Request resolved: #163322
@xmfan xmfan requested a review from ezyang September 23, 2025 03:28
@ezyang
Copy link
Contributor

ezyang commented Sep 23, 2025

Need more PR description

""",
)

@requires_cuda_and_triton
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this test require CUDA?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice it works on cpu now

@requires_cuda_and_triton
@unittest.skipIf(
not torch.distributed.is_available(), "Torch distributed not available."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto this, doesn't seem like you need distributed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i need it to import local_map/device mesh

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh ok, please look forward to https://www.internalfb.com/diff/D82283623

):
out = torch.compile(model, backend=backend)(*inputs)
out.sum().backward()
except AttributeError as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

local_map HOP currently only works for AP style compile which interprets the nodes and directly accessses their target. But the graph's codegen is currently wrong, it should be torch._higher_order_ops.local_map.<locals>.call_local_map. I'll fix this later as it's not needed for AP.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I'd prefer to make it obvious which code is live and which code is dead, with a comment with what you said here

# actual == torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE
# ):
# # can still be in fw_outs for post-graph bytecode
# self.assertFalse(node.name in bw_ins)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we can't use an expect test here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could, but we'd need to manually check for these properties whenever we update the expecttest anyway

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The actual impl changes are plausible

Some ops like local_map hop's deferred mode are not desugared by make_fx, this means that when we apply SAC tags, we will need to define dispatch rules for the SAC torch dispatch modes as pointed out here: #162246 (comment). This PR adds those rules.

Additionally it fixes a pre-existing issue where we weren't coercing tangent layout (that AOTAutograd typically does) when partitioning the HOP joint.

[ghstack-poisoned]
@xmfan
Copy link
Member Author

xmfan commented Sep 24, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 24, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
Some ops like local_map hop's deferred mode are not desugared by make_fx, this means that when we apply SAC tags, we will need to define dispatch rules for the SAC torch dispatch modes as pointed out here: pytorch#162246 (comment). This PR adds those rules.

Additionally it fixes a pre-existing issue where we weren't coercing tangent layout (that AOTAutograd typically does) when partitioning the HOP joint.
Pull Request resolved: pytorch#163322
Approved by: https://github.com/ezyang
jainapurva pushed a commit that referenced this pull request Sep 29, 2025
Some ops like local_map hop's deferred mode are not desugared by make_fx, this means that when we apply SAC tags, we will need to define dispatch rules for the SAC torch dispatch modes as pointed out here: #162246 (comment). This PR adds those rules.

Additionally it fixes a pre-existing issue where we weren't coercing tangent layout (that AOTAutograd typically does) when partitioning the HOP joint.
Pull Request resolved: #163322
Approved by: https://github.com/ezyang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants