Skip to content

Commit

Permalink
Map float8 types to uint8 for allgather (#126556)
Browse files Browse the repository at this point in the history
# Summary
Different take on this one:
#126338

We should probably not allow this mapping for 'compute' ops e.g. reductions

### Corresponding fp8 PR
pytorch-labs/float8_experimental#263

Pull Request resolved: #126556
Approved by: https://github.com/wanchaol
  • Loading branch information
drisspg authored and pytorchmergebot committed May 18, 2024
1 parent bf099a0 commit d4704dc
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 2 deletions.
93 changes: 93 additions & 0 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2577,6 +2577,27 @@ def test_all_reduce_coalesced_nccl(self):
),
)

@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_all_reduce_coalesced_nccl_float8_errors(self):
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
)
process_group = c10d.distributed_c10d._get_default_group()
device = torch.device("cuda:%d" % self.rank)
tensors = [
torch.full(
(60 + i,), self.rank + 1 + i, device=device, dtype=torch.float
).to(torch.float8_e4m3fn)
for i in range(5)
]
with self.assertRaisesRegex(
RuntimeError,
"Float8 dtypes are not currenlty supported for NCCL reductions",
):
torch.distributed.all_reduce_coalesced(tensors, group=process_group)

@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_all_reduce_coalesced_manager_nccl(self):
Expand Down Expand Up @@ -2940,6 +2961,56 @@ def test_reduce_scatter_tensor_coalesced(self):
dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i])
self.assertEqual(output_tensors, input_tensors[self.rank] * self.world_size)

@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_reduce_scatter_base_k_float8_errors(self):
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
"nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
output_tensor = (
torch.zeros(2, dtype=torch.float32).to(torch.float8_e4m3fn).to(self.rank)
)
input_tensors = (
torch.arange(self.world_size * 2, dtype=torch.float32)
.to(torch.float8_e4m3fn)
.to(self.rank)
)
input_tensors = torch.reshape(input_tensors, (self.world_size, 2))
with self.assertRaisesRegex(
RuntimeError,
"Float8 dtypes are not currenlty supported for NCCL reductions",
):
dist.reduce_scatter_tensor(output_tensor, input_tensors)

@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_reduce_scatter_tensor_coalesced_float8_errors(self):
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
"nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
output_tensors = torch.zeros(2, 2).to(torch.float8_e5m2).to(self.rank)
input_tensors = [
torch.ones(2, 2).to(torch.float8_e5m2).to(self.rank)
for _ in range(self.world_size)
]

with self.assertRaisesRegex(
RuntimeError,
"Float8 dtypes are not currenlty supported for NCCL reductions",
):
with dist._coalescing_manager():
for i in range(self.world_size):
dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i])
self.assertEqual(output_tensors, input_tensors[self.rank])


class SetDeviceMethod(Enum):
TORCH_CUDA_SET = auto() # torch.cuda.set_device
Expand Down Expand Up @@ -2980,6 +3051,28 @@ def test_allgather_base(self):
dist.all_gather_into_tensor(output_tensor, tensor)
self.assertEqual(output_tensor, tensor)

@requires_nccl()
@skip_if_lt_x_gpu(1)
@parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
def test_allgather_float8(self, float8_dtype):
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
"nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
device = "cuda"
tensor = torch.ones(10, 16, device=torch.device(device)).to(float8_dtype)
output_tensor = torch.zeros(10, 16, device=torch.device(device)).to(
float8_dtype
)
dist.all_gather_into_tensor(output_tensor, tensor)
self.assertEqual(output_tensor.view(torch.float32), tensor.view(torch.float32))


instantiate_parametrized_tests(NcclProcessGroupWithDispatchedCollectivesTests)


class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase):
def setUp(self):
Expand Down
24 changes: 22 additions & 2 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

#ifdef USE_C10D_NCCL

#include <exception>
Expand Down Expand Up @@ -64,6 +63,10 @@ std::map<at::ScalarType, ncclDataType_t> ncclDataType = {
{at::kLong, ncclInt64},
{at::kHalf, ncclHalf},
{at::kBool, ncclUint8},
{at::kFloat8_e5m2, ncclUint8},
{at::kFloat8_e4m3fn, ncclUint8},
{at::kFloat8_e4m3fnuz, ncclUint8},
{at::kFloat8_e5m2fnuz, ncclUint8},
#if HAS_NCCL_BF16_DATATYPE
{at::kBFloat16, ncclBfloat16},
#endif
Expand Down Expand Up @@ -3039,6 +3042,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_sparse(
const AllreduceOptions& opts) {
TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
auto tensor = tensors.back();
TORCH_CHECK(
!isFloat8Type(tensor.scalar_type()),
"Float8 dtypes are not currenlty supported for NCCL reductions");
#ifdef IS_NCCLX
tensor = tensor.coalesce();
at::Tensor outputTensor =
Expand Down Expand Up @@ -3153,7 +3159,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce(
return c10::make_intrusive<IntraNodeCommWork>();
}
}

TORCH_CHECK(
!isFloat8Type(tensor.scalar_type()),
"Float8 dtypes are not currenlty supported for NCCL reductions");
// @lint-ignore CLANGTIDY
RECORD_PARAM_COMMS_DATA(
static_cast<int>(
Expand All @@ -3180,6 +3188,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts) {
auto total_numel = check_gpu_tensors_same_device(tensors);
TORCH_CHECK(
!isFloat8Type(tensors.back().scalar_type()),
"Float8 dtypes are not currenlty supported for NCCL reductions");

// @lint-ignore CLANGTIDY
RECORD_PARAM_COMMS_DATA(
Expand Down Expand Up @@ -3552,6 +3563,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
check_gpu_single_tensor(outputTensor);
// @lint-ignore CLANGTIDY
auto inputTensors_ = inputTensors.back();
TORCH_CHECK(
!isFloat8Type(outputTensor.scalar_type()),
"Float8 dtypes are not currenlty supported for NCCL reductions");

RECORD_PARAM_COMMS_DATA(
static_cast<int>(
Expand Down Expand Up @@ -3663,6 +3677,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(

// @lint-ignore CLANGTIDY
const auto& tensor = outputTensor;
TORCH_CHECK(
!isFloat8Type(tensor.scalar_type()),
"Float8 dtypes are not currenlty supported for NCCL reductions");
RECORD_PARAM_COMMS_DATA(
static_cast<int>(
this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
Expand Down Expand Up @@ -3723,6 +3740,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced(
std::vector<at::Tensor>& outputs,
std::vector<at::Tensor>& inputs,
const ReduceScatterOptions& opts) {
TORCH_CHECK(
!isFloat8Type(inputs.back().scalar_type()),
"Float8 dtypes are not currenlty supported for NCCL reductions");
return collectiveCoalesced(
inputs,
outputs,
Expand Down

0 comments on commit d4704dc

Please sign in to comment.