Skip to content

Conversation

rohan-varma
Copy link
Contributor

@rohan-varma rohan-varma commented Jul 12, 2020

Stack from ghstack:

Closes #24137.
This PR adds support for the torch.bool tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since bool is not supported as a native ncclDataType_t, we add the following logic:

  1. Map at::kBool to ncclUint8
  2. During reduction (allreduce for example), if the operation is SUM, we instead override to to a MAX, to avoid overflow issues. The rest of the operations work with no changes. In the boolean case, changing sum to max makes no correctness difference since they both function as a bitwise OR.

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362

Tests are added to ensure that the reductions work as expected.
Differential Revision: D22496604

NOTE FOR REVIEWERS: This PR has internal Facebook specific changes or comments, please review them on Phabricator!

Closes #24137. Since bool is
not supported as a native ncclDataType_t, we add some upcasting + downcasting
logic to support it.

Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Jul 12, 2020
Closes #24137. Since bool is
not supported as a native ncclDataType_t, we add some upcasting + downcasting
logic to support it.

Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!

ghstack-source-id: 107598033
Pull Request resolved: #41318
Closes #24137. Since bool is
not supported as a native ncclDataType_t, we add some upcasting + downcasting
logic to support it.

Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!

[ghstack-poisoned]
Closes #24137. 
This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors. 
2) Run the specified reduction.
3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op. 

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362

Tests are added to ensure that the reductions work as expected. 
Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!

[ghstack-poisoned]
@dr-ci
Copy link

dr-ci bot commented Jul 13, 2020

💊 CI failures summary and remediations

As of commit 1522331 (more details on the Dr. CI page):


  • 2/2 failures possibly* introduced in this PR
    • 1/2 non-CircleCI failure(s)

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_cuda10_1_cudnn7_py3_multigpu_test (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun) <confirmed not flaky by 2 failures>

Jul 24 00:15:47 FAIL [4.424s]: test_reduce_multigpu (__main__.TestDistBackend)
Jul 24 00:15:47 Traceback (most recent call last): 
Jul 24 00:15:47   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 204, in wrapper 
Jul 24 00:15:47     self._join_processes(fn) 
Jul 24 00:15:47   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 311, in _join_processes 
Jul 24 00:15:47     self._check_return_codes(elapsed_time) 
Jul 24 00:15:47   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 344, in _check_return_codes 
Jul 24 00:15:47     raise RuntimeError(error) 
Jul 24 00:15:47 RuntimeError: Processes 2 exited with error code 10 
Jul 24 00:15:47  
Jul 24 00:15:47 ====================================================================== 
Jul 24 00:15:47 FAIL [4.424s]: test_reduce_multigpu (__main__.TestDistBackend) 
Jul 24 00:15:47 ---------------------------------------------------------------------- 
Jul 24 00:15:47 Traceback (most recent call last): 
Jul 24 00:15:47   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 204, in wrapper 
Jul 24 00:15:47     self._join_processes(fn) 
Jul 24 00:15:47   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 311, in _join_processes 
Jul 24 00:15:47     self._check_return_codes(elapsed_time) 
Jul 24 00:15:47   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 363, in _check_return_codes 
Jul 24 00:15:47     msg="Expected zero exit code but got {}".format(first_process.exitcode) 
Jul 24 00:15:47   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 1122, in assertEqual 
Jul 24 00:15:47     self.assertTrue(result, msg=msg) 

Extra GitHub checks: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 17 times.

@rohan-varma rohan-varma changed the title [WIP] NCCL Backend support for torch.bool NCCL Backend support for torch.bool Jul 13, 2020
Closes #24137. 
This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors. 
2) Run the specified reduction.
3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op. 

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362

Tests are added to ensure that the reductions work as expected. 
Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Jul 13, 2020
Pull Request resolved: #41318

Closes #24137.

This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors.
2) Run the specified reduction.
3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op.

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Tests are added to ensure that the reductions work as expected.
ghstack-source-id: 107675254

Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!
Closes #24137. 
This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors. 
2) Run the specified reduction.
3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op. 

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362

