Skip to content

Commit

Permalink
fp8 support in all_gather
Browse files Browse the repository at this point in the history
  • Loading branch information
jspark1105 committed Oct 14, 2023
1 parent 6b0edba commit dd4c868
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3342,11 +3342,20 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base(
c10::cuda::CUDACachingAllocator::recordStream(
output.storage().data_ptr(), stream);
}

ncclDataType_t nccl_dtype;
if (input.scalar_type() == at::kFloat8_e5m2 ||
input.scalar_type() == at::kFloat8_e4m3fn) {
nccl_dtype = ncclInt8;
} else {
nccl_dtype = getNcclDataType(input.scalar_type());
}

return ncclAllGather(
input.data_ptr(),
output.data_ptr(),
input.numel(),
getNcclDataType(input.scalar_type()),
nccl_dtype,
comm,
stream.stream());
},
Expand Down

0 comments on commit dd4c868

Please sign in to comment.