Skip to content

Conversation

@aashaka
Copy link

@aashaka aashaka commented Aug 5, 2022

Summary: A vector reduce_scatter requires each process to reduce and scatter an input tensor according to the input list provided. Internally, pg_nccl.reduce_scatter will coalesce a list of pg_nccl._reduce_oop to implement a vector reduce-scatter in the case when the any input shape is different in the input list. Otherwise, it will perform a ncclReduceScatter as usual.

  • This change adds a CoalescedWorkNCCL class which encapsulates the WorkNCCL requests from coalesced operations. A .wait() on a CoalescedWorkNCCL request will call a wait on each of the WorkNCCL requests that are coalesced.

  • This change adds an out-of-place _reduce_oop function to ProcessGroupNCCL. It allows reducing an input tensor and placing the output in a separate output tensor. Since reduce_scatter provides an out-of-place API, a reduce_scatter_v semantic implemented inside pg_nccl.reduce_scatter also needs to support out-of-place, for which an out-of-place reduce is required to be added.

Test Plan: Added a new test test_reduce_scatter_v_cuda for reduce_scatter_v to distributed_nccl_spawn.

Differential Revision: D38478781

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Aug 5, 2022

🔗 Helpful links

✅ No Failures (1 Pending)

As of commit 4ace662 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D38478781

@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Aug 5, 2022
@kwen2501
Copy link
Contributor

kwen2501 commented Aug 6, 2022

Maybe also worth writing a bit in the PR description about the context --
You need out-of-place reduce from the backend because you want to compose a reduce_scatter_v pattern at the Python front end using coalesced reduces. Today, the dist.reduce_scatter API supports out-of-place, so, if the reduce_scatter_v pattern is implemented under the dist.reduce_scatter API, you would need out-of-place support as well.

@kwen2501
Copy link
Contributor

kwen2501 commented Aug 6, 2022

Also good to comment the above in code.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D38478781

@aashaka aashaka changed the title Expose an out-of-place _reduce from ProcessGroupNCCL Expose an out-of-place _reduce_oop from ProcessGroupNCCL Aug 8, 2022
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D38478781

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D38478781

@aashaka aashaka changed the title Expose an out-of-place _reduce_oop from ProcessGroupNCCL Enable pg_nccl.reduce_scatter to perform vector ReduceScatter for uneven input splits Aug 15, 2022
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D38478781

…ven input splits (pytorch#82924)

Summary:
Pull Request resolved: pytorch#82924

A vector reduce_scatter requires each process to reduce and scatter an input tensor according to the input list provided.
Internally, pg_nccl.reduce_scatter will coalesce a list of pg_nccl._reduce_oop to implement a vector reduce-scatter in the case when the any input shape is different in the input list. Otherwise, it will perform a ncclReduceScatter as usual.

- This change adds a `CoalescedWorkNCCL` class which encapsulates the WorkNCCL requests from coalesced operations. A `.wait()` on a CoalescedWorkNCCL request will call a wait on each of the WorkNCCL requests that are coalesced.

- This change adds an out-of-place `_reduce_oop` function to ProcessGroupNCCL. It allows reducing an input tensor and placing the output in a separate output tensor. Since reduce_scatter provides an out-of-place API, a reduce_scatter_v semantic implemented inside `pg_nccl.reduce_scatter` also needs to support out-of-place, for which an out-of-place reduce is required to be added.

Test Plan: Added a new test `test_reduce_scatter_v_cuda` for reduce_scatter_v to `distributed_nccl_spawn`.

Differential Revision: D38478781

fbshipit-source-id: 0157acdee4e9a1dd328a27d4e30c3b81c4523039
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D38478781

Copy link
Contributor

@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@kwen2501
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here.
The merge job was triggered without a flag. This means that your change will be merged once all checks on your PR have passed (ETA: 0-4 Hours). If this is not the intended behavior, feel free to use some of the other merge options in the wiki.
Please reach out to the PyTorch DevX Team with feedback or questions!

@github-actions
Copy link
Contributor

Hey @aashaka.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

if async_val:
req.wait()

expected_value = 2 + (10 * (len(group) - 1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use variable defined above instead of instant numbers.

end_len = start_len + input_split_sizes[rank]
sum_len = sum(input_split_sizes)
master_value = 2
worker_value = 10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:
rename master_value --> value_to_self
rename worker_value --> value_to_others
to be clearer

@kwen2501 kwen2501 added release notes: distributed (c10d) release notes category topic: new features topic category labels Aug 18, 2022
facebook-github-bot pushed a commit that referenced this pull request Aug 19, 2022
…ven input splits (#82924) (#82924)

Summary:
A vector reduce_scatter requires each process to reduce and scatter an input tensor according to the input list provided. Internally, pg_nccl.reduce_scatter will coalesce a list of pg_nccl._reduce_oop to implement a vector reduce-scatter in the case when the any input shape is different in the input list. Otherwise, it will perform a ncclReduceScatter as usual.

- This change adds a `CoalescedWorkNCCL` class which encapsulates the WorkNCCL requests from coalesced operations. A `.wait()` on a CoalescedWorkNCCL request will call a wait on each of the WorkNCCL requests that are coalesced.

- This change adds an out-of-place `_reduce_oop` function to ProcessGroupNCCL. It allows reducing an input tensor and placing the output in a separate output tensor. Since reduce_scatter provides an out-of-place API, a reduce_scatter_v semantic implemented inside `pg_nccl.reduce_scatter` also needs to support out-of-place, for which an out-of-place reduce is required to be added.

Pull Request resolved: #82924
Approved by: https://github.com/kwen2501

Test Plan:
contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/d6a30e213e2355e8ad553c02d205391c889a0254

Test plan from GitHub:
Added a new test `test_reduce_scatter_v_cuda` for reduce_scatter_v to `distributed_nccl_spawn`.

Original Phabricator Test Plan:
Added a new test `test_reduce_scatter_v_cuda` for reduce_scatter_v to `distributed_nccl_spawn`.

Reviewed By: kwen2501

Differential Revision: D38478781

Pulled By: aashaka

fbshipit-source-id: b5a203847241e83556e51b640b24eaa765dca6c4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed fb-exported Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category topic: new features topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants