Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing NumPy-like function torch.broadcast_to #48997

Closed

Conversation

RockingJavaBean
Copy link
Contributor

@RockingJavaBean RockingJavaBean commented Dec 8, 2020

Related #38349

Implement NumPy-like function torch.broadcast_to to broadcast the input tensor to a new shape.

@dr-ci
Copy link

dr-ci bot commented Dec 8, 2020

💊 CI failures summary and remediations

As of commit 058d827 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

This comment has been revised 28 times.

@codecov
Copy link

codecov bot commented Dec 8, 2020

Codecov Report

Merging #48997 (df707a3) into master (5a5e576) will increase coverage by 0.00%.
The diff coverage is 100.00%.

@@           Coverage Diff           @@
##           master   #48997   +/-   ##
=======================================
  Coverage   80.56%   80.56%           
=======================================
  Files        1875     1875           
  Lines      202701   202715   +14     
=======================================
+ Hits       163307   163327   +20     
+ Misses      39394    39388    -6     

@RockingJavaBean RockingJavaBean force-pushed the torch_broadcast_to branch 2 times, most recently from 48e1cf2 to 779857d Compare December 9, 2020 08:04
@RockingJavaBean RockingJavaBean changed the title [WIP] Implementing NumPy-like function torch.broadcast_to Implementing NumPy-like function torch.broadcast_to Dec 9, 2020
@RockingJavaBean
Copy link
Contributor Author

RockingJavaBean commented Dec 9, 2020

This PR adds torch.broadcast_to function similar to NumPy's broadcast_to, and its implementation is based on the existing expand operator.

The differences between torch.broadcast_to and the expand operator are below:

  • torch.broadcast_to is implemented as both a torch function and a tensor method, while the expand is a tensor method.
  • The signature of the expand operator is expand(Tensor(a) self, int[] size, *, bool implicit=False), its implicit parameter is to let the tracer distinguish between expands inserted by broadcasts and those requested by the user.
    This parameter is not exposed by torch.broadcast_to because it is used internally and to be consistent with NumPy.

Please kindly help review this PR. @mruberry

torch/_tensor_docs.py Outdated Show resolved Hide resolved
torch/_torch_docs.py Outdated Show resolved Hide resolved
torch/_torch_docs.py Outdated Show resolved Hide resolved
torch/_torch_docs.py Outdated Show resolved Hide resolved
torch/_torch_docs.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @RockingJavaBean, thanks for implementing broadcast_to!

Overall this PR looks very good. I made a few suggestions in the documentation, and I'd like to suggest trying to replace the method_tests entries, which are being deprecated, with an OpInfo. Looking forward to hearing your thoughts!

@@ -1109,6 +1109,13 @@ def method_tests():
('view_as', (S, S, S), (non_differentiable(torch.rand(S * S, S)),)),
('view_as', (), (non_differentiable(torch.tensor(5.5)),), 'scalar'),
('view_as', (), (non_differentiable(torch.rand(1, 1)),), 'scalar_to_dims'),
('broadcast_to', (S, 1, 1), (dont_convert((S, S, S)),), '', (False,)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're actually starting to deprecate method_tests() entries like these in favor of writing OpInfos for new operators, like this:

Basically, an OpInfo is added to this alphabetical list. It has to specify which dtypes it supports using the dtype functions defined here:

_floating_types = _dispatch_dtypes((torch.float32, torch.float64))

and provide the same information as these method_tests() entries like addmm does. For example, instead of these tuples the operator create a sample_inputs_func that returns them.

This is a new thing, which is a pain, but I think it will be helpful to learn and I would appreciate your feedback on it. It may be a little trickier to setup than method_tests() but it will make testing much easier.

One problem with OpInfos currently is that operations which support the bfloat16 datatype have to skip the following test:

skips=(
  # RuntimeError: "rsqrt_cuda" not implemented for 'BFloat16'
  SkipInfo(
    'TestCommon', 'test_variant_consistency_jit',
    device_type='cuda', dtypes=[torch.bfloat16])))

This is a bug in the test's implementation that we expect to fix soon. Please let me know if you have any questions about doing this. If creating an OpInfo is too much of a pain then we can ignore it for now and continue to use method_tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm truly grateful for this comprehensive explanation of the new OpInfo test framework.
I think it is much more descriptive and readable compared to the method_tests, as the long tuple line full of testing parameters is often hard to read.
Moreover, OpInfo is awesome because it provides the standard way of testing newly-added operators. It helps us develop code in a more TDD style and write more robust code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have rewritten the test using OpInfo, and updated the test_ops.py as OpInfo currently does not support testing operators whose parameters are not torch.Tensor.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it does support non-tensor parameters, but let me take a look and I'll comment inline.

@H-Huang H-Huang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 10, 2020
@RockingJavaBean
Copy link
Contributor Author

@mruberry Thank you so much for the kind review and patient guidance on how to use OpInfo for testing, I have updated this PR with docs improvement and new OpInfo based test codes.

As the current implementation in test_ops.py uses the clone method for inputs, it would fail when the input is a sequence of ints like torch.broadcast_to, or a sequence of tensors like PyTorch stack related functions.

variant_forward = variant(*(input.clone() for input in sample.input), *sample.args, **sample.kwargs)

This PR adds a helper method in torch/testing/_internal/common_utils.py to handle cloning for the above scenarios.
The test/test_op_aliases.py, which tests torch.row_stack as the alias of torch.vstack, is updated to use this common helper.

