New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding complex support for distributed functions and . fix #45760 #45879
Changes from 4 commits
0c15698
d590fe6
517f458
8c0578d
2c5e872
74cb182
cb38d26
ef8c031
e8fcc0b
e048c75
ef546af
3d6a692
8ddb50a
375fe36
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -869,6 +869,8 @@ def all_reduce_multigpu(tensor_list, | |
After the call, all ``tensor`` in ``tensor_list`` is going to be bitwise | ||
identical in all processes. | ||
|
||
Complex tensors are supported. | ||
|
||
Only nccl and gloo backend is currently supported | ||
tensors should only be GPU tensors | ||
|
||
|
@@ -892,6 +894,8 @@ def all_reduce_multigpu(tensor_list, | |
if _rank_not_in_group(group): | ||
return | ||
|
||
tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list] | ||
|
||
opts = AllreduceOptions() | ||
opts.reduceOp = op | ||
if group == GroupMember.WORLD: | ||
|
@@ -916,6 +920,8 @@ def all_reduce(tensor, | |
|
||
After the call ``tensor`` is going to be bitwise identical in all processes. | ||
|
||
Complex tensors are supported. | ||
|
||
Arguments: | ||
tensor (Tensor): Input and output of the collective. The function | ||
operates in-place. | ||
|
@@ -929,11 +935,36 @@ def all_reduce(tensor, | |
Async work handle, if async_op is set to True. | ||
None, if not async_op or if not part of the group | ||
|
||
Examples: | ||
>>> # Tensors are all of dtype torch.int64. | ||
>>> # We have 2 process groups, 2 ranks. | ||
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank | ||
>>> tensor | ||
tensor([1, 2]) # Rank 0 | ||
tensor([3, 4]) # Rank 1 | ||
>>> dist.all_reduce(tensor, op=ReduceOp.SUM) | ||
>>> tensor | ||
tensor([4, 6]) # Rank 0 | ||
tensor([4, 6]) # Rank 1 | ||
|
||
>>> # Tensors are all of dtype torch.complex64. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit - All tensors below are of torch.complex64 dtype |
||
>>> # We have 2 process groups, 2 ranks. | ||
>>> tensor = torch.tensor([complex(1, 1), complex(2, 2)], dtype=torch.complex64) + 2 * complex(rank, rank) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Unlike C++, Python can interpret the imaginary number |
||
>>> tensor | ||
tensor([1.+1.j, 2.+2.j]) # Rank 0 | ||
tensor([3.+3.j, 4.+4.j]) # Rank 1 | ||
>>> dist.all_reduce(tensor, op=ReduceOp.SUM) | ||
>>> tensor | ||
tensor([4.+4.j, 6.+6.j]) # Rank 0 | ||
tensor([4.+4.j, 6.+6.j]) # Rank 1 | ||
|
||
""" | ||
_check_single_tensor(tensor, "tensor") | ||
if _rank_not_in_group(group): | ||
return | ||
|
||
tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) | ||
|
||
opts = AllreduceOptions() | ||
opts.reduceOp = op | ||
if group == GroupMember.WORLD: | ||
|
@@ -967,6 +998,8 @@ def all_reduce_coalesced(tensors, | |
After the call each tensor in tensors is going to bitwise identical | ||
in all processes. | ||
|
||
Complex tensors are supported. | ||
|
||
Arguments: | ||
tensors (List[Tensor]): Input and output of the collective. The function | ||
operates in-place. | ||
|
@@ -985,6 +1018,8 @@ def all_reduce_coalesced(tensors, | |
if _rank_not_in_group(group): | ||
return | ||
|
||
tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors] | ||
|
||
opts = AllreduceCoalescedOptions() | ||
opts.reduceOp = op | ||
if group == GroupMember.WORLD: | ||
|
@@ -1114,6 +1149,8 @@ def all_gather_multigpu(output_tensor_lists, | |
Only nccl backend is currently supported | ||
tensors should only be GPU tensors | ||
|
||
Complex tensors are supported. | ||
|
||
Arguments: | ||
output_tensor_lists (List[List[Tensor]]): Output lists. It should | ||
contain correctly-sized tensors on each GPU to be used for output | ||
|
@@ -1149,6 +1186,9 @@ def all_gather_multigpu(output_tensor_lists, | |
if _rank_not_in_group(group): | ||
return | ||
|
||
output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists] | ||
input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list] | ||
|
||
if group == GroupMember.WORLD: | ||
_check_default_pg() | ||
work = _default_pg.allgather(output_tensor_lists, input_tensor_list) | ||
|
@@ -1397,6 +1437,8 @@ def all_gather(tensor_list, | |
""" | ||
Gathers tensors from the whole group in a list. | ||
|
||
Complex tensors are supported. | ||
|
||
Arguments: | ||
tensor_list (list[Tensor]): Output list. It should contain | ||
correctly-sized tensors to be used for output of the collective. | ||
|
@@ -1408,12 +1450,44 @@ def all_gather(tensor_list, | |
Async work handle, if async_op is set to True. | ||
None, if not async_op or if not part of the group | ||
|
||
Examples: | ||
>>> # Tensors are all of dtype torch.int64. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit - All tensors below are of torch.int64 dtype |
||
>>> # We have 2 process groups, 2 ranks. | ||
>>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)] | ||
>>> tensor_list | ||
[tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1 | ||
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank | ||
>>> tensor | ||
tensor([1, 2]) # Rank 0 | ||
tensor([3, 4]) # Rank 1 | ||
>>> dist.all_gather(tensor_list, tensor) | ||
>>> tensor_list | ||
[tensor([1, 2]), tensor([3, 4])] # Rank 0 | ||
[tensor([1, 2]), tensor([3, 4])] # Rank 1 | ||
|
||
>>> # Tensors are all of dtype torch.complex64. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit - All tensors below are of torch.complex64 dtype |
||
>>> # We have 2 process groups, 2 ranks. | ||
>>> tensor_list = [torch.zero(2, dtype=torch.complex64) for _ in range(2)] | ||
>>> tensor_list | ||
[tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1 | ||
>>> tensor = torch.tensor([complex(1, 1), complex(2, 2)], dtype=torch.complex64) + 2 * complex(rank, rank) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
>>> tensor | ||
tensor([1.+1.j, 2.+2.j]) # Rank 0 | ||
tensor([3.+3.j, 4.+4.j]) # Rank 1 | ||
>>> dist.all_gather(tensor_list, tensor) | ||
>>> tensor_list | ||
[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 0 | ||
[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 1 | ||
|
||
""" | ||
_check_tensor_list(tensor_list, "tensor_list") | ||
_check_single_tensor(tensor, "tensor") | ||
if _rank_not_in_group(group): | ||
return | ||
|
||
tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list] | ||
tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) | ||
|
||
if group == GroupMember.WORLD: | ||
_check_default_pg() | ||
work = _default_pg.allgather([tensor_list], [tensor]) | ||
|
@@ -1432,6 +1506,8 @@ def all_gather_coalesced(output_tensor_lists, | |
""" | ||
Gathers input tensors from the whole group in a list in a coalesced manner. | ||
|
||
Complex tensors are supported. | ||
|
||
Arguments: | ||
output_tensor_lists (list[list[Tensor]]): Output list. It should contain | ||
correctly-sized tensors to be used for output of the collective. | ||
|
@@ -1480,6 +1556,9 @@ def all_gather_coalesced(output_tensor_lists, | |
for output_tensor_list in output_tensor_lists: | ||
_check_tensor_list(output_tensor_list, "output_tensor_lists") | ||
|
||
output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists] | ||
input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list] | ||
|
||
if group == GroupMember.WORLD: | ||
_check_default_pg() | ||
work = _default_pg.allgather_coalesced( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - All tensors below are of torch.int64 dtype