Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding complex support for distributed functions and . fix #45760 #45879

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
79 changes: 79 additions & 0 deletions torch/distributed/distributed_c10d.py
Expand Up @@ -869,6 +869,8 @@ def all_reduce_multigpu(tensor_list,
After the call, all ``tensor`` in ``tensor_list`` is going to be bitwise
identical in all processes.

Complex tensors are supported.

Only nccl and gloo backend is currently supported
tensors should only be GPU tensors

Expand All @@ -892,6 +894,8 @@ def all_reduce_multigpu(tensor_list,
if _rank_not_in_group(group):
return

tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list]

opts = AllreduceOptions()
opts.reduceOp = op
if group == GroupMember.WORLD:
Expand All @@ -916,6 +920,8 @@ def all_reduce(tensor,

After the call ``tensor`` is going to be bitwise identical in all processes.

Complex tensors are supported.

Arguments:
tensor (Tensor): Input and output of the collective. The function
operates in-place.
Expand All @@ -929,11 +935,36 @@ def all_reduce(tensor,
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group

Examples:
>>> # Tensors are all of dtype torch.int64.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - All tensors below are of torch.int64 dtype

>>> # We have 2 process groups, 2 ranks.
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> dist.all_reduce(tensor, op=ReduceOp.SUM)
>>> tensor
tensor([4, 6]) # Rank 0
tensor([4, 6]) # Rank 1

>>> # Tensors are all of dtype torch.complex64.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - All tensors below are of torch.complex64 dtype

>>> # We have 2 process groups, 2 ranks.
>>> tensor = torch.tensor([complex(1, 1), complex(2, 2)], dtype=torch.complex64) + 2 * complex(rank, rank)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cdouble) + 2 * rank * (1+1j)

Unlike C++, Python can interpret the imaginary number j. Also, maybe we should use torch.complex128 or torch.cdouble since above, we show an example of torch.int64 and not torch.int32

>>> tensor
tensor([1.+1.j, 2.+2.j]) # Rank 0
tensor([3.+3.j, 4.+4.j]) # Rank 1
>>> dist.all_reduce(tensor, op=ReduceOp.SUM)
>>> tensor
tensor([4.+4.j, 6.+6.j]) # Rank 0
tensor([4.+4.j, 6.+6.j]) # Rank 1

"""
_check_single_tensor(tensor, "tensor")
if _rank_not_in_group(group):
return

tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)

opts = AllreduceOptions()
opts.reduceOp = op
if group == GroupMember.WORLD:
Expand Down Expand Up @@ -967,6 +998,8 @@ def all_reduce_coalesced(tensors,
After the call each tensor in tensors is going to bitwise identical
in all processes.

Complex tensors are supported.

Arguments:
tensors (List[Tensor]): Input and output of the collective. The function
operates in-place.
Expand All @@ -985,6 +1018,8 @@ def all_reduce_coalesced(tensors,
if _rank_not_in_group(group):
return

tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors]

opts = AllreduceCoalescedOptions()
opts.reduceOp = op
if group == GroupMember.WORLD:
Expand Down Expand Up @@ -1114,6 +1149,8 @@ def all_gather_multigpu(output_tensor_lists,
Only nccl backend is currently supported
tensors should only be GPU tensors

Complex tensors are supported.

Arguments:
output_tensor_lists (List[List[Tensor]]): Output lists. It should
contain correctly-sized tensors on each GPU to be used for output
Expand Down Expand Up @@ -1149,6 +1186,9 @@ def all_gather_multigpu(output_tensor_lists,
if _rank_not_in_group(group):
return

output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists]
input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list]

if group == GroupMember.WORLD:
_check_default_pg()
work = _default_pg.allgather(output_tensor_lists, input_tensor_list)
Expand Down Expand Up @@ -1397,6 +1437,8 @@ def all_gather(tensor_list,
"""
Gathers tensors from the whole group in a list.

Complex tensors are supported.

Arguments:
tensor_list (list[Tensor]): Output list. It should contain
correctly-sized tensors to be used for output of the collective.
Expand All @@ -1408,12 +1450,44 @@ def all_gather(tensor_list,
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group

Examples:
>>> # Tensors are all of dtype torch.int64.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - All tensors below are of torch.int64 dtype

>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)]
>>> tensor_list
[tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1, 2]), tensor([3, 4])] # Rank 0
[tensor([1, 2]), tensor([3, 4])] # Rank 1

>>> # Tensors are all of dtype torch.complex64.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - All tensors below are of torch.complex64 dtype

>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zero(2, dtype=torch.complex64) for _ in range(2)]
>>> tensor_list
[tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1
>>> tensor = torch.tensor([complex(1, 1), complex(2, 2)], dtype=torch.complex64) + 2 * complex(rank, rank)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cdouble) + 2 * rank * (1+1j)

>>> tensor
tensor([1.+1.j, 2.+2.j]) # Rank 0
tensor([3.+3.j, 4.+4.j]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 0
[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 1

"""
_check_tensor_list(tensor_list, "tensor_list")
_check_single_tensor(tensor, "tensor")
if _rank_not_in_group(group):
return

tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list]
tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)

if group == GroupMember.WORLD:
_check_default_pg()
work = _default_pg.allgather([tensor_list], [tensor])
Expand All @@ -1432,6 +1506,8 @@ def all_gather_coalesced(output_tensor_lists,
"""
Gathers input tensors from the whole group in a list in a coalesced manner.

Complex tensors are supported.

Arguments:
output_tensor_lists (list[list[Tensor]]): Output list. It should contain
correctly-sized tensors to be used for output of the collective.
Expand Down Expand Up @@ -1480,6 +1556,9 @@ def all_gather_coalesced(output_tensor_lists,
for output_tensor_list in output_tensor_lists:
_check_tensor_list(output_tensor_list, "output_tensor_lists")

output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists]
input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list]

if group == GroupMember.WORLD:
_check_default_pg()
work = _default_pg.allgather_coalesced(
Expand Down