Tests are added to ensure that the reductions work as expected. 
Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Jul 14, 2020
Pull Request resolved: #41318

Closes #24137.

This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors.
2) Run the specified reduction.
3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op.

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Tests are added to ensure that the reductions work as expected.
ghstack-source-id: 107698101

Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!
for (auto& tensor : tensors) {
// TODO: a simple tensor = tensor.to(long) won't work here. The allreduce
// will be correct, but the modified tensor won't be reflected in Python.
auto asIntBufTensor = tensor.to(at::kLong);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason for using at::kLong instead of at::kByte or at::kChar? They correspond to ncclInt8 and ncclUint8 respectively.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason is to prevent overflow issues, I think both sizes might be too small for very large-scale training use cases. Do we know an estimate of the largest no. of process groups we've seen?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I wonder if it makes sense to just support min (all) and max(any) for bool tensors? Sum and prod do not seem to apply here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going by Pieter's suggestion in a comment on this issue: #24137
In boolean logic, since sum means OR and product means AND, it might be reasonable for users to expect PyTorch to obey this? i.e., this would mean calling a sum reduction on boolean collectives means doing an OR of them.

Open to either approach though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since sum means OR and product means AND, it might be reasonable for users to expect PyTorch to obey this?

I see. This makes sense. My concern was converting bool into long might be a perf hit. How about we just use the bool tensor, and let applications deal with the overflow issue, and highlight that in the doc? They do have options to avoid overflow by using MIN/MAX.

Another option is we internally replace sum with max and prod with min, as we expect true + true = true and true * false = false, etc.

Copy link
Contributor Author

@rohan-varma rohan-varma Jul 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, so it seems that this is best approach with:

  1. Add a type mapping of at::kBool to ncclUint8
  2. I'm assuming (need to check) this means that nccl will treat the payload as a buffer of uint8.
  3. If the data type is of at::kBool, override the ncclOp if it's sum or product to be max or min, respectively (this will prevent the overflow issue, which ideally users should not have to manually deal with)
  4. With the above changes the reduction should work transparently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the diff to use this approach.

return ncclDataType.at(type);
} catch (std::out_of_range& e) {
throw std::runtime_error("Unsupported data type for NCCL process group");
auto it = ncclDataType.find(type);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it work if we add an entry to ncclDataType to map at::kBool to ncclUint8? How does at::kBool represent and interpret true and false? I recall had a discussion with @izdeby. IIRC, only 0 is interpreted as false and all others are true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the casting seems to be:

bool True -> 1, False -> 0
any nonzero -> True, 0 -> False

The issue I see with casting to uint8 - either through the map or with similar logic as above - is that this will mean we can only support up to 255 processes calling all_reduce(). For example, if we have 256 processes and call allreduce with sum, and each process contributes a single set bit, we'd get 0 as our result, which would be wrong.

@rohan-varma
Copy link
Contributor Author

Adding @jiayisuse to review as well

@rohan-varma rohan-varma requested a review from jiayisuse July 16, 2020 20:00
Closes #24137. 
This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors. 
2) Run the specified reduction.
3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op. 

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362

Tests are added to ensure that the reductions work as expected. 
Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Jul 16, 2020
Pull Request resolved: #41318

Closes #24137.

This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors.
2) Run the specified reduction.
3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op.

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Tests are added to ensure that the reductions work as expected.
ghstack-source-id: 107942247

Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!
Closes #24137. 
This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors. 
2) Run the specified reduction.
3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op. 

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362

Tests are added to ensure that the reductions work as expected. 
Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Jul 17, 2020
Pull Request resolved: #41318

Closes #24137.

