Skip to content

Commit

Permalink
adding complex support for distributed functions and . fix #45760
Browse files Browse the repository at this point in the history
updated docs
used standard python repl examples in the docs, tested the way
that they render in the browser

more doc fixes. Add an explicit error check for ReduceOps that
do not support complex (Max and Min), + tests for that case

ghstack-source-id: 4920be3e0cb551612c2f76a4fbcea2444f097558
Pull Request resolved: #45879
  • Loading branch information
bdhirsh committed Oct 8, 2020
1 parent f65ab89 commit 0977393
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 51 deletions.
2 changes: 2 additions & 0 deletions torch/csrc/distributed/c10d/init.cpp
Expand Up @@ -220,6 +220,8 @@ An enum-like class for available reduction operations: ``SUM``, ``PRODUCT``,
Note that ``BAND``, ``BOR``, and ``BXOR`` reductions are not available when
using the ``NCCL`` backend.
Additionally, ``MAX``, ``MIN`` and ``PRODUCT`` are not supported for complex tensors.
The values of this class can be accessed as attributes, e.g., ``ReduceOp.SUM``.
They are used in specifying strategies for reduction collectives, e.g.,
:func:`reduce`, :func:`all_reduce_multigpu`, etc.)")
Expand Down
97 changes: 97 additions & 0 deletions torch/distributed/distributed_c10d.py
Expand Up @@ -43,6 +43,18 @@
except ImportError:
_GLOO_AVAILABLE = False

# Some reduce ops are not supported by complex numbers.
# We currently provide complex support to the distributed API by viewing
# complex tensors as real (torch.view_as_real), meaning that calling
# these unsupported ops will return garbage values rather than error out.
# (e.g. max(2+3i, 3+2i) = 3+3i)
# We'd like calls to unsupported ops to error out accordingly,
# rather than returning garbage values.
def supports_complex(reduceOp: ReduceOp) -> bool:
if reduceOp == ReduceOp.MAX or reduceOp == ReduceOp.MIN or reduceOp == ReduceOp.PRODUCT:
return False
return True


class Backend(object):
"""
Expand Down Expand Up @@ -869,6 +881,8 @@ def all_reduce_multigpu(tensor_list,
After the call, all ``tensor`` in ``tensor_list`` is going to be bitwise
identical in all processes.
Complex tensors are supported.
Only nccl and gloo backend is currently supported
tensors should only be GPU tensors
Expand All @@ -892,6 +906,8 @@ def all_reduce_multigpu(tensor_list,
if _rank_not_in_group(group):
return

tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list]

opts = AllreduceOptions()
opts.reduceOp = op
if group == GroupMember.WORLD:
Expand All @@ -916,6 +932,8 @@ def all_reduce(tensor,
After the call ``tensor`` is going to be bitwise identical in all processes.
Complex tensors are supported.
Arguments:
tensor (Tensor): Input and output of the collective. The function
operates in-place.
Expand All @@ -929,11 +947,39 @@ def all_reduce(tensor,
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group
Examples:
>>> # All tensors below are of torch.int64 type.
>>> # We have 2 process groups, 2 ranks.
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> dist.all_reduce(tensor, op=ReduceOp.SUM)
>>> tensor
tensor([4, 6]) # Rank 0
tensor([4, 6]) # Rank 1
>>> # All tensors below are of torch.cfloat type.
>>> # We have 2 process groups, 2 ranks.
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
>>> tensor
tensor([1.+1.j, 2.+2.j]) # Rank 0
tensor([3.+3.j, 4.+4.j]) # Rank 1
>>> dist.all_reduce(tensor, op=ReduceOp.SUM)
>>> tensor
tensor([4.+4.j, 6.+6.j]) # Rank 0
tensor([4.+4.j, 6.+6.j]) # Rank 1
"""
_check_single_tensor(tensor, "tensor")
if _rank_not_in_group(group):
return

if tensor.is_complex():
if not supports_complex(op):
raise RuntimeError(f"{op} is unsupported on complex tensors")
tensor = torch.view_as_real(tensor)

opts = AllreduceOptions()
opts.reduceOp = op
if group == GroupMember.WORLD:
Expand Down Expand Up @@ -967,6 +1013,8 @@ def all_reduce_coalesced(tensors,
After the call each tensor in tensors is going to bitwise identical
in all processes.
Complex tensors are supported.
Arguments:
tensors (List[Tensor]): Input and output of the collective. The function
operates in-place.
Expand All @@ -985,6 +1033,11 @@ def all_reduce_coalesced(tensors,
if _rank_not_in_group(group):
return

if any([t.is_complex() for t in tensors]) and not supports_complex(op):
raise RuntimeError(f"all_reduce does not support {op} on complex tensors")

tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors]

opts = AllreduceCoalescedOptions()
opts.reduceOp = op
if group == GroupMember.WORLD:
Expand Down Expand Up @@ -1114,6 +1167,8 @@ def all_gather_multigpu(output_tensor_lists,
Only nccl backend is currently supported
tensors should only be GPU tensors
Complex tensors are supported.
Arguments:
output_tensor_lists (List[List[Tensor]]): Output lists. It should
contain correctly-sized tensors on each GPU to be used for output
Expand Down Expand Up @@ -1149,6 +1204,9 @@ def all_gather_multigpu(output_tensor_lists,
if _rank_not_in_group(group):
return

output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists]
input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list]

if group == GroupMember.WORLD:
_check_default_pg()
work = _default_pg.allgather(output_tensor_lists, input_tensor_list)
Expand Down Expand Up @@ -1397,6 +1455,8 @@ def all_gather(tensor_list,
"""
Gathers tensors from the whole group in a list.
Complex tensors are supported.
Arguments:
tensor_list (list[Tensor]): Output list. It should contain
correctly-sized tensors to be used for output of the collective.
Expand All @@ -1408,12 +1468,44 @@ def all_gather(tensor_list,
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group
Examples:
>>> # All tensors below are of torch.int64 dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)]
>>> tensor_list
[tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1, 2]), tensor([3, 4])] # Rank 0
[tensor([1, 2]), tensor([3, 4])] # Rank 1
>>> # All tensors below are of torch.cfloat dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zero(2, dtype=torch.cfloat) for _ in range(2)]
>>> tensor_list
[tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
>>> tensor
tensor([1.+1.j, 2.+2.j]) # Rank 0
tensor([3.+3.j, 4.+4.j]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 0
[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 1
"""
_check_tensor_list(tensor_list, "tensor_list")
_check_single_tensor(tensor, "tensor")
if _rank_not_in_group(group):
return

tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list]
tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)

if group == GroupMember.WORLD:
_check_default_pg()
work = _default_pg.allgather([tensor_list], [tensor])
Expand All @@ -1432,6 +1524,8 @@ def all_gather_coalesced(output_tensor_lists,
"""
Gathers input tensors from the whole group in a list in a coalesced manner.
Complex tensors are supported.
Arguments:
output_tensor_lists (list[list[Tensor]]): Output list. It should contain
correctly-sized tensors to be used for output of the collective.
Expand Down Expand Up @@ -1480,6 +1574,9 @@ def all_gather_coalesced(output_tensor_lists,
for output_tensor_list in output_tensor_lists:
_check_tensor_list(output_tensor_list, "output_tensor_lists")

output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists]
input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list]

if group == GroupMember.WORLD:
_check_default_pg()
work = _default_pg.allgather_coalesced(
Expand Down

0 comments on commit 0977393

Please sign in to comment.