Please kindly take a look.

@@ -201,7 +194,7 @@ def _fn(t):
arg_string = ', '.join((str(arg) for arg in info.get_args(device)))
script = fn_template.format(alias_name=info.alias_name, args=arg_string)
else:
is_input_tensor_list = isinstance(info.get_input(device), collections.Sequence)
is_input_tensor_list = isinstance(info.get_input(device), Sequence)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to update this file, since this file doesn't currently use any OpInfos to test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sequence here is imported from collections.abc to replace collections.Sequence.
The reason for the change is that collections.Sequence will be deprecated after Python 3.9.

In [4]: collections.Sequence
/root/anaconda3/envs/xx/bin/ipython:1: DeprecationWarning: Using or importing the ABCs from 'collections' 
instead of from 'collections.abc' is deprecated since Python 3.3,and in 3.9 it will stop working

Other changes for test_op_aliases.py has been reverted in the latest PR.

r"""
broadcast_to(input, shape) -> Tensor

Broadcasts :attr:`input` to the :attr:`shape`.
Copy link
Collaborator

@mruberry mruberry Dec 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"... to the :attr:`shape`" -> "... to the shape :attr:`\shape`"

Sorry this didn't format correctly the first time I suggested it. I think the formatting is fixed now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for the kind suggestions for improving the doc, it has been updated accordingly.

((), (1, 3, 2)),
)

return [SampleInput((make_tensor(size, device, dtype,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return these as a tuple, not a list, because tuples in Python are immutable. Lists are appropriate only when a mutable iterable is needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for pointing this out, I updated the sample_inputs_broadcast_to to return a tuple.

def clone_input_helper(input):
if isinstance(input, Sequence):
# use torch.clone for tensor, copy.deepcopy otherwise
return list(map(lambda x: torch.clone(x) if isinstance(x, torch.Tensor) else deepcopy(x),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just map clone_input_helper so you're calling this recursively on sequences (that would also handle nested sequences)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for not considering the function of this helper carefully, the helper has been updated as a recursive method to support nested sequences.

if isinstance(input, torch.Tensor):
return torch.clone(input)
else:
return deepcopy(input)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can get away without copying non-tensor inputs, since I don't think we have any functions that edit non-tensor inputs?

Ideally we'd also only clone when we're performing the operation inplace, but that can be a TODO for a later PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the advice, the clone_input_helper do not copy the non-tensor inputs in the latest PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_ops.py is updated to only clone tensor-type inputs for inplace operations.

variant_forward = variant(*(clone_input_helper(input) if variant is inplace else input
                            for input in sample.input),
                          *sample.args,
                          **sample.kwargs)
self.assertEqual(variant_forward, expected_forward)

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @RockingJavaBean!

I really like the approach you took here. The new helper function for making copies makes sense. I have a few very small comments but also suggest reverting the changes to test_op_aliases.py, since those tests are not affected by this PR. In the future test_op_aliases.py will be refactored to use OpInfos, and that will be a better opportunity to update its test.

@RockingJavaBean
Copy link
Contributor Author

@mruberry I'm really thankful for your kind and thorough review.

  • the doc description of torch.broadcast_to has been updated.
  • the changes besides the one for collections.abc.Sequence has been reverted for test_op_aliases.py.
  • the new clone helper is changed to a recursive method so that nested sequences can be handled properly.

Please kindly take a look.

@@ -283,6 +283,22 @@ def sample_inputs_addmm(op_info, device, dtype, requires_grad):
low=None, high=None,
requires_grad=False))),)

def sample_inputs_broadcast_to(op_info, device, dtype, requires_grad):
test_cases = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nice method of test generation.

test_inplace_grad=False,
sample_inputs_func=sample_inputs_broadcast_to,
skips=(
# RuntimeError: "isfinite" not implemented for 'BFloat16'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a rebase this skip should no longer be necessary - would you try removing it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for pointing this out, skips has been removed and the Opinfo-based broacast_to tests are passed.

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @RockingJavaBean, this looks great! Would you try removing the one skipped test and rebasing to resolve the merge conflict? After the rebase this should be OK to merge.

Looking forward to having broadcast_to in PyTorch!

@RockingJavaBean
Copy link
Contributor Author

@mruberry I'm deeply grateful for the code review of this PR, and your exhaustive guidance regarding writing tests using the newly introduced Opinfo. 🎄 🎅 🎄
I have rebased the code and removed the skipped tests.

@RockingJavaBean
Copy link
Contributor Author

By clicking the details pending CircleCI check, all steps of pytorch_linux_xenial_cuda10_2_cudnn7_py3_gcc7_build are successful.
It seems the pending check is already completed while its status hasn't been synced to this PR page.

@mruberry mruberry self-requested a review December 21, 2020 11:28
Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, @RockingJavaBean!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in 3779bde.

@facebook-github-bot
Copy link
Contributor

@mruberry merged this pull request in 3779bde.

@RockingJavaBean RockingJavaBean deleted the torch_broadcast_to branch December 22, 2020 01:30
hwangdeyu pushed a commit to hwangdeyu/pytorch that referenced this pull request Jan 6, 2021
Summary:
Related pytorch#38349

Implement NumPy-like function `torch.broadcast_to` to broadcast the input tensor to a new shape.

Pull Request resolved: pytorch#48997

Reviewed By: anjali411, ngimel

Differential Revision: D25663937

Pulled By: mruberry

fbshipit-source-id: 0415c03f92f02684983f412666d0a44515b99373
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants