Skip to content

Commit

Permalink
[c10d] Allow mixing complex and its element type in collectives.
Browse files Browse the repository at this point in the history
  • Loading branch information
Rodrigo Kumpera committed Sep 12, 2022
1 parent 22486a1 commit 492fed9
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions torch/distributed/distributed_c10d.py
Expand Up @@ -415,10 +415,15 @@ def _as_iterable(obj) -> collections.Iterable:
def _ensure_all_tensors_same_dtype(*tensors) -> None:
last_dtype = None
for tensor in itertools.chain(*map(_as_iterable, tensors)):
tensor_dtype = tensor.dtype
# Mixing complex and its element type is allowed
if tensor_dtype.is_complex:
tensor_dtype = torch.float32 if tensor_dtype == torch.complex64 else torch.complex128

if last_dtype is None:
last_dtype = tensor.dtype
last_dtype = tensor_dtype
else:
if last_dtype != tensor.dtype:
if last_dtype != tensor_dtype:
raise RuntimeError(
"Invalid usage of tensors with different dtypes"
f"Found {last_dtype} and {tensor.dtype}"
Expand Down

0 comments on commit 492fed9

Please sign in to comment.