This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Detect if input tensors are of bool type. If so, cast inputs & outputs to int tensors.
2) Run the specified reduction.
3) If we had to cast, cast the outputs back to boolean tensors. If this collective does not operator in-place, then re-cast inputs back to bool so that they are not modified as a result of the op.

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Tests are added to ensure that the reductions work as expected.
ghstack-source-id: 108017010

Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!
@rohan-varma rohan-varma requested a review from mrshenli July 17, 2020 19:09
Closes #24137. 
This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Map `at::kBool` to `ncclUint8`
2) During reduction (allreduce for example), if the operation is SUM, we instead override to to a MAX, to avoid overflow issues. The rest of the operations work with no changes. In the boolean case, changing sum to max makes no correctness difference since they both function as a bitwise OR. 

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362

Tests are added to ensure that the reductions work as expected. 
Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Jul 21, 2020
Pull Request resolved: #41318

Closes #24137.

This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Map `at::kBool` to `ncclUint8`
2) During reduction (allreduce for example), if the operation is SUM, we instead override to to a MAX, to avoid overflow issues. The rest of the operations work with no changes. In the boolean case, changing sum to max makes no correctness difference since they both function as a bitwise OR.

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Tests are added to ensure that the reductions work as expected.
ghstack-source-id: 108185942

Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!
Copy link
Contributor

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for adding this.

I guess it's impossible for us to test the bool/uint8 overflow problem, as we cannot spawn 256+ processes?

}

ncclRedOp_t getNcclReduceOp(const ReduceOp reduceOp, at::Tensor& input) {
if (reduceOp == ReduceOp::SUM && input.scalar_type() == at::kBool) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess prod is fine as True always maps to 1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, per:

>>> torch.ones(1).to(bool)
tensor([True])

and also verified that it works by modifying the test to all_reduce on True and return True as expected.

@require_backends_available({"nccl"})
@skip_if_lt_x_gpu(2)
@skip_if_rocm
def test_nccl_backend_bool_reduction(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on the name of other tests, this test name should be "test_nccl_backend_bool_allreduce"?

Closes #24137. 
This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Map `at::kBool` to `ncclUint8`
2) During reduction (allreduce for example), if the operation is SUM, we instead override to to a MAX, to avoid overflow issues. The rest of the operations work with no changes. In the boolean case, changing sum to max makes no correctness difference since they both function as a bitwise OR. 

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362

Tests are added to ensure that the reductions work as expected. 
Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Jul 23, 2020
Pull Request resolved: #41318

Closes #24137.

This PR adds support for the `torch.bool` tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since `bool` is not supported as a native `ncclDataType_t`, we add the following logic:
1) Map `at::kBool` to `ncclUint8`
2) During reduction (allreduce for example), if the operation is SUM, we instead override to to a MAX, to avoid overflow issues. The rest of the operations work with no changes. In the boolean case, changing sum to max makes no correctness difference since they both function as a bitwise OR.

The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Tests are added to ensure that the reductions work as expected.
ghstack-source-id: 108315417

Differential Revision: [D22496604](https://our.internmc.facebook.com/intern/diff/D22496604/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D22496604/)!
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 3626473.

facebook-github-bot pushed a commit that referenced this pull request Jul 25, 2020
Summary:
Resubmit of #41318 pushed to ci-all branch.

Original description:
Closes #24137.
This PR adds support for the torch.bool tensor type to ProcessGroupNCCL. For most types we use the existing mapping, but since bool is not supported as a native ncclDataType_t, we add the following logic:

Map at::kBool to ncclUint8
During reduction (allreduce for example), if the operation is SUM, we instead override to to a MAX, to avoid overflow issues. The rest of the operations work with no changes. In the boolean case, changing sum to max makes no correctness difference since they both function as a bitwise OR.
The reduction logic (for example for reduce/allreduce) is as follows:
sum, max = bitwise or
product, min = bitwise and

Note that this PR doesn't add support for BAND/BOR/BXOR. That is because these reduction ops currently are not supported by NCCL backend, see #41362

Pull Request resolved: #41959

Reviewed By: mrshenli

Differential Revision: D22719665

Pulled By: rohan-varma

fbshipit-source-id: 8bc4194a8d1268589640242277124f277d2ec9f1
@facebook-github-bot facebook-github-bot deleted the gh/rohan-varma/149/head branch July 27, 2020 14:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants