-
Notifications
You must be signed in to change notification settings - Fork 25.6k
NCCL Backend support for torch.bool #41318
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9c195f8
a3cccd4
c69b4cf
4f989e9
117c4bc
0e24671
a8794fc
29d2635
1522331
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,18 +52,30 @@ std::map<at::ScalarType, ncclDataType_t> ncclDataType = { | |
{at::kInt, ncclInt32}, | ||
{at::kLong, ncclInt64}, | ||
{at::kHalf, ncclHalf}, | ||
{at::kBool, ncclUint8}, | ||
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 301 | ||
{at::kBFloat16, ncclBfloat16}, | ||
#endif | ||
}; | ||
|
||
// Helper function that gets the data type and issues error if not supported | ||
ncclDataType_t getNcclDataType(at::ScalarType type) { | ||
try { | ||
return ncclDataType.at(type); | ||
} catch (std::out_of_range& e) { | ||
throw std::runtime_error("Unsupported data type for NCCL process group"); | ||
auto it = ncclDataType.find(type); | ||
TORCH_CHECK( | ||
it != ncclDataType.end(), | ||
"Input tensor data type is not supported for NCCL process group: ", | ||
type); | ||
return it->second; | ||
} | ||
|
||
ncclRedOp_t getNcclReduceOp(const ReduceOp reduceOp, at::Tensor& input) { | ||
if (reduceOp == ReduceOp::SUM && input.scalar_type() == at::kBool) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess prod is fine as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, per:
and also verified that it works by modifying the test to all_reduce on True and return True as expected. |
||
// For bool tensors, map sum to max, which both represent a bitwise or. | ||
// This is to prevent overflow issues with sum, since we use uint8 to | ||
// represent a bool (see ncclDataType mapping). | ||
return ncclMax; | ||
} | ||
return ncclOp[reduceOp]; | ||
} | ||
|
||
// Get the deviceList String from the list of devices | ||
|
@@ -752,7 +764,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allreduce( | |
output.data_ptr(), | ||
input.numel(), | ||
getNcclDataType(input.scalar_type()), | ||
ncclOp[opts.reduceOp], | ||
getNcclReduceOp(opts.reduceOp, input), | ||
comm, | ||
stream.stream()); | ||
}); | ||
|
@@ -806,7 +818,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce( | |
output.data_ptr(), | ||
input.numel(), | ||
getNcclDataType(input.scalar_type()), | ||
ncclOp[opts.reduceOp], | ||
getNcclReduceOp(opts.reduceOp, input), | ||
root, | ||
comm, | ||
stream.stream()); | ||
|
@@ -888,7 +900,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce_scatter( | |
output.data_ptr(), | ||
output.numel(), | ||
getNcclDataType(input.scalar_type()), | ||
ncclOp[opts.reduceOp], | ||
getNcclReduceOp(opts.reduceOp, input), | ||
comm, | ||
stream.stream()); | ||
}, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will it work if we add an entry to
ncclDataType
to mapat::kBool
toncclUint8
? How doesat::kBool
represent and interpret true and false? I recall had a discussion with @izdeby. IIRC, only 0 is interpreted as false and all others are true?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the casting seems to be:
bool True -> 1, False -> 0
any nonzero -> True, 0 -> False
The issue I see with casting to
uint8
- either through the map or with similar logic as above - is that this will mean we can only support up to 255 processes callingall_reduce()
. For example, if we have 256 processes and call allreduce with sum, and each process contributes a single set bit, we'd get 0 as our result, which would be wrong.