Skip to content

[ProcessGroupNCCL] Avoid recording stream for synchronous ops #111431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Oct 17, 2023

For synchronous ops (i.e. asyncOp = False), we don't want to record streams because we know that the NCCL stream will join back to the "current" stream right after this op. So we might just as well keep the stream ownership of the input/output tensors unchanged (i.e. not telling the caching allocator there is a stream-stream hand off). The benefit would be that the allocation/free of the tensors would look deterministic to the "current" stream so that the caching allocator can reuse memory pool for this stream in a clever way.

To prevent the input/output tensors from being recycled by python, we rely on the stashing mechanism in ProcessGroupNCCL (which can be also turned on by setting TORCH_NCCL_AVOID_RECORD_STREAMS=1).

This mechanism change is for libraries like FSDP which uses all_gather_into_tensor and reduce_scatter_tensor in a synchronous way and which cannot set TORCH_NCCL_AVOID_RECORD_STREAMS=1 for their users. And therefore, this change is limited to these two collectives for now.

Cc: @awgu @janeyx99 @albanD

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 17, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111431

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 4113d9c with merge base 6f06832 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change sounds good once CI is happy!

@@ -25,15 +25,15 @@ TORCH_LIBRARY(c10d, m) {
m.def(
"allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int timeout) -> (Tensor[][], __torch__.torch.classes.c10d.Work)");
m.def(
"_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group) -> (Tensor, __torch__.torch.classes.c10d.Work)");
"_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the timeout arg new here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is new for this op. However, it exists for every other ops, so I just added it to stay in line.

Copy link
Member

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, have a few comments feel free to address at your convenience

m.def(
"allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work");
m.def(
"allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group) -> __torch__.torch.classes.c10d.Work");
m.def(
"reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
m.def(
"_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)");
"_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ALLOW_LIST will need to be updated with _reduce_scatter_base_ and _all_gather_base_ to fix backward compat complaint in CI

@@ -137,6 +137,7 @@ struct ReduceOptions {

struct AllgatherOptions {
std::chrono::milliseconds timeout = kUnsetTimeout;
bool asyncOp = true;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should these be true by default? Our current collectives have async = false by default right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the python APIs, yes, async = false is the default.
But for C++ APIs, i.e. APIs defined by the Backend class, async = true is the default behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here struct AllgatherOptions is mainly for passing the async option into the C++ implementations, so I chose true as the default behavior.

@kwen2501
Copy link
Contributor Author

@pytorchbot merge -f "The failure in test_python_ref_executor__refs_sinc_executor_aten_cuda_complex128 does not seem related"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 28, 2023
Even after PR #111431, the `collective(...)` function still uses the underlined version `avoidRecordStreams_` inside and does not respect each collective call's preference, as the underlined `avoidRecordStreams_` is only controlled by environment variable.

As a fix, we pass `avoidRecordStreams` into the collective() function.

Pull Request resolved: #112195
Approved by: https://github.com/awgu
kwen2501 added a commit to kwen2501/pytorch that referenced this pull request Nov 6, 2023
…ytorch#112896)

Summary:

Follows PR pytorch#111431, save memory for DTensor init

Test Plan: Sandcastle

Reviewed By: wanchaol

Differential Revision: D50985365
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
…h#111431)

For synchronous ops (i.e. `asyncOp = False`), we don't want to record streams because we know that the NCCL stream will join back to the "current" stream right after this op. So we might just as well keep the stream ownership of the input/output tensors unchanged. The benefit would be that the allocation/free of the tensors would look deterministic to the "current" stream so that the caching allocator can reuse memory pool for this stream in a clever way.

To prevent the input/output tensors from being recycled by python, we rely on the stashing mechanism in ProcessGroupNCCL (which can be also turned on by setting `TORCH_NCCL_AVOID_RECORD_STREAMS=1`).

This mechanism change is for libraries like FSDP which uses `all_gather_into_tensor` and `reduce_scatter_tensor` in a synchronous way and which cannot set `TORCH_NCCL_AVOID_RECORD_STREAMS=1` for their users. And therefore, this change is limited to these two collectives for now.

Cc: @awgu @janeyx99 @albanD
Pull Request resolved: pytorch#111431
Approved by: https://github.com/H-Huang
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
…2195)

Even after PR pytorch#111431, the `collective(...)` function still uses the underlined version `avoidRecordStreams_` inside and does not respect each collective call's preference, as the underlined `avoidRecordStreams_` is only controlled by environment variable.

As a fix, we pass `avoidRecordStreams` into the collective() function.

Pull Request resolved: pytorch#112195
Approved by: https://github.com/awgu
pytorchmergebot pushed a commit that referenced this pull request Nov 7, 2023
…112896)

Summary: Follows PR #111431, save memory for DTensor init

Test Plan: Sandcastle

Reviewed By: wanchaol

Differential Revision: D50985365

Pull Request resolved: #112896
Approved by: https://github.com/wanchaol
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
…2195)

Even after PR pytorch#111431, the `collective(...)` function still uses the underlined version `avoidRecordStreams_` inside and does not respect each collective call's preference, as the underlined `avoidRecordStreams_` is only controlled by environment variable.

As a fix, we pass `avoidRecordStreams` into the collective() function.

Pull Request resolved: pytorch#112195
Approved by: https://github.com/awgu
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
…ytorch#112896)

Summary: Follows PR pytorch#111431, save memory for DTensor init

Test Plan: Sandcastle

Reviewed By: wanchaol

Differential Revision: D50985365

Pull Request resolved: pytorch#112896
Approved by: https://github.com/wanchaol
@github-actions github-actions bot deleted the avoid_record_ag_rs branch April 20, 2025 02:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants