-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Don't use subclass when tracing and call wait_tensor immediately. #98001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This change expects that proper scheduling of the wait_tensor call will happen over the traced graph.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/98001
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e8a26c1: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
return False | ||
return mode.tracer is not None | ||
|
||
def _maybe_wrap_tensor(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should also be easy to support in dynamo:
we just have to implement 1 special case, for '_maybe_wrap_tensor' which always traces wait_tensor, and never traces the guts of _maybe_wrap_tensor or _are_we_tracing.
), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}" | ||
tensor = torch._C._nn.reduce_scatter_tensor(self, reduceOp, scatter_dim, tag, rankset, group_size) # type: ignore[attr-defined] | ||
res = AsyncCollectiveTensor(tensor) | ||
_register_wrapper_tensor(res, tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i keep forgetting what the policy should be on calling _register...
we need to put a better comment here to immortalize it...
my current thought is
- during eager we must register here
- during tracing we never register here, but we require backends will register when calling the actual collective
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Backend's can optionally skip the _register_wrapper_tensor call if they emit a wait_tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm can they? i thought wait_tensor will call into this code https://github.com/pytorch/pytorch/blob/master/torch/distributed/_functional_collectives.py#L96
which will find nothing has been registered, UNLESS the backend did the registration first
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well this is complicated.
in inductor, I made collectives call register and i made wait_tensor call wait_tensor.
- this way, if wait happens in a separate inductor graph or in eager, it still works
but if your backend knows that allreduce+wait will be in the same graph, it can emit more efficient code that lowers a collective to work = dist.<collective>
and lowers wait_tensor to work.wait()
which skips the need for ever registering.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i guess the risk/concern is that if you aren't careful in backend design, wait_tensor op can become a silent no-op
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, we can look into this sort of optimization later.
General thoughts:
(is 2 fully true? and what is the status of the passes we'll need for wait placement? cc @lessw2020 @fegin) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add some unit tests to guard the behavior and make sure it works? You can take the tests I wrote in #97945
Re: 2 - correct that it doesn't matter where the waits start per se, we are going to roll them up with the comm fusion pass and use only the last wait for each fused section (the rest will be removed with dce). Re: status - the fusion pass has been working for some time and this is the pass atm that modifies the waits. |
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
mode = get_innermost_proxy_mode() | ||
if mode is None: | ||
return False | ||
return mode.tracer is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you intentionally check mode.tracer? It seems to me that this is guaranteed to be not None
This change expects that proper scheduling of the wait_tensor call will happen over the traced graph.