-
Notifications
You must be signed in to change notification settings - Fork 26k
Supporting compilation of distributed_c10d.send and distributed_c10d.recv #155070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155070
Note: Links to docs will display an error until the docs builds have been completed. ❌ 14 New Failures, 1 Unrelated FailureAs of commit 68feaac with merge base 065c446 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
d9dd5d1 to
3453f31
Compare
|
@qingyi-yan Hi, this is a great work! Kindly ask if |
Thanks @RabbitWhite1 for the compliment. Right now I am focusing on getting this pull request merged. Supporting isend and irecv are certainly doable, but it depends on availability of my resources for this work. This availability is currently uncertain. |
|
Hi - Just checking --- I believe this pull request is ready for review and possibly merge. It has been a long time since i did my last update. Is there anything I need to do? Thanks. |
wconstab
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an interesting case because by definition if we compile a send or recv op, our graph is non-spmd. We have been designing how the compiler optimizations should behave for distributed programs and making sure that the resulting program is still valid is very difficult unless we assume/enforce it is spmd (same graphs on every rank).
I think we should support capturing send/recv, but we should have some rules. If we capture a p2p op, we need to also make sure we are not doing unsafe collective optimizations, for example. For now I think all we need to do is raise an error if any of the spmd-mode flags or compiler passes are registered and we encounter a p2p op during tracing. What do folks think?
@bdhirsh would be the best person to advise on this issue though he is out for a week or two. Also cc @xmfan @ezyang
|
I mean we just have to actually implement spmd mode IMO. |
|
Agreed with @wconstab that more consistency checking would make the support of p2p ops safer. Waiting for more detailed feedback on the relevant rules (checks) that are needed. |
|
I think this is basically reasonable. To address wconstab's concern, I suggest we only enable this is a config flag is set, and set it to False by default. I haven't reviewed the rest of the PR carefully but if you're willing to do the config flag I'll do the rest of the review. |
|
Yes, I agree to adding a config flag which is set to False by default. I will work on this and hope to have it ready in a week or so. Thanks for the feedback! |
| "torch.sparse_compressed_tensor": SkipFunctionVariable, | ||
| # Specially handle system-level communication functions | ||
| "torch.distributed.distributed_c10d.send": CommunicationFunctionVariable, | ||
| "torch.distributed.distributed_c10d.recv": CommunicationFunctionVariable, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Help me understand why these aren't handled the same way as other traceable collectives?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The recv function has an API that does not belong to the functional paradigm --- specifically the recv(variable) interface modifies the given variable. Another concern is that system-level communcations may modify the underlying system states, which means they may have unknown side effects. So I assumed it is not safe to treat them as functional collectives?
| use_fallback = False | ||
|
|
||
| import torch.distributed.distributed_c10d as c10d | ||
| # Fall back to not enable autograd if mutation has to be supported. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do the tests you added fail due to this? It feels like this is just a DCE problem? It should work to run AOTAutograd here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue was mutations of the input parameter in the recv function. Because this mutation violates the functional paradigm assumption for functions, the parameter fails to be modified if the Autograd is enabled, due to the use of functional objects. I am not aware if there is a way around it, except if we adopt an alternative functional API for the recv function.
torch/_dynamo/variables/functions.py
Outdated
| fn = fn_var.fn | ||
| return variables.TorchInGraphFunctionVariable(fn, nonstrict_traceable=True) | ||
| name = self.fn.__name__ | ||
| print (f"name={name}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't forget to remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely. Thanks for catching my carelessness!
| tx.output.create_proxy("call_function", self.fn, | ||
| *proxy_args_kwargs(args, kwargs)) | ||
| return variables.ConstantVariable(None) | ||
| return super().call_function(tx, args, kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably not quite right. I think it would be better to use something analogous to traceable collectives remapping to support this. See this:
def _traceable_collective_remaps():
# We can't rely on importing from distributed, since it's not always built
if torch.distributed.is_available():
from torch.distributed._functional_collectives import (
traceable_collective_remaps,
)
return traceable_collective_remaps
return {}
def _traceable_collectives_source(tx: "InstructionTranslator", fn):
assert torch.distributed.is_available(), "Illegal invocation."
assert fn in _traceable_collective_remaps().values()
inner_name = fn.__name__
path_source = tx.import_source("torch.distributed._functional_collectives")
return AttrSource(path_source, inner_name)
Essentially, we need functional versions of send and recv. Then you can use the CollectiveFunctionRewriteVariable apparatus to get to the functional collective.
|
Thank you @ezyang for the helpful feedback! I will try what you suggested and get back to you. |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Fixes #153642
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @EikanWang @jgong5 @wenzhe-nrv @sanchitintel @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd @mingfeima @XiaobingSuper @ashokei @jingxu10 @jerryzh168 @aditew01 @ezyang @voznesenskym @penguinwu @Guobing-Chen @zhuhaozhe @blzheng @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Lucaskabela @xmfan @SherlockNoMad