Skip to content

Conversation

RockingJavaBean
Copy link
Contributor

@RockingJavaBean RockingJavaBean commented Jan 26, 2021

Close #51108
Related #38349

This PR implements the cpu_kernel_multiple_outputs to support returning multiple values in a CPU kernel.

auto iter = at::TensorIteratorConfig()
  .add_output(out1)
  .add_output(out2)
  .add_input(in1)
  .add_input(in2)
  .build();

at::native::cpu_kernel_multiple_outputs(iter,
  [=](float a, float b) -> std::tuple<float, float> {
    float add = a + b;
    float mul = a * b;
    return std::tuple<float, float>(add, mul);
  }
);

The out1 will equal to torch.add(in1, in2), while the result of out2 will be torch.mul(in1, in2).
It helps developers implement new torch functions that return two tensors more conveniently, such as NumPy-like functions divmod and frexp.

This PR adds torch.frexp function to exercise the new functionality provided by cpu_kernel_multiple_outputs.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jan 26, 2021

💊 CI failures summary and remediations

As of commit c8198e1 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@RockingJavaBean RockingJavaBean force-pushed the cpu_kernel_mutiple_outputs branch from 91dbe3f to c9a493a Compare January 26, 2021 09:23
@RockingJavaBean RockingJavaBean changed the title [WIP] implement cpu_kernel_multiple_outputs to support returning multiple values implement cpu_kernel_multiple_outputs to support returning multiple values Jan 26, 2021
@codecov
Copy link

codecov bot commented Jan 26, 2021

Codecov Report

Merging #51097 (c8198e1) into master (f6df18f) will increase coverage by 0.98%.
The diff coverage is 96.42%.

@@            Coverage Diff             @@
##           master   #51097      +/-   ##
==========================================
+ Coverage   76.36%   77.35%   +0.98%     
==========================================
  Files        1886     1887       +1     
  Lines      184699   184804     +105     
==========================================
+ Hits       141048   142946    +1898     
+ Misses      43651    41858    -1793     

@heitorschueroff heitorschueroff added module: reductions module: TensorIterator triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 26, 2021
@RockingJavaBean RockingJavaBean force-pushed the cpu_kernel_mutiple_outputs branch from c9a493a to a76fe8d Compare January 27, 2021 05:25
@mruberry mruberry requested review from mruberry and removed request for colesbury February 1, 2021 06:44
@mruberry
Copy link
Collaborator

mruberry commented Feb 1, 2021

We should add a function that exercises this new functionality to test it, too. What about Numpy's frexp? (https://numpy.org/doc/stable/reference/generated/numpy.frexp.html).

@RockingJavaBean
Copy link
Contributor Author

@mruberry Thanks for the kind suggestion, I will update this PR by adding torch.frexp using both cpu_kernel_multiple_outputs and gpu_kernel_multiple_outputs.

@mruberry
Copy link
Collaborator

Thanks @RockingJavaBean! I'm still catching up from being out on vacation, but I'll take a look ASAP!

Copy link
Contributor

@heitorschueroff heitorschueroff left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @RockingJavaBean , @mruberry and I took another look at your updates together. Overall, this PR is looking pretty good and close to ready. We really appreciate all the work you put into implementing frexp and also taking extra time to fix bugs in our test suite such as separating the float values.

There are still a few suggestion we have for your review:

  • We think we can reuse test_reference_numerics with a tweak, and while we appreciate you demonstrating that the other test_unary_ufunc.py tests can be adapted to work with frexp(), we're a little worried that change is too big. For the final PR, we'd like to suggest skipping those tests and leaving them unmodified.
  • For the test_frexp_out, add cases for incorrectly sized and noncontiguous inputs
  • Add supports_tensor_out=False to the UnaryUfuncInfo and fix the test_out_arg... to correctly query for the metadata instead of skipping the test.

We look forward to your next updates!

@RockingJavaBean RockingJavaBean force-pushed the cpu_kernel_mutiple_outputs branch from 123a787 to 95f662e Compare March 9, 2021 08:20
@RockingJavaBean RockingJavaBean force-pushed the cpu_kernel_mutiple_outputs branch from 95f662e to 52abf28 Compare March 9, 2021 11:14
@RockingJavaBean RockingJavaBean force-pushed the cpu_kernel_mutiple_outputs branch from de5c23e to 742a99f Compare March 10, 2021 12:30
@heitorschueroff heitorschueroff self-requested a review March 10, 2021 15:31
@RockingJavaBean
Copy link
Contributor Author

I'm really grateful for the thorough review and invaluable suggestions throughout this PR.

This PR has been updated with the following changes:

  • update docs according to the suggestions, and add comments to the cpu_kernel_multiple_outputs related methods in Loops.h
  • revert the changes to test_ops.py, as well as the changes to existing tests in test_unary_ufuncs.py except for the tweak for test_reference_numerics.
  • add test cases for incorrectly sized and noncontiguous inputs to test_frexp_out, explicitly check exponent dtype, and compare it with NumPy counterpart in the customize test method test_frexp.

Please kindly take a look. @heitorschueroff @mruberry

Copy link
Contributor

@heitorschueroff heitorschueroff left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RockingJavaBean This last version looks ready, great work! Just one last change before I land this, it looks like the PR picked up some changes from another PR which I commented on, could you confirm this and fix it please, I'll land it then.

@RockingJavaBean
Copy link
Contributor Author

@heitorschueroff I'm truly thankful for your kind review.
The above changes of torch.special_gammaln belongs to a merge commit fixing code conflicts, and would not be in the final changes of this PR, please kindly help confirm this.

@heitorschueroff
Copy link
Contributor

@heitorschueroff I'm truly thankful for your kind review.
The above changes of torch.special_gammaln belongs to a merge commit fixing code conflicts, and would not be in the final changes of this PR, please kindly help confirm this.

You're correct. I'm landing the PR now, thank you for this great PR.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@heitorschueroff has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice job, @RockingJavaBean! This is a very technically complicated PR. I appreciate your thoughtfulness on both the technical challenges and the test architecture, too.

And thanks @heitorschueroff for reviewing this!

@RockingJavaBean
Copy link
Contributor Author

It is my honor to contribute to the PyTorch project and it cannot be done without your generous help and guidance.
I'm really grateful to you. @mruberry @heitorschueroff

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@heitorschueroff has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@heitorschueroff merged this pull request in da10ccd.

@RockingJavaBean RockingJavaBean deleted the cpu_kernel_mutiple_outputs branch March 16, 2021 01:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged module: reductions module: TensorIterator open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Function request: support returning multiple values in CPU kernel
7 participants