Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[c10d] Ensure collectives are called with the same dtype for all tensor params. #84664

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
78 changes: 78 additions & 0 deletions test/distributed/test_c10d_common.py
Expand Up @@ -972,6 +972,10 @@ def op_timeout_sec(self):
def world_size(self):
return 2

@property
def device(self):
self.fail("test subclass didn't override device")

def _verify_sequence_number_across_pg(self, pg, verify_pg):

seq_num = pg._get_sequence_number_for_group()
Expand Down Expand Up @@ -1144,7 +1148,81 @@ def _test_rank_membership(self, backend):

self.assertEqual(dist.get_process_group_ranks(group), [1])

def _test_tensor_dtype_mismatch(self, backend):
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend,
world_size=self.world_size,
rank=self.rank,
store=store,
)

tensor = torch.ones(2, 2, device=self.device) * 7
tensor_h = tensor.half()
tensor_list = [torch.zeros(2, 2, device=self.device) for _ in range(self.world_size)]
tensor_list_h = list(tensor_list)
tensor_list_h[1] = tensor_list_h[1].half()


with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
dist.all_gather(tensor_list_h, tensor)

with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
dist.all_gather(tensor_list, tensor_h)

with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
dist.all_gather_coalesced([tensor_list_h], tensor_list)
dist.all_gather_coalesced([tensor_list], tensor_list_h)

with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
dist.all_reduce_coalesced(tensor_list_h)

with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
dist.reduce_scatter(tensor, tensor_list_h)

with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
dist.reduce_scatter(tensor_h, tensor_list)

with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
dist.all_to_all_single(tensor_h, tensor)

with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
dist.all_to_all(tensor_list_h, tensor_list)

with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
dist.all_to_all(tensor_list, tensor_list_h)

with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
dist.scatter(tensor, tensor_list_h)

with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
dist.gather(tensor_h, tensor_list)

with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
dist.gather(tensor, tensor_list_h)

with self.assertRaisesRegex(RuntimeError, "tensors with different dtypes"):
dist.scatter(tensor_h, tensor_list)

def _test_tensor_dtype_complex(self, backend):
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend,
world_size=self.world_size,
rank=self.rank,
store=store,
)

tensor = torch.rand(2, device=self.device)
tensor_c = torch.view_as_complex(tensor)
tensor_list = [torch.rand(2, device=self.device) for _ in range(self.world_size)]
tensor_list_c = list(tensor_list)
tensor_list_c[1] = torch.view_as_complex(tensor_list_c[1])

dist.all_gather(tensor_list, tensor)
dist.all_gather(tensor_list, tensor_c)
dist.all_gather(tensor_list_c, tensor)
dist.all_gather(tensor_list_c, tensor_c)

class CommTest(AbstractCommTest, MultiProcessTestCase):
def setUp(self):
Expand Down
14 changes: 14 additions & 0 deletions test/distributed/test_c10d_gloo.py
Expand Up @@ -2232,6 +2232,11 @@ def test_forward_backward_optimizer(self):


class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
@property
def device(self):
return "cpu"


def setUp(self):
super(CommTest, self).setUp()
self._spawn_processes()
Expand Down Expand Up @@ -2343,6 +2348,15 @@ def test_gloo_warn_not_in_group(self):
def test_gloo_rank_membership(self):
self._test_rank_membership(backend="gloo")

@skip_if_lt_x_gpu(2)
@requires_gloo()
def test_tensor_dtype_mismatch(self):
self._test_tensor_dtype_mismatch(backend="gloo")

@skip_if_lt_x_gpu(2)
@requires_gloo()
def test_tensor_dtype_complex(self):
self._test_tensor_dtype_complex(backend="gloo")

class CompilerTest(test_c10d_common.CompilerTest):

Expand Down
15 changes: 15 additions & 0 deletions test/distributed/test_c10d_nccl.py
Expand Up @@ -2548,6 +2548,11 @@ def test_nccl_timeout(self):


class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
@property
def device(self):
return f"cuda:{self.rank}"


def setUp(self):
super(CommTest, self).setUp()
# NCCL_BLOCKING_WAIT overrides NCCL_ASYNC_ERROR_HANDLING hence tests
Expand Down Expand Up @@ -2806,6 +2811,16 @@ def test_nccl_warn_not_in_group_debug_off(self):
def test_nncl_rank_membership(self):
self._test_rank_membership(backend="nccl")

@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_tensor_dtype_mismatch(self):
self._test_tensor_dtype_mismatch(backend="nccl")

@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_tensor_dtype_complex(self):
self._test_tensor_dtype_complex(backend="nccl")


class CompilerTest(test_c10d_common.CompilerTest):

Expand Down
31 changes: 31 additions & 0 deletions torch/distributed/distributed_c10d.py
@@ -1,3 +1,5 @@
import itertools
import collections.abc
import contextlib
import io
import logging
Expand Down Expand Up @@ -407,6 +409,26 @@ def _check_tensor_list(param, param_name):
"to be of type List[torch.Tensor].".format(param_name)
)

def _as_iterable(obj) -> collections.abc.Iterable:
return obj if isinstance(obj, list) else (obj,)

def _ensure_all_tensors_same_dtype(*tensors) -> None:
last_dtype = None
for tensor in itertools.chain(*map(_as_iterable, tensors)):
tensor_dtype = tensor.dtype
# Mixing complex and its element type is allowed
if tensor_dtype.is_complex:
kumpera marked this conversation as resolved.
Show resolved Hide resolved
tensor_dtype = torch.float32 if tensor_dtype == torch.complex64 else torch.complex128

if last_dtype is None:
last_dtype = tensor_dtype
else:
if last_dtype != tensor_dtype:
raise RuntimeError(
"Invalid usage of tensors with different dtypes"
f"Found {last_dtype} and {tensor.dtype}"
)


def _check_op(op):
"""
Expand Down Expand Up @@ -1458,6 +1480,7 @@ def all_reduce_coalesced(tensors, op=ReduceOp.SUM, group=None, async_op=False):

"""
_check_tensor_list(tensors, "tensor")
_ensure_all_tensors_same_dtype(tensors)
if _rank_not_in_group(group):
_warn_not_in_group("all_reduce_coalesced")
return
Expand Down Expand Up @@ -2125,6 +2148,7 @@ def all_gather(tensor_list, tensor, group=None, async_op=False):
"""
_check_tensor_list(tensor_list, "tensor_list")
_check_single_tensor(tensor, "tensor")
_ensure_all_tensors_same_dtype(tensor_list, tensor)
if _rank_not_in_group(group):
_warn_not_in_group("all_gather")
return
Expand Down Expand Up @@ -2265,12 +2289,14 @@ def all_gather_coalesced(
_warn_not_in_group("all_gather_coalesced")
return
_check_tensor_list(input_tensor_list, "tensor_list")
_ensure_all_tensors_same_dtype(input_tensor_list)
if not isinstance(output_tensor_lists, list):
raise RuntimeError(
"Invalid function argument: " "output_tensor_lists should be a list"
)
for output_tensor_list in output_tensor_lists:
_check_tensor_list(output_tensor_list, "output_tensor_lists")
_ensure_all_tensors_same_dtype(output_tensor_list)

output_tensor_lists = [
[t if not t.is_complex() else torch.view_as_real(t) for t in l]
Expand Down Expand Up @@ -2331,6 +2357,7 @@ def gather(tensor, gather_list=None, dst=0, group=None, async_op=False):
_check_tensor_list(gather_list, "gather_list")
else:
gather_list = []
_ensure_all_tensors_same_dtype(tensor, gather_list)

if _rank_not_in_group(group):
_warn_not_in_group("gather")
Expand Down Expand Up @@ -2388,6 +2415,7 @@ def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False):
_check_tensor_list(scatter_list, "scatter_list")
else:
scatter_list = []
_ensure_all_tensors_same_dtype(tensor, scatter_list)

if _rank_not_in_group(group):
_warn_not_in_group("scatter")
Expand Down Expand Up @@ -2512,6 +2540,7 @@ def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=Fal
"""
_check_single_tensor(output, "output")
_check_tensor_list(input_list, "input_list")
_ensure_all_tensors_same_dtype(output, input_list)
if _rank_not_in_group(group):
_warn_not_in_group("reduce_scatter")
return
Expand Down Expand Up @@ -2673,6 +2702,7 @@ def all_to_all_single(
opts = AllToAllOptions()
_check_single_tensor(output, "output")
_check_single_tensor(input, "input")
_ensure_all_tensors_same_dtype(output, input)

if input.is_complex():
input = torch.view_as_real(input)
Expand Down Expand Up @@ -2796,6 +2826,7 @@ def all_to_all(output_tensor_list, input_tensor_list, group=None, async_op=False
opts = AllToAllOptions()
_check_tensor_list(output_tensor_list, "output_tensor_list")
_check_tensor_list(input_tensor_list, "input_tensor_list")
_ensure_all_tensors_same_dtype(output_tensor_list, input_tensor_list)

input_tensor_list = [
t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list
Expand Down