Skip to content

Conversation

kumpera
Copy link
Contributor

@kumpera kumpera commented Mar 30, 2023

This change expects that proper scheduling of the wait_tensor call will happen over the traced graph.

This change expects that proper scheduling of the wait_tensor call will happen
over the traced graph.
@pytorch-bot
Copy link

pytorch-bot bot commented Mar 30, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/98001

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e8a26c1:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@kumpera kumpera requested a review from d4l3k as a code owner March 30, 2023 18:53
return False
return mode.tracer is not None

def _maybe_wrap_tensor(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should also be easy to support in dynamo:

we just have to implement 1 special case, for '_maybe_wrap_tensor' which always traces wait_tensor, and never traces the guts of _maybe_wrap_tensor or _are_we_tracing.

), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
tensor = torch._C._nn.reduce_scatter_tensor(self, reduceOp, scatter_dim, tag, rankset, group_size) # type: ignore[attr-defined]
res = AsyncCollectiveTensor(tensor)
_register_wrapper_tensor(res, tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i keep forgetting what the policy should be on calling _register...

we need to put a better comment here to immortalize it...

my current thought is

  • during eager we must register here
  • during tracing we never register here, but we require backends will register when calling the actual collective

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backend's can optionally skip the _register_wrapper_tensor call if they emit a wait_tensor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm can they? i thought wait_tensor will call into this code https://github.com/pytorch/pytorch/blob/master/torch/distributed/_functional_collectives.py#L96

which will find nothing has been registered, UNLESS the backend did the registration first

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well this is complicated.

in inductor, I made collectives call register and i made wait_tensor call wait_tensor.

  • this way, if wait happens in a separate inductor graph or in eager, it still works

but if your backend knows that allreduce+wait will be in the same graph, it can emit more efficient code that lowers a collective to work = dist.<collective> and lowers wait_tensor to work.wait() which skips the need for ever registering.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i guess the risk/concern is that if you aren't careful in backend design, wait_tensor op can become a silent no-op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, we can look into this sort of optimization later.

@wconstab
Copy link
Contributor

General thoughts:

  1. easy to reverse this decision if we need to: just have to find a way to fix the subclass tracing problems
  2. we're already implementing compiler passes for comm optimization, it shouldn't matter to them whether we start out with wait in a near optimal or non-optimal place

(is 2 fully true? and what is the status of the passes we'll need for wait placement? cc @lessw2020 @fegin)

Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some unit tests to guard the behavior and make sure it works? You can take the tests I wrote in #97945

@lessw2020
Copy link
Contributor

lessw2020 commented Mar 30, 2023

  1. we're already implementing compiler passes for comm optimization, it shouldn't matter to them whether we start out with wait in a near optimal or non-optimal place

(is 2 fully true? and what is the status of the passes we'll need for wait placement? cc @lessw2020 @fegin)

Re: 2 - correct that it doesn't matter where the waits start per se, we are going to roll them up with the comm fusion pass and use only the last wait for each fused section (the rest will be removed with dce).

Re: status - the fusion pass has been working for some time and this is the pass atm that modifies the waits.
The cross iter pass is still in progress, but it's going to pick up and move the relevant wait with the move, so same as fusion pass where we'll move the waits directly as needed. (cc @wconstab )

@kumpera
Copy link
Contributor Author

kumpera commented Mar 31, 2023

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 31, 2023
@kumpera kumpera added module: c10d Issues/PRs related to collective communications and process groups topic: not user facing topic category labels Mar 31, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

mode = get_innermost_proxy_mode()
if mode is None:
return False
return mode.tracer is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you intentionally check mode.tracer? It seems to me that this is guaranteed to be not None

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: c10d Issues/PRs related to collective communications and process groups topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants