-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate _th_std_var
to ATen
#59258
Migrate _th_std_var
to ATen
#59258
Conversation
💊 CI failures summary and remediationsAs of commit da58ffa (more details on the Dr. CI page):
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
0ada3db
to
82d3dcb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, @peterbell10
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
I'm going to assume that ROCm error is unrelated as this PR doesn't touch anything cuda or ROCm.
|
afbb2ae
to
c65c01e
Compare
Quick look at the test and it's entirely unrelated. I've fixed merge conflicts with the nonzero porting PR, so CI will get a chance to rerun anyway. |
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
This pull request has been reverted by 8b4784a. |
Weird, 2 subsequent tests for tbb build were fine, I'll try to reland. @peterbell10 if you have any idea why it could have failed, please comment. tbb tests were not flaky, and broke on std_cpu on this PR only |
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
I think the issue might be that |
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Hm, tbb failed again in the same way. |
0bffa00
to
da58ffa
Compare
@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Summary: Ref pytorch#49421 This migrates `std`/`var`'s special case all-reduction from TH to ATen. Using the benchmark from pytorchgh-43858 that was used to justify keeping the TH version; I find this PR has similar (slightly better) performance in single threaded. And unlike the TH version, this is multi-threaded and so much faster for large tensors. TH Results: ``` [----------------------------- Index ------------------------------] | torch_var | torch_var0 | stdfn | torch_sum0 1 threads: --------------------------------------------------------- 8 | 3.6 | 3.8 | 8.2 | 1.2 80 | 3.7 | 3.8 | 8.4 | 1.2 800 | 4.2 | 4.3 | 8.7 | 1.2 8000 | 9.0 | 9.1 | 11.2 | 1.5 80000 | 58.3 | 59.0 | 30.6 | 4.2 800000 | 546.9 | 546.9 | 183.4 | 31.3 8000000 | 5729.7 | 5701.0 | 6165.4 | 484.1 ``` ATen results: ``` [----------------------------- Index ------------------------------] | torch_var | torch_var0 | stdfn | torch_sum0 1 threads: --------------------------------------------------------- 8 | 4.0 | 4.0 | 8.7 | 1.2 80 | 3.6 | 3.8 | 9.0 | 1.2 800 | 4.1 | 4.3 | 8.9 | 1.2 8000 | 8.9 | 9.2 | 10.6 | 1.5 80000 | 57.0 | 57.4 | 28.8 | 4.3 800000 | 526.9 | 526.9 | 178.3 | 30.2 8000000 | 5568.1 | 5560.6 | 6042.1 | 453.2 [----------------------------- Index ------------------------------] | torch_var | torch_var0 | stdfn | torch_sum0 8 threads: --------------------------------------------------------- 8 | 3.9 | 3.8 | 9.1 | 1.2 80 | 3.8 | 3.9 | 8.8 | 1.2 800 | 4.2 | 4.3 | 8.9 | 1.3 8000 | 9.0 | 9.2 | 10.4 | 1.5 80000 | 26.0 | 26.8 | 26.4 | 4.4 800000 | 92.9 | 87.3 | 72.1 | 22.4 8000000 | 793.5 | 791.8 | 5334.8 | 115.1 ``` Pull Request resolved: pytorch#59258 Reviewed By: mruberry Differential Revision: D28821216 Pulled By: ngimel fbshipit-source-id: f35992c21f08a0a8878053680dc0ca7a8facd155
Summary: Ref pytorch#49421 This migrates `std`/`var`'s special case all-reduction from TH to ATen. Using the benchmark from pytorchgh-43858 that was used to justify keeping the TH version; I find this PR has similar (slightly better) performance in single threaded. And unlike the TH version, this is multi-threaded and so much faster for large tensors. TH Results: ``` [----------------------------- Index ------------------------------] | torch_var | torch_var0 | stdfn | torch_sum0 1 threads: --------------------------------------------------------- 8 | 3.6 | 3.8 | 8.2 | 1.2 80 | 3.7 | 3.8 | 8.4 | 1.2 800 | 4.2 | 4.3 | 8.7 | 1.2 8000 | 9.0 | 9.1 | 11.2 | 1.5 80000 | 58.3 | 59.0 | 30.6 | 4.2 800000 | 546.9 | 546.9 | 183.4 | 31.3 8000000 | 5729.7 | 5701.0 | 6165.4 | 484.1 ``` ATen results: ``` [----------------------------- Index ------------------------------] | torch_var | torch_var0 | stdfn | torch_sum0 1 threads: --------------------------------------------------------- 8 | 4.0 | 4.0 | 8.7 | 1.2 80 | 3.6 | 3.8 | 9.0 | 1.2 800 | 4.1 | 4.3 | 8.9 | 1.2 8000 | 8.9 | 9.2 | 10.6 | 1.5 80000 | 57.0 | 57.4 | 28.8 | 4.3 800000 | 526.9 | 526.9 | 178.3 | 30.2 8000000 | 5568.1 | 5560.6 | 6042.1 | 453.2 [----------------------------- Index ------------------------------] | torch_var | torch_var0 | stdfn | torch_sum0 8 threads: --------------------------------------------------------- 8 | 3.9 | 3.8 | 9.1 | 1.2 80 | 3.8 | 3.9 | 8.8 | 1.2 800 | 4.2 | 4.3 | 8.9 | 1.3 8000 | 9.0 | 9.2 | 10.4 | 1.5 80000 | 26.0 | 26.8 | 26.4 | 4.4 800000 | 92.9 | 87.3 | 72.1 | 22.4 8000000 | 793.5 | 791.8 | 5334.8 | 115.1 ``` Pull Request resolved: pytorch#59258 Reviewed By: jbschlosser Differential Revision: D28860353 Pulled By: ngimel fbshipit-source-id: 80c9fe1db84dbc864eeb1a319076c7aaff0a04e5
Ref #49421
This migrates
std
/var
's special case all-reduction from TH to ATen. Using the benchmark from gh-43858 that was used to justify keeping the TH version; I find this PR has similar (slightly better) performance in single threaded. And unlike the TH version, this is multi-threaded and so much faster for large tensors.TH Results:
ATen results: