New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding complex support for distributed functions and . fix #45760 #45879
Closed
Closed
Changes from 10 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
0c15698
adding complex support for distributed functions and . fix #45760
bdhirsh d590fe6
Update on "adding complex support for distributed functions and . fi…
bdhirsh 517f458
Update on "adding complex support for distributed functions and . fi…
bdhirsh 8c0578d
Update on "adding complex support for distributed functions and . fi…
bdhirsh 2c5e872
Update on "adding complex support for distributed functions and . fi…
bdhirsh 74cb182
Update on "adding complex support for distributed functions and . fi…
bdhirsh cb38d26
Update on "adding complex support for distributed functions and . fi…
bdhirsh ef8c031
Update on "adding complex support for distributed functions and . fi…
bdhirsh e8fcc0b
Update on "adding complex support for distributed functions and . fi…
bdhirsh e048c75
Update on "adding complex support for distributed functions and . fi…
bdhirsh ef546af
Update on "adding complex support for distributed functions and . fi…
bdhirsh 3d6a692
Update on "adding complex support for distributed functions and . fi…
bdhirsh 8ddb50a
Update on "adding complex support for distributed functions and . fi…
bdhirsh 375fe36
Update on "adding complex support for distributed functions and . fi…
bdhirsh File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,6 +43,18 @@ | |
except ImportError: | ||
_GLOO_AVAILABLE = False | ||
|
||
# Some reduce ops are not supported by complex numbers. | ||
# We currently provide complex support to the distributed API by viewing | ||
# complex tensors as real (torch.view_as_real), meaning that calling | ||
# these unsupported ops will return garbage values rather than error out. | ||
# (e.g. max(2+3i, 3+2i) = 3+3i) | ||
# We'd like calls to unsupported ops to error out accordingly, | ||
# rather than returning garbage values. | ||
def supports_complex(reduceOp: ReduceOp) -> bool: | ||
if reduceOp == ReduceOp.MAX or reduceOp == ReduceOp.MIN or reduceOp == ReduceOp.PRODUCT: | ||
return False | ||
return True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this same as: return True if reduceOp == ReduceOp.SUM else False? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cleaned this up a little to make it more pythonic. |
||
|
||
|
||
class Backend(object): | ||
""" | ||
|
@@ -869,6 +881,8 @@ def all_reduce_multigpu(tensor_list, | |
After the call, all ``tensor`` in ``tensor_list`` is going to be bitwise | ||
identical in all processes. | ||
|
||
Complex tensors are supported. | ||
|
||
Only nccl and gloo backend is currently supported | ||
tensors should only be GPU tensors | ||
|
||
|
@@ -892,6 +906,8 @@ def all_reduce_multigpu(tensor_list, | |
if _rank_not_in_group(group): | ||
return | ||
|
||
tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list] | ||
|
||
opts = AllreduceOptions() | ||
opts.reduceOp = op | ||
if group == GroupMember.WORLD: | ||
|
@@ -916,6 +932,8 @@ def all_reduce(tensor, | |
|
||
After the call ``tensor`` is going to be bitwise identical in all processes. | ||
|
||
Complex tensors are supported. | ||
|
||
Arguments: | ||
tensor (Tensor): Input and output of the collective. The function | ||
operates in-place. | ||
|
@@ -929,11 +947,39 @@ def all_reduce(tensor, | |
Async work handle, if async_op is set to True. | ||
None, if not async_op or if not part of the group | ||
|
||
Examples: | ||
>>> # All tensors below are of torch.int64 type. | ||
>>> # We have 2 process groups, 2 ranks. | ||
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank | ||
>>> tensor | ||
tensor([1, 2]) # Rank 0 | ||
tensor([3, 4]) # Rank 1 | ||
>>> dist.all_reduce(tensor, op=ReduceOp.SUM) | ||
>>> tensor | ||
tensor([4, 6]) # Rank 0 | ||
tensor([4, 6]) # Rank 1 | ||
|
||
>>> # All tensors below are of torch.cfloat type. | ||
>>> # We have 2 process groups, 2 ranks. | ||
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j) | ||
>>> tensor | ||
tensor([1.+1.j, 2.+2.j]) # Rank 0 | ||
tensor([3.+3.j, 4.+4.j]) # Rank 1 | ||
>>> dist.all_reduce(tensor, op=ReduceOp.SUM) | ||
>>> tensor | ||
tensor([4.+4.j, 6.+6.j]) # Rank 0 | ||
tensor([4.+4.j, 6.+6.j]) # Rank 1 | ||
|
||
""" | ||
_check_single_tensor(tensor, "tensor") | ||
if _rank_not_in_group(group): | ||
return | ||
|
||
if tensor.is_complex(): | ||
if not supports_complex(op): | ||
raise RuntimeError(f"{op} is unsupported on complex tensors") | ||
tensor = torch.view_as_real(tensor) | ||
|
||
opts = AllreduceOptions() | ||
opts.reduceOp = op | ||
if group == GroupMember.WORLD: | ||
|
@@ -967,6 +1013,8 @@ def all_reduce_coalesced(tensors, | |
After the call each tensor in tensors is going to bitwise identical | ||
in all processes. | ||
|
||
Complex tensors are supported. | ||
|
||
Arguments: | ||
tensors (List[Tensor]): Input and output of the collective. The function | ||
operates in-place. | ||
|
@@ -985,6 +1033,11 @@ def all_reduce_coalesced(tensors, | |
if _rank_not_in_group(group): | ||
return | ||
|
||
if any([t.is_complex() for t in tensors]) and not supports_complex(op): | ||
raise RuntimeError(f"all_reduce does not support {op} on complex tensors") | ||
|
||
tensors = [t if not t.is_complex() else torch.view_as_real(t) for t in tensors] | ||
|
||
opts = AllreduceCoalescedOptions() | ||
opts.reduceOp = op | ||
if group == GroupMember.WORLD: | ||
|
@@ -1114,6 +1167,8 @@ def all_gather_multigpu(output_tensor_lists, | |
Only nccl backend is currently supported | ||
tensors should only be GPU tensors | ||
|
||
Complex tensors are supported. | ||
|
||
Arguments: | ||
output_tensor_lists (List[List[Tensor]]): Output lists. It should | ||
contain correctly-sized tensors on each GPU to be used for output | ||
|
@@ -1149,6 +1204,9 @@ def all_gather_multigpu(output_tensor_lists, | |
if _rank_not_in_group(group): | ||
return | ||
|
||
output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists] | ||
input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list] | ||
|
||
if group == GroupMember.WORLD: | ||
_check_default_pg() | ||
work = _default_pg.allgather(output_tensor_lists, input_tensor_list) | ||
|
@@ -1397,6 +1455,8 @@ def all_gather(tensor_list, | |
""" | ||
Gathers tensors from the whole group in a list. | ||
|
||
Complex tensors are supported. | ||
|
||
Arguments: | ||
tensor_list (list[Tensor]): Output list. It should contain | ||
correctly-sized tensors to be used for output of the collective. | ||
|
@@ -1408,12 +1468,44 @@ def all_gather(tensor_list, | |
Async work handle, if async_op is set to True. | ||
None, if not async_op or if not part of the group | ||
|
||
Examples: | ||
>>> # All tensors below are of torch.int64 dtype. | ||
>>> # We have 2 process groups, 2 ranks. | ||
>>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)] | ||
>>> tensor_list | ||
[tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1 | ||
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank | ||
>>> tensor | ||
tensor([1, 2]) # Rank 0 | ||
tensor([3, 4]) # Rank 1 | ||
>>> dist.all_gather(tensor_list, tensor) | ||
>>> tensor_list | ||
[tensor([1, 2]), tensor([3, 4])] # Rank 0 | ||
[tensor([1, 2]), tensor([3, 4])] # Rank 1 | ||
|
||
>>> # All tensors below are of torch.cfloat dtype. | ||
>>> # We have 2 process groups, 2 ranks. | ||
>>> tensor_list = [torch.zero(2, dtype=torch.cfloat) for _ in range(2)] | ||
>>> tensor_list | ||
[tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1 | ||
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j) | ||
>>> tensor | ||
tensor([1.+1.j, 2.+2.j]) # Rank 0 | ||
tensor([3.+3.j, 4.+4.j]) # Rank 1 | ||
>>> dist.all_gather(tensor_list, tensor) | ||
>>> tensor_list | ||
[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 0 | ||
[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 1 | ||
|
||
""" | ||
_check_tensor_list(tensor_list, "tensor_list") | ||
_check_single_tensor(tensor, "tensor") | ||
if _rank_not_in_group(group): | ||
return | ||
|
||
tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list] | ||
tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor) | ||
|
||
if group == GroupMember.WORLD: | ||
_check_default_pg() | ||
work = _default_pg.allgather([tensor_list], [tensor]) | ||
|
@@ -1432,6 +1524,8 @@ def all_gather_coalesced(output_tensor_lists, | |
""" | ||
Gathers input tensors from the whole group in a list in a coalesced manner. | ||
|
||
Complex tensors are supported. | ||
|
||
Arguments: | ||
output_tensor_lists (list[list[Tensor]]): Output list. It should contain | ||
correctly-sized tensors to be used for output of the collective. | ||
|
@@ -1480,6 +1574,9 @@ def all_gather_coalesced(output_tensor_lists, | |
for output_tensor_list in output_tensor_lists: | ||
_check_tensor_list(output_tensor_list, "output_tensor_lists") | ||
|
||
output_tensor_lists = [[t if not t.is_complex() else torch.view_as_real(t) for t in l] for l in output_tensor_lists] | ||
input_tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in input_tensor_list] | ||
|
||
if group == GroupMember.WORLD: | ||
_check_default_pg() | ||
work = _default_pg.allgather_coalesced( | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the way this comment is written reads like we allow calling them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gchanan quick fix here: #46599