New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementing NumPy-like function torch.broadcast_to #48997
Implementing NumPy-like function torch.broadcast_to #48997
Conversation
5ba5902
to
c1ce7dd
Compare
💊 CI failures summary and remediationsAs of commit 058d827 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. This comment has been revised 28 times. |
Codecov Report
@@ Coverage Diff @@
## master #48997 +/- ##
=======================================
Coverage 80.56% 80.56%
=======================================
Files 1875 1875
Lines 202701 202715 +14
=======================================
+ Hits 163307 163327 +20
+ Misses 39394 39388 -6 |
48e1cf2
to
779857d
Compare
779857d
to
2f22297
Compare
This PR adds The differences between
Please kindly help review this PR. @mruberry |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @RockingJavaBean, thanks for implementing broadcast_to
!
Overall this PR looks very good. I made a few suggestions in the documentation, and I'd like to suggest trying to replace the method_tests
entries, which are being deprecated, with an OpInfo. Looking forward to hearing your thoughts!
@@ -1109,6 +1109,13 @@ def method_tests(): | |||
('view_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)), | |||
('view_as', (), (non_differentiable(torch.tensor(5.5)),), 'scalar'), | |||
('view_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'), | |||
('broadcast_to', (S, 1, 1), (dont_convert((S, S, S)),), '', (False,)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're actually starting to deprecate method_tests()
entries like these in favor of writing OpInfos for new operators, like this:
OpInfo('addmm', |
Basically, an OpInfo is added to this alphabetical list. It has to specify which dtypes it supports using the dtype functions defined here:
pytorch/torch/testing/__init__.py
Line 278 in 9417e92
_floating_types = _dispatch_dtypes((torch.float32, torch.float64)) |
and provide the same information as these method_tests()
entries like addmm does. For example, instead of these tuples the operator create a sample_inputs_func
that returns them.
This is a new thing, which is a pain, but I think it will be helpful to learn and I would appreciate your feedback on it. It may be a little trickier to setup than method_tests()
but it will make testing much easier.
One problem with OpInfos currently is that operations which support the bfloat16 datatype have to skip the following test:
skips=(
# RuntimeError: "rsqrt_cuda" not implemented for 'BFloat16'
SkipInfo(
'TestCommon', 'test_variant_consistency_jit',
device_type='cuda', dtypes=[torch.bfloat16])))
This is a bug in the test's implementation that we expect to fix soon. Please let me know if you have any questions about doing this. If creating an OpInfo is too much of a pain then we can ignore it for now and continue to use method_tests
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm truly grateful for this comprehensive explanation of the new OpInfo
test framework.
I think it is much more descriptive and readable compared to the method_tests
, as the long tuple line full of testing parameters is often hard to read.
Moreover, OpInfo
is awesome because it provides the standard way of testing newly-added operators. It helps us develop code in a more TDD style and write more robust code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have rewritten the test using OpInfo
, and updated the test_ops.py
as OpInfo
currently does not support testing operators whose parameters are not torch.Tensor
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it does support non-tensor parameters, but let me take a look and I'll comment inline.
@mruberry Thank you so much for the kind review and patient guidance on how to use As the current implementation in Line 208 in 5ab90b2
This PR adds a helper method in Please kindly take a look. |
9306345
to
7acc0ae
Compare
…h_broadcast_to
@@ -201,7 +194,7 @@ def _fn(t): | |||
arg_string = ', '.join((str(arg) for arg in info.get_args(device))) | |||
script = fn_template.format(alias_name=info.alias_name, args=arg_string) | |||
else: | |||
is_input_tensor_list = isinstance(info.get_input(device), collections.Sequence) | |||
is_input_tensor_list = isinstance(info.get_input(device), Sequence) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need to update this file, since this file doesn't currently use any OpInfos to test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sequence
here is imported from collections.abc
to replace collections.Sequence
.
The reason for the change is that collections.Sequence
will be deprecated after Python 3.9.
In [4]: collections.Sequence
/root/anaconda3/envs/xx/bin/ipython:1: DeprecationWarning: Using or importing the ABCs from 'collections'
instead of from 'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working
Other changes for test_op_aliases.py
has been reverted in the latest PR.
torch/_torch_docs.py
Outdated
r""" | ||
broadcast_to(input, shape) -> Tensor | ||
|
||
Broadcasts :attr:`input` to the :attr:`shape`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"... to the :attr:`shape`" -> "... to the shape :attr:`\shape`"
Sorry this didn't format correctly the first time I suggested it. I think the formatting is fixed now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for the kind suggestions for improving the doc, it has been updated accordingly.
((), (1, 3, 2)), | ||
) | ||
|
||
return [SampleInput((make_tensor(size, device, dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return these as a tuple, not a list, because tuples in Python are immutable. Lists are appropriate only when a mutable iterable is needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for pointing this out, I updated the sample_inputs_broadcast_to
to return a tuple
.
def clone_input_helper(input): | ||
if isinstance(input, Sequence): | ||
# use torch.clone for tensor, copy.deepcopy otherwise | ||
return list(map(lambda x: torch.clone(x) if isinstance(x, torch.Tensor) else deepcopy(x), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just map clone_input_helper so you're calling this recursively on sequences (that would also handle nested sequences)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for not considering the function of this helper carefully, the helper has been updated as a recursive method to support nested sequences
.
if isinstance(input, torch.Tensor): | ||
return torch.clone(input) | ||
else: | ||
return deepcopy(input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can get away without copying non-tensor inputs, since I don't think we have any functions that edit non-tensor inputs?
Ideally we'd also only clone when we're performing the operation inplace, but that can be a TODO for a later PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the advice, the clone_input_helper
do not copy the non-tensor inputs in the latest PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_ops.py
is updated to only clone tensor-type inputs for inplace operations.
variant_forward = variant(*(clone_input_helper(input) if variant is inplace else input
for input in sample.input),
*sample.args,
**sample.kwargs)
self.assertEqual(variant_forward, expected_forward)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @RockingJavaBean!
I really like the approach you took here. The new helper function for making copies makes sense. I have a few very small comments but also suggest reverting the changes to test_op_aliases.py, since those tests are not affected by this PR. In the future test_op_aliases.py will be refactored to use OpInfos, and that will be a better opportunity to update its test.
@mruberry I'm really thankful for your kind and thorough review.
Please kindly take a look. |
54763f7
to
ad226c0
Compare
…h_broadcast_to
@@ -283,6 +283,22 @@ def sample_inputs_addmm(op_info, device, dtype, requires_grad): | |||
low=None, high=None, | |||
requires_grad=False))),) | |||
|
|||
def sample_inputs_broadcast_to(op_info, device, dtype, requires_grad): | |||
test_cases = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a nice method of test generation.
test_inplace_grad=False, | ||
sample_inputs_func=sample_inputs_broadcast_to, | ||
skips=( | ||
# RuntimeError: "isfinite" not implemented for 'BFloat16' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After a rebase this skip should no longer be necessary - would you try removing it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for pointing this out, skips
has been removed and the Opinfo-based broacast_to
tests are passed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @RockingJavaBean, this looks great! Would you try removing the one skipped test and rebasing to resolve the merge conflict? After the rebase this should be OK to merge.
Looking forward to having broadcast_to
in PyTorch!
@mruberry I'm deeply grateful for the code review of this PR, and your exhaustive guidance regarding writing tests using the newly introduced |
By clicking the details pending CircleCI check, all steps of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work, @RockingJavaBean!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Related pytorch#38349 Implement NumPy-like function `torch.broadcast_to` to broadcast the input tensor to a new shape. Pull Request resolved: pytorch#48997 Reviewed By: anjali411, ngimel Differential Revision: D25663937 Pulled By: mruberry fbshipit-source-id: 0415c03f92f02684983f412666d0a44515b99373
Related #38349
Implement NumPy-like function
torch.broadcast_to
to broadcast the input tensor to a new shape.