New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding complex support for distributed functions and . fix #45760 #45879
Conversation
[ghstack-poisoned]
if value is None: | ||
value = size | ||
return torch.FloatTensor(size=[dim_size for _ in range(dim)]).fill_(value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the use case for typed Tensor classes that are typed on their dtype, vs. the generic torch.Tensor class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure about the history here, but I think your solution is better.
@@ -1060,6 +1057,41 @@ def test_all_reduce_sum_cuda(self): | |||
rank_to_GPU, | |||
) | |||
|
|||
@unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors") | |||
def test_all_reduce_sum_complex(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only bothered adding complex tests for a single reduction op (sum
) since it's unrelated to the actual op logic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update to that since the above comment isn't really true- added explicit tests that a complex-unsupported reduceOp like Max
should error out properly :)
…45760" [ghstack-poisoned]
ghstack-source-id: ff291d0f75451c20b1cdfd7d93738fb252a60fc9 Pull Request resolved: #45879
Codecov Report
@@ Coverage Diff @@
## gh/bdhirsh/22/base #45879 +/- ##
======================================================
- Coverage 68.20% 68.17% -0.03%
======================================================
Files 410 410
Lines 53453 53516 +63
======================================================
+ Hits 36457 36484 +27
- Misses 16996 17032 +36
Continue to review full report at Codecov.
|
…45760" Differential Revision: [D24127949](https://our.internmc.facebook.com/intern/diff/D24127949) [ghstack-poisoned]
updated docs ghstack-source-id: 6867916d2a3316d5896421a8906b1f64cb1495ad Pull Request resolved: #45879
@@ -929,11 +935,32 @@ def all_reduce(tensor, | |||
Async work handle, if async_op is set to True. | |||
None, if not async_op or if not part of the group | |||
|
|||
Example: | |||
Tensors are all of dtype torch.int64. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
double quote on all code
``torch.int64``
``tensor = [[1, 1], [2, 2]]``
rank 1 passes: | ||
tensor = [[3+3i, 3+3i], [4+4i, 4+4i]] | ||
both rank 0 and 1 get: | ||
tensor = [[4+4i, 4+4i], [6+6i, 6+6i]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please build the doc to verify that this renders correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if value is None: | ||
value = size | ||
return torch.FloatTensor(size=[dim_size for _ in range(dim)]).fill_(value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure about the history here, but I think your solution is better.
…45760" Differential Revision: [D24127949](https://our.internmc.facebook.com/intern/diff/D24127949) [ghstack-poisoned]
updated docs used standard python repl examples in the docs, tested the way that they render in the browser ghstack-source-id: f0fe2fdb0a1abf6f7a30019d132cc3d5900e8fd6 Pull Request resolved: #45879
|
||
>>> # Tensors are all of dtype torch.complex64. | ||
>>> # We have 2 process groups, 2 ranks. | ||
>>> tensor = torch.tensor([complex(1, 1), complex(2, 2)], dtype=torch.complex64) + 2 * complex(rank, rank) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cdouble) + 2 * rank * (1+1j)
Unlike C++, Python can interpret the imaginary number j
. Also, maybe we should use torch.complex128
or torch.cdouble
since above, we show an example of torch.int64
and not torch.int32
tensor([4, 6]) # Rank 0 | ||
tensor([4, 6]) # Rank 1 | ||
|
||
>>> # Tensors are all of dtype torch.complex64. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - All tensors below are of torch.complex64 dtype
@@ -929,11 +935,36 @@ def all_reduce(tensor, | |||
Async work handle, if async_op is set to True. | |||
None, if not async_op or if not part of the group | |||
|
|||
Examples: | |||
>>> # Tensors are all of dtype torch.int64. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - All tensors below are of torch.int64 dtype
@@ -1408,12 +1450,44 @@ def all_gather(tensor_list, | |||
Async work handle, if async_op is set to True. | |||
None, if not async_op or if not part of the group | |||
|
|||
Examples: | |||
>>> # Tensors are all of dtype torch.int64. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - All tensors below are of torch.int64 dtype
[tensor([1, 2]), tensor([3, 4])] # Rank 0 | ||
[tensor([1, 2]), tensor([3, 4])] # Rank 1 | ||
|
||
>>> # Tensors are all of dtype torch.complex64. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - All tensors below are of torch.complex64 dtype
>>> tensor_list = [torch.zero(2, dtype=torch.complex64) for _ in range(2)] | ||
>>> tensor_list | ||
[tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1 | ||
>>> tensor = torch.tensor([complex(1, 1), complex(2, 2)], dtype=torch.complex64) + 2 * complex(rank, rank) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cdouble) + 2 * rank * (1+1j)
) | ||
|
||
@staticmethod | ||
def _all_reduce_coalesced_min_test_cases(group_size): | ||
return ( | ||
[1, 4], | ||
[2, 3], | ||
[1, 3] | ||
[1, 3], | ||
[torch.float, torch.float], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
might be useful to explicitly mention in the documentation here: https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp that min, max are not supported for complex tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed. I also added explicit error checking for that case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry I messed up in copy pasting earlier. updated the comment with the link to the doc I was referring to
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep that's probably a better place to put it. added
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does complex support all of the reduce ops? Like, I don't think we support max with complex and it doesn't seem like viewing it as a real will give you something that makes a ton of sense.
…45760" Differential Revision: [D24127949](https://our.internmc.facebook.com/intern/diff/D24127949) [ghstack-poisoned]
|
||
>>> # All tensors below are of torch.cdouble type. | ||
>>> # We have 2 process groups, 2 ranks. | ||
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cdouble) + 2 * rank * (1+1j) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually maybe it's best to let it be torch.cfloat
because that's the default complex type if the default dtype is set to torch.float. This is consistent with the above example since default int dtype is torch.long
or torch.int64
.
If we want to use torch.cdouble
, all the following prints of tensor, would look like:
tensor([1.+1.j, 2.+2.j], dtype=torch.complex128)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. done
…45760" Differential Revision: [D24127949](https://our.internmc.facebook.com/intern/diff/D24127949) [ghstack-poisoned]
|
||
>>> # All tensors below are of torch.cfloat dtype. | ||
>>> # We have 2 process groups, 2 ranks. | ||
>>> tensor_list = [torch.zero(2, dtype=torch.float) for _ in range(2)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - torch.zero(2, dtype=torch.cfloat)
…45760" Differential Revision: [D24127949](https://our.internmc.facebook.com/intern/diff/D24127949) [ghstack-poisoned]
…45760" Differential Revision: [D24127949](https://our.internmc.facebook.com/intern/diff/D24127949) [ghstack-poisoned]
updated docs used standard python repl examples in the docs, tested the way that they render in the browser more doc fixes. Add an explicit error check for ReduceOps that do not support complex (Max and Min), + tests for that case ghstack-source-id: 4920be3e0cb551612c2f76a4fbcea2444f097558 Pull Request resolved: #45879
💊 CI failures summary and remediationsAs of commit 375fe36 (more details on the Dr. CI page):
codecov.io: 1 failed
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 13 times. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code LGTM! Please also get a stamp from @anjali411
Besides, there are too many irrelevant test failures. Please rebase and rerun tests before landing.
if reduceOp == ReduceOp.MAX or reduceOp == ReduceOp.MIN or reduceOp == ReduceOp.PRODUCT: | ||
return False | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this same as:
return True if reduceOp == ReduceOp.SUM else False?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cleaned this up a little to make it more pythonic.
…45760" Differential Revision: [D24127949](https://our.internmc.facebook.com/intern/diff/D24127949) [ghstack-poisoned]
…45760" Differential Revision: [D24127949](https://our.internmc.facebook.com/intern/diff/D24127949) [ghstack-poisoned]
…45760" Differential Revision: [D24127949](https://our.internmc.facebook.com/intern/diff/D24127949) [ghstack-poisoned]
updated docs used standard python repl examples in the docs, tested the way that they render in the browser more doc fixes. Add an explicit error check for ReduceOps that do not support complex (Max and Min), + tests for that case make error checking a bit more pythonic ghstack-source-id: 57babd5380cf8eb464b66114a6532011b9aea4ac Pull Request resolved: #45879
# We'd like calls to unsupported ops to error out accordingly, | ||
# rather than returning garbage values. | ||
def supports_complex(reduceOp: ReduceOp) -> bool: | ||
denyList = [ReduceOp.MAX, ReduceOp.MIN, ReduceOp.PRODUCT] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks great!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall. thanks Brian! My only other comment would be that we should add tests to ensure BAND
, BOR
, and BXOR
work for complex.
@mrshenli will torch.distributed.autograd
also work after this change?
…45760" Differential Revision: [D24127949](https://our.internmc.facebook.com/intern/diff/D24127949) [ghstack-poisoned]
updated docs used standard python repl examples in the docs, tested the way that they render in the browser more doc fixes. Add an explicit error check for ReduceOps that do not support complex (Max and Min), + tests for that case make error checking a bit more pythonic ghstack-source-id: bb0e96bc28664a059d8415124cc809556010ac6b Pull Request resolved: #45879
@@ -44,6 +44,17 @@ | |||
except ImportError: | |||
_GLOO_AVAILABLE = False | |||
|
|||
# Some reduce ops are not supported by complex numbers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the way this comment is written reads like we allow calling them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stack from ghstack:
Differential Revision: D24127949