Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding complex support for distributed functions and . fix #45760 #45879

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions torch/csrc/distributed/c10d/init.cpp
Expand Up @@ -220,6 +220,8 @@ An enum-like class for available reduction operations: ``SUM``, ``PRODUCT``,
Note that ``BAND``, ``BOR``, and ``BXOR`` reductions are not available when
using the ``NCCL`` backend.

Additionally, ``MAX``, ``MIN`` and ``PRODUCT`` are not supported for complex tensors.

The values of this class can be accessed as attributes, e.g., ``ReduceOp.SUM``.
They are used in specifying strategies for reduction collectives, e.g.,
:func:`reduce`, :func:`all_reduce_multigpu`, etc.)")
Expand Down
97 changes: 97 additions & 0 deletions torch/distributed/distributed_c10d.py
Expand Up @@ -43,6 +43,18 @@
except ImportError:
_GLOO_AVAILABLE = False

# Some reduce ops are not supported by complex numbers.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the way this comment is written reads like we allow calling them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gchanan quick fix here: #46599

# We currently provide complex support to the distributed API by viewing
# complex tensors as real (torch.view_as_real), meaning that calling
# these unsupported ops will return garbage values rather than error out.
# (e.g. max(2+3i, 3+2i) = 3+3i)
# We'd like calls to unsupported ops to error out accordingly,
# rather than returning garbage values.
def supports_complex(reduceOp: ReduceOp) -> bool:
if reduceOp == ReduceOp.MAX or reduceOp == ReduceOp.MIN or reduceOp == ReduceOp.PRODUCT:
return False
return True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this same as:

return True if reduceOp == ReduceOp.SUM else False?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleaned this up a little to make it more pythonic.



class Backend(object):
"""
Expand Down Expand Up @@ -869,6 +881,8 @@ def all_reduce_multigpu(tensor_list,
After the call, all ``tensor`` in ``tensor_list`` is going to be bitwise
identical in all processes.

Complex tensors are supported.

Only nccl and gloo backend is currently supported
tensors should only be GPU tensors

Expand All @@ -892,6 +906,8 @@ def all_reduce_multigpu(tensor_list,
if _rank_not_in_group(group):
return

tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list]

opts = AllreduceOptions()
opts.reduceOp = op
if group == GroupMember.WORLD:
Expand All @@ -916,6 +932,8 @@ def all_reduce(tensor,

After the call ``tensor`` is going to be bitwise identical in all processes.

Complex tensors are supported.

Arguments:
tensor (Tensor): Input and output of the collective. The function
operates in-place.
Expand All @@ -929,11 +947,39 @@ def all_reduce(tensor,
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group

Examples:
>>> # All tensors below are of torch.int64 type.
>>> # We have 2 process groups, 2 ranks.
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> dist.all_reduce(tensor, op=ReduceOp.SUM)
>>> tensor
tensor([4, 6]) # Rank 0
tensor([4, 6]) # Rank 1

>>> # All tensors below are of torch.cfloat type.
>>> # We have 2 process groups, 2 ranks.
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
>>> tensor
tensor([1.+1.j, 2.+2.j]) # Rank 0
tensor([3.+3.j, 4.+4.j]) # Rank 1
>>> dist.all_reduce(tensor, op=ReduceOp.SUM)
>>> tensor
tensor([4.+4.j, 6.+6.j]) # Rank 0
tensor([4.+4.j, 6.+6.j]) # Rank 1

"""
_check_single_tensor(tensor, "tensor")
if _rank_not_in_group(group):
return

if tensor.is_complex():
if not supports_complex(op):
raise RuntimeError(f"{op} is unsupported on complex tensors")
tensor = torch.view_as_real(tensor)

opts = AllreduceOptions()
opts.reduceOp = op
if group == GroupMember.WORLD:
Expand Down Expand Up @@ -967,6 +1013,8 @@ def all_reduce_coalesced(tensors,
After the call each tensor in tensors is going to bitwise identical
in all processes.

Complex tensors are supported.

Arguments:
tensors (List[Tensor]): Input and output of the collective. The function
operates in-place.
Expand All @@ -985,6 +1033,11 @@ def all_reduce_coalesced(tensors,
if _rank_not_in_group(group):
return

if any([t.is_complex() for t in tensors]) and not supports_complex(op):
raise RuntimeError(f"all_reduce does not support {op} on complex tensors")

tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors]

opts = AllreduceCoalescedOptions()
opts.reduceOp = op
if group == GroupMember.WORLD:
Expand Down Expand Up @@ -1114,6 +1167,8 @@ def all_gather_multigpu(output_tensor_lists,
Only nccl backend is currently supported
tensors should only be GPU tensors

Complex tensors are supported.

Arguments:
output_tensor_lists (List[List[Tensor]]): Output lists. It should
contain correctly-sized tensors on each GPU to be used for output
Expand Down Expand Up @@ -1149,6 +1204,9 @@ def all_gather_multigpu(output_tensor_lists,
if _rank_not_in_group(group):
return

output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists]
input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list]

if group == GroupMember.WORLD:
_check_default_pg()
work = _default_pg.allgather(output_tensor_lists, input_tensor_list)
Expand Down Expand Up @@ -1397,6 +1455,8 @@ def all_gather(tensor_list,
"""
Gathers tensors from the whole group in a list.

Complex tensors are supported.

Arguments:
tensor_list (list[Tensor]): Output list. It should contain
correctly-sized tensors to be used for output of the collective.
Expand All @@ -1408,12 +1468,44 @@ def all_gather(tensor_list,
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group

Examples:
>>> # All tensors below are of torch.int64 dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)]
>>> tensor_list
[tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1, 2]), tensor([3, 4])] # Rank 0
[tensor([1, 2]), tensor([3, 4])] # Rank 1

>>> # All tensors below are of torch.cfloat dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zero(2, dtype=torch.cfloat) for _ in range(2)]
>>> tensor_list
[tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
>>> tensor
tensor([1.+1.j, 2.+2.j]) # Rank 0
tensor([3.+3.j, 4.+4.j]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 0
[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 1

"""
_check_tensor_list(tensor_list, "tensor_list")
_check_single_tensor(tensor, "tensor")
if _rank_not_in_group(group):
return

tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list]
tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)

if group == GroupMember.WORLD:
_check_default_pg()
work = _default_pg.allgather([tensor_list], [tensor])
Expand All @@ -1432,6 +1524,8 @@ def all_gather_coalesced(output_tensor_lists,
"""
Gathers input tensors from the whole group in a list in a coalesced manner.

Complex tensors are supported.

Arguments:
output_tensor_lists (list[list[Tensor]]): Output list. It should contain
correctly-sized tensors to be used for output of the collective.
Expand Down Expand Up @@ -1480,6 +1574,9 @@ def all_gather_coalesced(output_tensor_lists,
for output_tensor_list in output_tensor_lists:
_check_tensor_list(output_tensor_list, "output_tensor_lists")

output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists]
input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list]

if group == GroupMember.WORLD:
_check_default_pg()
work = _default_pg.allgather_coalesced(
Expand Down