-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[hop] support local_map + SAC #163322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[hop] support local_map + SAC #163322
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163322
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 72c863b with merge base 607489f ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
hold on, updating test to use aot eager backend |
or op == torch.ops.aten._scaled_dot_product_efficient_attention.default | ||
): | ||
# NOTE: we can't save nondeterministic_seeded ops, the run with rng wrapper is not traceable yet | ||
return torch.utils.checkpoint.CheckpointPolicy.PREFER_SAVE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer save is completely ignored by the compiler btw, so generally its recommended to use MUST_SAVE
, but probably fine I if we're testing eager only
op == torch.ops.aten._scaled_dot_product_flash_attention.default | ||
or op == torch.ops.aten._scaled_dot_product_efficient_attention.default | ||
): | ||
# NOTE: we can't save nondeterministic_seeded ops, the run with rng wrapper is not traceable yet |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait do you mean "cannot recompute RNG ops"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh ignore this, this is autoparallel frontend specific
[ghstack-poisoned]
[ghstack-poisoned]
Need more PR description |
""", | ||
) | ||
|
||
@requires_cuda_and_triton |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this test require CUDA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice it works on cpu now
@requires_cuda_and_triton | ||
@unittest.skipIf( | ||
not torch.distributed.is_available(), "Torch distributed not available." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto this, doesn't seem like you need distributed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i need it to import local_map/device mesh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh ok, please look forward to https://www.internalfb.com/diff/D82283623
): | ||
out = torch.compile(model, backend=backend)(*inputs) | ||
out.sum().backward() | ||
except AttributeError as e: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's going on here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
local_map HOP currently only works for AP style compile which interprets the nodes and directly accessses their target. But the graph's codegen is currently wrong, it should be torch._higher_order_ops.local_map.<locals>.call_local_map
. I'll fix this later as it's not needed for AP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess I'd prefer to make it obvious which code is live and which code is dead, with a comment with what you said here
# actual == torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE | ||
# ): | ||
# # can still be in fw_outs for post-graph bytecode | ||
# self.assertFalse(node.name in bw_ins) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason we can't use an expect test here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could, but we'd need to manually check for these properties whenever we update the expecttest anyway
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The actual impl changes are plausible
[ghstack-poisoned]
Some ops like local_map hop's deferred mode are not desugared by make_fx, this means that when we apply SAC tags, we will need to define dispatch rules for the SAC torch dispatch modes as pointed out here: #162246 (comment). This PR adds those rules. Additionally it fixes a pre-existing issue where we weren't coercing tangent layout (that AOTAutograd typically does) when partitioning the HOP joint. [ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Some ops like local_map hop's deferred mode are not desugared by make_fx, this means that when we apply SAC tags, we will need to define dispatch rules for the SAC torch dispatch modes as pointed out here: pytorch#162246 (comment). This PR adds those rules. Additionally it fixes a pre-existing issue where we weren't coercing tangent layout (that AOTAutograd typically does) when partitioning the HOP joint. Pull Request resolved: pytorch#163322 Approved by: https://github.com/ezyang
Some ops like local_map hop's deferred mode are not desugared by make_fx, this means that when we apply SAC tags, we will need to define dispatch rules for the SAC torch dispatch modes as pointed out here: #162246 (comment). This PR adds those rules. Additionally it fixes a pre-existing issue where we weren't coercing tangent layout (that AOTAutograd typically does) when partitioning the HOP joint. Pull Request resolved: #163322 Approved by: https://github.com/ezyang
Stack from ghstack (oldest at bottom):
Some ops like local_map hop's deferred mode are not desugared by make_fx, this means that when we apply SAC tags, we will need to define dispatch rules for the SAC torch dispatch modes as pointed out here: #162246 (comment). This PR adds those rules.
Additionally it fixes a pre-existing issue where we weren't coercing tangent layout (that AOTAutograd typically does) when partitioning the HOP joint.