Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LANDED]Speed up gemm by reordering the for loops #17730

Closed

Conversation

zhangguanheng66
Copy link
Contributor

Summary:
Optimize the order of the "for" loops.

Note: For "transa = true" cases, the order of the "for" loops has been optimzied in the original code. Therefore, no significant improvement is observed in those case (i.e. "transa && transb" and "transa && !transb")

mode/opt (i.e. static libary)
//////////////////////////////////////////////////////////////////////////////
transa && transb
after:
loops: 2229 x: 128 y: 128 z: 128 time: 2243ns => acceleration multiplier: 0.90
loops: 124 x: 128 y: 1024 z: 128 time: 40381ns => acceleration multiplier: 0.97
loops: 121 x: 1024 y: 128 z: 128 time: 41651ns => acceleration multiplier: 0.96
loops: 15 x: 1024 y: 1024 z: 128 time: 333771ns => acceleration multiplier: 0.98
loops: 4610 x: 128 y: 128 z: 64 time: 1084ns => acceleration multiplier: 0.95
loops: 252 x: 128 y: 1024 z: 64 time: 19860ns => acceleration multiplier: 0.98
loops: 248 x: 1024 y: 128 z: 64 time: 20232ns => acceleration multiplier: 0.98
loops: 30 x: 1024 y: 1024 z: 64 time: 167338ns => acceleration multiplier: 0.99

before:
loops: 2468 x: 128 y: 128 z: 128 time: 2026ns
loops: 128 x: 128 y: 1024 z: 128 time: 39338ns
loops: 126 x: 1024 y: 128 z: 128 time: 39930ns
loops: 16 x: 1024 y: 1024 z: 128 time: 327549ns
loops: 4840 x: 128 y: 128 z: 64 time: 1033ns
loops: 258 x: 128 y: 1024 z: 64 time: 19441ns
loops: 252 x: 1024 y: 128 z: 64 time: 19854ns
loops: 31 x: 1024 y: 1024 z: 64 time: 166254ns

//////////////////////////////////////////////////////////////////////////////
transa && !transb
after:
loops: 4880 x: 128 y: 128 z: 128 time: 1024ns => acceleration multiplier: 0.98
loops: 638 x: 128 y: 1024 z: 128 time: 7839ns => acceleration multiplier: 1.04
loops: 605 x: 1024 y: 128 z: 128 time: 8276ns => acceleration multiplier: 1.01
loops: 77 x: 1024 y: 1024 z: 128 time: 65713ns => acceleration multiplier: 1.00
loops: 9935 x: 128 y: 128 z: 64 time: 503ns => acceleration multiplier: 1.00
loops: 1252 x: 128 y: 1024 z: 64 time: 3994ns => acceleration multiplier: 1.00
loops: 1183 x: 1024 y: 128 z: 64 time: 4226ns => acceleration multiplier: 0.98
loops: 153 x: 1024 y: 1024 z: 64 time: 32766ns => acceleration multiplier: 0.99

before:
loops: 4985 x: 128 y: 128 z: 128 time: 1003ns
loops: 615 x: 128 y: 1024 z: 128 time: 8140ns
loops: 599 x: 1024 y: 128 z: 128 time: 8357ns
loops: 76 x: 1024 y: 1024 z: 128 time: 65934ns
loops: 9897 x: 128 y: 128 z: 64 time: 505ns
loops: 1248 x: 128 y: 1024 z: 64 time: 4008ns
loops: 1203 x: 1024 y: 128 z: 64 time: 4159ns
loops: 154 x: 1024 y: 1024 z: 64 time: 32499ns

//////////////////////////////////////////////////////////////////////////////
!transa && transb
after:
loops: 3919 x: 128 y: 128 z: 128 time: 1276ns => acceleration multiplier: 2.97
loops: 497 x: 128 y: 1024 z: 128 time: 10069ns => acceleration multiplier: 7.85
loops: 449 x: 1024 y: 128 z: 128 time: 11145ns => acceleration multiplier: 4.77
loops: 57 x: 1024 y: 1024 z: 128 time: 88595ns => acceleration multiplier: 7.12
loops: 7575 x: 128 y: 128 z: 64 time: 660ns => acceleration multiplier: 3.00
loops: 967 x: 128 y: 1024 z: 64 time: 5173ns => acceleration multiplier: 7.66
loops: 877 x: 1024 y: 128 z: 64 time: 5702ns => acceleration multiplier: 4.76
loops: 111 x: 1024 y: 1024 z: 64 time: 45232ns => acceleration multiplier: 7.03

before:
loops: 1320 x: 128 y: 128 z: 128 time: 3789ns
loops: 64 x: 128 y: 1024 z: 128 time: 79061ns
loops: 95 x: 1024 y: 128 z: 128 time: 53107ns
loops: 8 x: 1024 y: 1024 z: 128 time: 631161ns
loops: 2521 x: 128 y: 128 z: 64 time: 1983ns
loops: 127 x: 128 y: 1024 z: 64 time: 39604ns
loops: 185 x: 1024 y: 128 z: 64 time: 27128ns
loops: 16 x: 1024 y: 1024 z: 64 time: 318155ns

//////////////////////////////////////////////////////////////////////////////
!transa && !transb
after:
loops: 3895 x: 128 y: 128 z: 128 time: 1283ns => acceleration multiplier: 1.73
loops: 393 x: 128 y: 1024 z: 128 time: 12746ns => acceleration multiplier: 3.36
loops: 411 x: 1024 y: 128 z: 128 time: 12170ns => acceleration multiplier: 1.93
loops: 46 x: 1024 y: 1024 z: 128 time: 110116ns => acceleration multiplier: 3.17
loops: 7404 x: 128 y: 128 z: 64 time: 675ns => acceleration multiplier: 1.58
loops: 636 x: 128 y: 1024 z: 64 time: 7872ns => acceleration multiplier: 2.70
loops: 724 x: 1024 y: 128 z: 64 time: 6911ns => acceleration multiplier: 1.32
loops: 73 x: 1024 y: 1024 z: 64 time: 68502ns => acceleration multiplier: 2.49

before:
loops: 2253 x: 128 y: 128 z: 128 time: 2219ns
loops: 117 x: 128 y: 1024 z: 128 time: 42788ns
loops: 214 x: 1024 y: 128 z: 128 time: 23465ns
loops: 15 x: 1024 y: 1024 z: 128 time: 349076ns
loops: 4694 x: 128 y: 128 z: 64 time: 1065ns
loops: 236 x: 128 y: 1024 z: 64 time: 21251ns
loops: 549 x: 1024 y: 128 z: 64 time: 9108ns
loops: 30 x: 1024 y: 1024 z: 64 time: 170799ns

Differential Revision: D14325149

@zhangguanheng66 zhangguanheng66 changed the title Speed up gemm Speed up gemm by reordering the for loops Mar 6, 2019
@zhangguanheng66
Copy link
Contributor Author

@cpuhrsch The pull request has been created for review.


res2 = matrixmultiply(mat1, mat2)
self.assertEqual(res, res2)
def _test_mm(n, m, p, dtype, genf):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be worth adding test coverage for torch.mm's out flag to explicitly exercise the _out variant of this function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you show me more details on how to exercise the _out variant in the torch.mm. I will try to figure it out tomorrow as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By calling e.g. torch.mm(a, b, out=c), where c has been preallocated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add two additional cases in test_mm() to explicitly exercise the _out variant in torch.mm()

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Mar 7, 2019

Looks good! Once all tests pass and the comments are resolved we should be good to go.

c[j*ldc+i] = alpha*sum;
else
c[j*ldc+i] = beta*c[j*ldc+i]+alpha*sum;
if (beta == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we explicitly test for beta == 0 and beta != 0? It's important to confirm this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the caffe2/test under fbcode (i.e. "buck test @mode/dev caffe2/test:torch"), the value of beta is set to zero.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, can you look for callsites of THBlas_(gemm) to see which ones might call it with beta != 0? I can also help you with this one-on-one if necessary.

@cpuhrsch
Copy link
Contributor

This is ready to go after we did a bit more work on investigating why "beta" is set to 0 for all tests.

@zhangguanheng66
Copy link
Contributor Author

Avoid using memset() function in branch 1 and 3.
The updated is able to pass the previously failed test: caffe2_onnx_py2_gcc5_ubuntu16_04_test.
The PR is ready to land. @cpuhrsch

@cpuhrsch
Copy link
Contributor

Thank you! Go ahead @zhangguanheng66 .

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangguanheng66 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhangguanheng66 is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Summary:
Optimize the order of the "for" loops.

Note: For "transa = true" cases, the order of the "for" loops has been optimzied in the original code. Therefore, no significant improvement is observed in those case (i.e. "transa && transb" and "transa && !transb")

mode/opt (i.e. static libary)
//////////////////////////////////////////////////////////////////////////////
transa && transb
after:
loops:  2229     x:     128      y:     128      z:     128      time:  2243ns      =>  acceleration multiplier:  0.90
loops:  124      x:     128      y:     1024     z:     128      time:  40381ns      =>  acceleration multiplier:  0.97
loops:  121      x:     1024     y:     128      z:     128      time:  41651ns      =>  acceleration multiplier:  0.96
loops:  15       x:     1024     y:     1024     z:     128      time:  333771ns       =>  acceleration multiplier:  0.98
loops:  4610     x:     128      y:     128      z:     64       time:  1084ns       =>  acceleration multiplier:  0.95
loops:  252      x:     128      y:     1024     z:     64       time:  19860ns      =>  acceleration multiplier:  0.98
loops:  248      x:     1024     y:     128      z:     64       time:  20232ns      =>  acceleration multiplier:  0.98
loops:  30       x:     1024     y:     1024     z:     64       time:  167338ns      =>  acceleration multiplier:  0.99

before:
loops:  2468     x:     128      y:     128      z:     128      time:  2026ns
loops:  128      x:     128      y:     1024     z:     128      time:  39338ns
loops:  126      x:     1024     y:     128      z:     128      time:  39930ns
loops:  16       x:     1024     y:     1024     z:     128      time:  327549ns
loops:  4840     x:     128      y:     128      z:     64       time:  1033ns
loops:  258      x:     128      y:     1024     z:     64       time:  19441ns
loops:  252      x:     1024     y:     128      z:     64       time:  19854ns
loops:  31       x:     1024     y:     1024     z:     64       time:  166254ns

//////////////////////////////////////////////////////////////////////////////
transa && !transb
after:
loops:  4880     x:     128      y:     128      z:     128      time:  1024ns      =>  acceleration multiplier:  0.98
loops:  638      x:     128      y:     1024     z:     128      time:  7839ns      =>  acceleration multiplier:  1.04
loops:  605      x:     1024     y:     128      z:     128      time:  8276ns      =>  acceleration multiplier:  1.01
loops:  77       x:     1024     y:     1024     z:     128      time:  65713ns      =>  acceleration multiplier:  1.00
loops:  9935     x:     128      y:     128      z:     64       time:  503ns      =>  acceleration multiplier:  1.00
loops:  1252     x:     128      y:     1024     z:     64       time:  3994ns      =>  acceleration multiplier:  1.00
loops:  1183     x:     1024     y:     128      z:     64       time:  4226ns      =>  acceleration multiplier:  0.98
loops:  153      x:     1024     y:     1024     z:     64       time:  32766ns      =>  acceleration multiplier:  0.99

before:
loops:  4985     x:     128      y:     128      z:     128      time:  1003ns
loops:  615      x:     128      y:     1024     z:     128      time:  8140ns
loops:  599      x:     1024     y:     128      z:     128      time:  8357ns
loops:  76       x:     1024     y:     1024     z:     128      time:  65934ns
loops:  9897     x:     128      y:     128      z:     64       time:  505ns
loops:  1248     x:     128      y:     1024     z:     64       time:  4008ns
loops:  1203     x:     1024     y:     128      z:     64       time:  4159ns
loops:  154      x:     1024     y:     1024     z:     64       time:  32499ns

//////////////////////////////////////////////////////////////////////////////
!transa && transb
after:
loops:  3919     x:     128      y:     128      z:     128      time:  1276ns      =>  acceleration multiplier:  2.97
loops:  497      x:     128      y:     1024     z:     128      time:  10069ns      =>  acceleration multiplier:  7.85
loops:  449      x:     1024     y:     128      z:     128      time:  11145ns      =>  acceleration multiplier:  4.77
loops:  57       x:     1024     y:     1024     z:     128      time:  88595ns      =>  acceleration multiplier:  7.12
loops:  7575     x:     128      y:     128      z:     64       time:  660ns      =>  acceleration multiplier:  3.00
loops:  967      x:     128      y:     1024     z:     64       time:  5173ns      =>  acceleration multiplier:  7.66
loops:  877      x:     1024     y:     128      z:     64       time:  5702ns      =>  acceleration multiplier:  4.76
loops:  111      x:     1024     y:     1024     z:     64       time:  45232ns      =>  acceleration multiplier:  7.03

before:
loops:  1320     x:     128      y:     128      z:     128      time:  3789ns
loops:  64       x:     128      y:     1024     z:     128      time:  79061ns
loops:  95       x:     1024     y:     128      z:     128      time:  53107ns
loops:  8        x:     1024     y:     1024     z:     128      time:  631161ns
loops:  2521     x:     128      y:     128      z:     64       time:  1983ns
loops:  127      x:     128      y:     1024     z:     64       time:  39604ns
loops:  185      x:     1024     y:     128      z:     64       time:  27128ns
loops:  16       x:     1024     y:     1024     z:     64       time:  318155ns

//////////////////////////////////////////////////////////////////////////////
!transa && !transb
after:
loops:  3895     x:     128      y:     128      z:     128      time:  1283ns      =>  acceleration multiplier:  1.73
loops:  393      x:     128      y:     1024     z:     128      time:  12746ns      =>  acceleration multiplier:  3.36
loops:  411      x:     1024     y:     128      z:     128      time:  12170ns      =>  acceleration multiplier:  1.93
loops:  46       x:     1024     y:     1024     z:     128      time:  110116ns      =>  acceleration multiplier:  3.17
loops:  7404     x:     128      y:     128      z:     64       time:  675ns      =>  acceleration multiplier:  1.58
loops:  636      x:     128      y:     1024     z:     64       time:  7872ns      =>  acceleration multiplier:  2.70
loops:  724      x:     1024     y:     128      z:     64       time:  6911ns      =>  acceleration multiplier:  1.32
loops:  73       x:     1024     y:     1024     z:     64       time:  68502ns      =>  acceleration multiplier:  2.49

before:
loops:  2253     x:     128      y:     128      z:     128      time:  2219ns
loops:  117      x:     128      y:     1024     z:     128      time:  42788ns
loops:  214      x:     1024     y:     128      z:     128      time:  23465ns
loops:  15       x:     1024     y:     1024     z:     128      time:  349076ns
loops:  4694     x:     128      y:     128      z:     64       time:  1065ns
loops:  236      x:     128      y:     1024     z:     64       time:  21251ns
loops:  549      x:     1024     y:     128      z:     64       time:  9108ns
loops:  30       x:     1024     y:     1024     z:     64       time:  170799ns
Pull Request resolved: pytorch#17730

Differential Revision: D14325149

fbshipit-source-id: ea8679b5f89826817eed7b7fe8ed621b7a2247d7
facebook-github-bot pushed a commit that referenced this pull request Mar 13, 2019
Summary:
Optimize the order of the "for" loops.

Note: For "transa = true" cases, the order of the "for" loops has been optimzied in the original code. Therefore, no significant improvement is observed in those case (i.e. "transa && transb" and "transa && !transb")

mode/opt (i.e. static libary)
//////////////////////////////////////////////////////////////////////////////
transa && transb
after:
loops:  2229     x:     128      y:     128      z:     128      time:  2243ns      =>  acceleration multiplier:  0.90
loops:  124      x:     128      y:     1024     z:     128      time:  40381ns      =>  acceleration multiplier:  0.97
loops:  121      x:     1024     y:     128      z:     128      time:  41651ns      =>  acceleration multiplier:  0.96
loops:  15       x:     1024     y:     1024     z:     128      time:  333771ns       =>  acceleration multiplier:  0.98
loops:  4610     x:     128      y:     128      z:     64       time:  1084ns       =>  acceleration multiplier:  0.95
loops:  252      x:     128      y:     1024     z:     64       time:  19860ns      =>  acceleration multiplier:  0.98
loops:  248      x:     1024     y:     128      z:     64       time:  20232ns      =>  acceleration multiplier:  0.98
loops:  30       x:     1024     y:     1024     z:     64       time:  167338ns      =>  acceleration multiplier:  0.99

before:
loops:  2468     x:     128      y:     128      z:     128      time:  2026ns
loops:  128      x:     128      y:     1024     z:     128      time:  39338ns
loops:  126      x:     1024     y:     128      z:     128      time:  39930ns
loops:  16       x:     1024     y:     1024     z:     128      time:  327549ns
loops:  4840     x:     128      y:     128      z:     64       time:  1033ns
loops:  258      x:     128      y:     1024     z:     64       time:  19441ns
loops:  252      x:     1024     y:     128      z:     64       time:  19854ns
loops:  31       x:     1024     y:     1024     z:     64       time:  166254ns

//////////////////////////////////////////////////////////////////////////////
transa && !transb
after:
loops:  4880     x:     128      y:     128      z:     128      time:  1024ns      =>  acceleration multiplier:  0.98
loops:  638      x:     128      y:     1024     z:     128      time:  7839ns      =>  acceleration multiplier:  1.04
loops:  605      x:     1024     y:     128      z:     128      time:  8276ns      =>  acceleration multiplier:  1.01
loops:  77       x:     1024     y:     1024     z:     128      time:  65713ns      =>  acceleration multiplier:  1.00
loops:  9935     x:     128      y:     128      z:     64       time:  503ns      =>  acceleration multiplier:  1.00
loops:  1252     x:     128      y:     1024     z:     64       time:  3994ns      =>  acceleration multiplier:  1.00
loops:  1183     x:     1024     y:     128      z:     64       time:  4226ns      =>  acceleration multiplier:  0.98
loops:  153      x:     1024     y:     1024     z:     64       time:  32766ns      =>  acceleration multiplier:  0.99

before:
loops:  4985     x:     128      y:     128      z:     128      time:  1003ns
loops:  615      x:     128      y:     1024     z:     128      time:  8140ns
loops:  599      x:     1024     y:     128      z:     128      time:  8357ns
loops:  76       x:     1024     y:     1024     z:     128      time:  65934ns
loops:  9897     x:     128      y:     128      z:     64       time:  505ns
loops:  1248     x:     128      y:     1024     z:     64       time:  4008ns
loops:  1203     x:     1024     y:     128      z:     64       time:  4159ns
loops:  154      x:     1024     y:     1024     z:     64       time:  32499ns

//////////////////////////////////////////////////////////////////////////////
!transa && transb
after:
loops:  3919     x:     128      y:     128      z:     128      time:  1276ns      =>  acceleration multiplier:  2.97
loops:  497      x:     128      y:     1024     z:     128      time:  10069ns      =>  acceleration multiplier:  7.85
loops:  449      x:     1024     y:     128      z:     128      time:  11145ns      =>  acceleration multiplier:  4.77
loops:  57       x:     1024     y:     1024     z:     128      time:  88595ns      =>  acceleration multiplier:  7.12
loops:  7575     x:     128      y:     128      z:     64       time:  660ns      =>  acceleration multiplier:  3.00
loops:  967      x:     128      y:     1024     z:     64       time:  5173ns      =>  acceleration multiplier:  7.66
loops:  877      x:     1024     y:     128      z:     64       time:  5702ns      =>  acceleration multiplier:  4.76
loops:  111      x:     1024     y:     1024     z:     64       time:  45232ns      =>  acceleration multiplier:  7.03

before:
loops:  1320     x:     128      y:     128      z:     128      time:  3789ns
loops:  64       x:     128      y:     1024     z:     128      time:  79061ns
loops:  95       x:     1024     y:     128      z:     128      time:  53107ns
loops:  8        x:     1024     y:     1024     z:     128      time:  631161ns
loops:  2521     x:     128      y:     128      z:     64       time:  1983ns
loops:  127      x:     128      y:     1024     z:     64       time:  39604ns
loops:  185      x:     1024     y:     128      z:     64       time:  27128ns
loops:  16       x:     1024     y:     1024     z:     64       time:  318155ns

//////////////////////////////////////////////////////////////////////////////
!transa && !transb
after:
loops:  3895     x:     128      y:     128      z:     128      time:  1283ns      =>  acceleration multiplier:  1.73
loops:  393      x:     128      y:     1024     z:     128      time:  12746ns      =>  acceleration multiplier:  3.36
loops:  411      x:     1024     y:     128      z:     128      time:  12170ns      =>  acceleration multiplier:  1.93
loops:  46       x:     1024     y:     1024     z:     128      time:  110116ns      =>  acceleration multiplier:  3.17
loops:  7404     x:     128      y:     128      z:     64       time:  675ns      =>  acceleration multiplier:  1.58
loops:  636      x:     128      y:     1024     z:     64       time:  7872ns      =>  acceleration multiplier:  2.70
loops:  724      x:     1024     y:     128      z:     64       time:  6911ns      =>  acceleration multiplier:  1.32
loops:  73       x:     1024     y:     1024     z:     64       time:  68502ns      =>  acceleration multiplier:  2.49

before:
loops:  2253     x:     128      y:     128      z:     128      time:  2219ns
loops:  117      x:     128      y:     1024     z:     128      time:  42788ns
loops:  214      x:     1024     y:     128      z:     128      time:  23465ns
loops:  15       x:     1024     y:     1024     z:     128      time:  349076ns
loops:  4694     x:     128      y:     128      z:     64       time:  1065ns
loops:  236      x:     128      y:     1024     z:     64       time:  21251ns
loops:  549      x:     1024     y:     128      z:     64       time:  9108ns
loops:  30       x:     1024     y:     1024     z:     64       time:  170799ns
Pull Request resolved: #17730

Differential Revision: D14325149

Pulled By: zhangguanheng66

fbshipit-source-id: a7a5a83890fdf99fee6eb87a3a5060b7b6bd862f
zdevito pushed a commit to zdevito/ATen that referenced this pull request Mar 13, 2019
Summary:
Optimize the order of the "for" loops.

Note: For "transa = true" cases, the order of the "for" loops has been optimzied in the original code. Therefore, no significant improvement is observed in those case (i.e. "transa && transb" and "transa && !transb")

mode/opt (i.e. static libary)
//////////////////////////////////////////////////////////////////////////////
transa && transb
after:
loops:  2229     x:     128      y:     128      z:     128      time:  2243ns      =>  acceleration multiplier:  0.90
loops:  124      x:     128      y:     1024     z:     128      time:  40381ns      =>  acceleration multiplier:  0.97
loops:  121      x:     1024     y:     128      z:     128      time:  41651ns      =>  acceleration multiplier:  0.96
loops:  15       x:     1024     y:     1024     z:     128      time:  333771ns       =>  acceleration multiplier:  0.98
loops:  4610     x:     128      y:     128      z:     64       time:  1084ns       =>  acceleration multiplier:  0.95
loops:  252      x:     128      y:     1024     z:     64       time:  19860ns      =>  acceleration multiplier:  0.98
loops:  248      x:     1024     y:     128      z:     64       time:  20232ns      =>  acceleration multiplier:  0.98
loops:  30       x:     1024     y:     1024     z:     64       time:  167338ns      =>  acceleration multiplier:  0.99

before:
loops:  2468     x:     128      y:     128      z:     128      time:  2026ns
loops:  128      x:     128      y:     1024     z:     128      time:  39338ns
loops:  126      x:     1024     y:     128      z:     128      time:  39930ns
loops:  16       x:     1024     y:     1024     z:     128      time:  327549ns
loops:  4840     x:     128      y:     128      z:     64       time:  1033ns
loops:  258      x:     128      y:     1024     z:     64       time:  19441ns
loops:  252      x:     1024     y:     128      z:     64       time:  19854ns
loops:  31       x:     1024     y:     1024     z:     64       time:  166254ns

//////////////////////////////////////////////////////////////////////////////
transa && !transb
after:
loops:  4880     x:     128      y:     128      z:     128      time:  1024ns      =>  acceleration multiplier:  0.98
loops:  638      x:     128      y:     1024     z:     128      time:  7839ns      =>  acceleration multiplier:  1.04
loops:  605      x:     1024     y:     128      z:     128      time:  8276ns      =>  acceleration multiplier:  1.01
loops:  77       x:     1024     y:     1024     z:     128      time:  65713ns      =>  acceleration multiplier:  1.00
loops:  9935     x:     128      y:     128      z:     64       time:  503ns      =>  acceleration multiplier:  1.00
loops:  1252     x:     128      y:     1024     z:     64       time:  3994ns      =>  acceleration multiplier:  1.00
loops:  1183     x:     1024     y:     128      z:     64       time:  4226ns      =>  acceleration multiplier:  0.98
loops:  153      x:     1024     y:     1024     z:     64       time:  32766ns      =>  acceleration multiplier:  0.99

before:
loops:  4985     x:     128      y:     128      z:     128      time:  1003ns
loops:  615      x:     128      y:     1024     z:     128      time:  8140ns
loops:  599      x:     1024     y:     128      z:     128      time:  8357ns
loops:  76       x:     1024     y:     1024     z:     128      time:  65934ns
loops:  9897     x:     128      y:     128      z:     64       time:  505ns
loops:  1248     x:     128      y:     1024     z:     64       time:  4008ns
loops:  1203     x:     1024     y:     128      z:     64       time:  4159ns
loops:  154      x:     1024     y:     1024     z:     64       time:  32499ns

//////////////////////////////////////////////////////////////////////////////
!transa && transb
after:
loops:  3919     x:     128      y:     128      z:     128      time:  1276ns      =>  acceleration multiplier:  2.97
loops:  497      x:     128      y:     1024     z:     128      time:  10069ns      =>  acceleration multiplier:  7.85
loops:  449      x:     1024     y:     128      z:     128      time:  11145ns      =>  acceleration multiplier:  4.77
loops:  57       x:     1024     y:     1024     z:     128      time:  88595ns      =>  acceleration multiplier:  7.12
loops:  7575     x:     128      y:     128      z:     64       time:  660ns      =>  acceleration multiplier:  3.00
loops:  967      x:     128      y:     1024     z:     64       time:  5173ns      =>  acceleration multiplier:  7.66
loops:  877      x:     1024     y:     128      z:     64       time:  5702ns      =>  acceleration multiplier:  4.76
loops:  111      x:     1024     y:     1024     z:     64       time:  45232ns      =>  acceleration multiplier:  7.03

before:
loops:  1320     x:     128      y:     128      z:     128      time:  3789ns
loops:  64       x:     128      y:     1024     z:     128      time:  79061ns
loops:  95       x:     1024     y:     128      z:     128      time:  53107ns
loops:  8        x:     1024     y:     1024     z:     128      time:  631161ns
loops:  2521     x:     128      y:     128      z:     64       time:  1983ns
loops:  127      x:     128      y:     1024     z:     64       time:  39604ns
loops:  185      x:     1024     y:     128      z:     64       time:  27128ns
loops:  16       x:     1024     y:     1024     z:     64       time:  318155ns

//////////////////////////////////////////////////////////////////////////////
!transa && !transb
after:
loops:  3895     x:     128      y:     128      z:     128      time:  1283ns      =>  acceleration multiplier:  1.73
loops:  393      x:     128      y:     1024     z:     128      time:  12746ns      =>  acceleration multiplier:  3.36
loops:  411      x:     1024     y:     128      z:     128      time:  12170ns      =>  acceleration multiplier:  1.93
loops:  46       x:     1024     y:     1024     z:     128      time:  110116ns      =>  acceleration multiplier:  3.17
loops:  7404     x:     128      y:     128      z:     64       time:  675ns      =>  acceleration multiplier:  1.58
loops:  636      x:     128      y:     1024     z:     64       time:  7872ns      =>  acceleration multiplier:  2.70
loops:  724      x:     1024     y:     128      z:     64       time:  6911ns      =>  acceleration multiplier:  1.32
loops:  73       x:     1024     y:     1024     z:     64       time:  68502ns      =>  acceleration multiplier:  2.49

before:
loops:  2253     x:     128      y:     128      z:     128      time:  2219ns
loops:  117      x:     128      y:     1024     z:     128      time:  42788ns
loops:  214      x:     1024     y:     128      z:     128      time:  23465ns
loops:  15       x:     1024     y:     1024     z:     128      time:  349076ns
loops:  4694     x:     128      y:     128      z:     64       time:  1065ns
loops:  236      x:     128      y:     1024     z:     64       time:  21251ns
loops:  549      x:     1024     y:     128      z:     64       time:  9108ns
loops:  30       x:     1024     y:     1024     z:     64       time:  170799ns
Pull Request resolved: pytorch/pytorch#17730

Differential Revision: D14325149

Pulled By: zhangguanheng66

fbshipit-source-id: a7a5a83890fdf99fee6eb87a3a5060b7b6bd862f
@@ -4446,7 +4468,7 @@ def test_unbind(self):

def test_linspace(self):
devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
for device in devices:
for _device in devices:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we change device to _device? this breaks the CI. I have created a fix: #17995
Let's whether it's enough or not

If not, let's revert this change.

@zhangguanheng66 zhangguanheng66 changed the title Speed up gemm by reordering the for loops [LANDED]Speed up gemm by reordering the for loops Mar 18, 2019
@zhangguanheng66 zhangguanheng66 deleted the export-D14325149 branch April 9, 2019 00:23
@ezyang ezyang added the merged label Jun 25, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants