-
Notifications
You must be signed in to change notification settings - Fork 25k
Add fused layer norm impl on CUDA in PyTorch #27634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This pull request was exported from Phabricator. Differential Revision: D17462420 |
linked to Issue #27633 |
Should we benchmark on |
443a61d
to
5074d64
Compare
This pull request was exported from Phabricator. Differential Revision: D17462420 |
The reason I didn't test input with such small size is the overhead in the ops will become comparable large while the result is not very stable. One of the test with such small input is here https://gist.github.com/BIT-silence/4c762eb6c4db55d58a57a43f492674df Here is the profiler result for the large input https://gist.github.com/BIT-silence/ba01562912bf5815a8059093a87f7c15. BTW, I found I made a mistake when doing the benchmark that I included the gc.collect time for the 1000 runs, I have changed the benchmark script and update the results above. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, can you please do a minor cleanup?
5074d64
to
1297972
Compare
This pull request was exported from Phabricator. Differential Revision: D17462420 |
1297972
to
2640458
Compare
This pull request was exported from Phabricator. Differential Revision: D17462420 |
2640458
to
3829aa3
Compare
This pull request was exported from Phabricator. Differential Revision: D17462420 |
3829aa3
to
218cf04
Compare
d50d0e5
to
832a8d3
Compare
This pull request was exported from Phabricator. Differential Revision: D17462420 |
832a8d3
to
bb13141
Compare
This pull request was exported from Phabricator. Differential Revision: D17462420 |
Summary: Pull Request resolved: pytorch#27634 Add fused layer norm impl on CUDA in PyTorch Performance benchmark compare to apex.FusedLayerNorm on a V100 machine. ************************************** Shape = (128, 2097152) curr LayerNorm forward: 7.252584544941783ms apex LayerNorm forward: 10.366813436849043ms curr LayerNorm backward: 15.568048988003284ms apex LayerNorm backward: 20.869979876093566ms ************************************** Shape = (256, 1048576) curr LayerNorm forward: 5.185673736967146ms apex LayerNorm forward: 6.3868385690730065ms curr LayerNorm backward: 13.942008479032665ms apex LayerNorm backward: 15.469660016940907ms ************************************** Shape = (512, 524288) curr LayerNorm forward: 4.672068868065253ms apex LayerNorm forward: 4.717993081081659ms curr LayerNorm backward: 13.46354596503079ms apex LayerNorm backward: 14.04774487693794ms ************************************** Shape = (1024, 262144) curr LayerNorm forward: 4.547273400006816ms apex LayerNorm forward: 5.378365494078025ms curr LayerNorm backward: 13.425063178874552ms apex LayerNorm backward: 14.235145597020164ms ************************************** Shape = (2048, 131072) curr LayerNorm forward: 4.526399010093883ms apex LayerNorm forward: 4.775081946980208ms curr LayerNorm backward: 13.222738380078226ms apex LayerNorm backward: 13.59594238596037ms ************************************** Shape = (4096, 65536) curr LayerNorm forward: 4.28789056581445ms apex LayerNorm forward: 4.48913648002781ms curr LayerNorm backward: 13.026655421825126ms apex LayerNorm backward: 13.57052089786157ms ************************************** Shape = (8192, 32768) curr LayerNorm forward: 4.243518367875367ms apex LayerNorm forward: 4.34588153520599ms curr LayerNorm backward: 13.140627697808668ms apex LayerNorm backward: 13.49891544203274ms ************************************** Shape = (16384, 16384) curr LayerNorm forward: 4.181216162163764ms apex LayerNorm forward: 4.268723972840235ms curr LayerNorm backward: 13.035593512002379ms apex LayerNorm backward: 13.463351831072941ms ************************************** Shape = (32768, 8192) curr LayerNorm forward: 4.097899778978899ms apex LayerNorm forward: 4.109480210812762ms curr LayerNorm backward: 13.041268918896094ms apex LayerNorm backward: 13.586135944118723ms Test Plan: buck test mode/dev-nosan caffe2/test:nn -- "LayerNorm" Differential Revision: D17462420 fbshipit-source-id: ab8859272d21b94d9a2849a2255b9b5859ab6e38
bb13141
to
90ca0f6
Compare
This pull request was exported from Phabricator. Differential Revision: D17462420 |
Summary: Pull Request resolved: pytorch/pytorch#27634 Add fused layer norm impl on CUDA in PyTorch Performance benchmark compare to apex.FusedLayerNorm on a V100 machine. ************************************** Shape = (128, 2097152) curr LayerNorm forward: 7.252584544941783ms apex LayerNorm forward: 10.366813436849043ms curr LayerNorm backward: 15.568048988003284ms apex LayerNorm backward: 20.869979876093566ms ************************************** Shape = (256, 1048576) curr LayerNorm forward: 5.185673736967146ms apex LayerNorm forward: 6.3868385690730065ms curr LayerNorm backward: 13.942008479032665ms apex LayerNorm backward: 15.469660016940907ms ************************************** Shape = (512, 524288) curr LayerNorm forward: 4.672068868065253ms apex LayerNorm forward: 4.717993081081659ms curr LayerNorm backward: 13.46354596503079ms apex LayerNorm backward: 14.04774487693794ms ************************************** Shape = (1024, 262144) curr LayerNorm forward: 4.547273400006816ms apex LayerNorm forward: 5.378365494078025ms curr LayerNorm backward: 13.425063178874552ms apex LayerNorm backward: 14.235145597020164ms ************************************** Shape = (2048, 131072) curr LayerNorm forward: 4.526399010093883ms apex LayerNorm forward: 4.775081946980208ms curr LayerNorm backward: 13.222738380078226ms apex LayerNorm backward: 13.59594238596037ms ************************************** Shape = (4096, 65536) curr LayerNorm forward: 4.28789056581445ms apex LayerNorm forward: 4.48913648002781ms curr LayerNorm backward: 13.026655421825126ms apex LayerNorm backward: 13.57052089786157ms ************************************** Shape = (8192, 32768) curr LayerNorm forward: 4.243518367875367ms apex LayerNorm forward: 4.34588153520599ms curr LayerNorm backward: 13.140627697808668ms apex LayerNorm backward: 13.49891544203274ms ************************************** Shape = (16384, 16384) curr LayerNorm forward: 4.181216162163764ms apex LayerNorm forward: 4.268723972840235ms curr LayerNorm backward: 13.035593512002379ms apex LayerNorm backward: 13.463351831072941ms ************************************** Shape = (32768, 8192) curr LayerNorm forward: 4.097899778978899ms apex LayerNorm forward: 4.109480210812762ms curr LayerNorm backward: 13.041268918896094ms apex LayerNorm backward: 13.586135944118723ms Test Plan: buck test mode/dev-nosan caffe2/test:nn -- "LayerNorm" Reviewed By: houseroad Differential Revision: D17462420 fbshipit-source-id: d4a67d160bb4eff73ffac64af46c56c3845cf211
This pull request has been merged in 8b87f9a. |
just a naming thing (and it appears this name probably just comes from apex), but it's strange to refer to an operator as "fused" but not specify what is fused. |
Summary: Pull Request resolved: pytorch#27634 Add fused layer norm impl on CUDA in PyTorch Performance benchmark compare to apex.FusedLayerNorm on a V100 machine. ************************************** Shape = (128, 2097152) curr LayerNorm forward: 7.252584544941783ms apex LayerNorm forward: 10.366813436849043ms curr LayerNorm backward: 15.568048988003284ms apex LayerNorm backward: 20.869979876093566ms ************************************** Shape = (256, 1048576) curr LayerNorm forward: 5.185673736967146ms apex LayerNorm forward: 6.3868385690730065ms curr LayerNorm backward: 13.942008479032665ms apex LayerNorm backward: 15.469660016940907ms ************************************** Shape = (512, 524288) curr LayerNorm forward: 4.672068868065253ms apex LayerNorm forward: 4.717993081081659ms curr LayerNorm backward: 13.46354596503079ms apex LayerNorm backward: 14.04774487693794ms ************************************** Shape = (1024, 262144) curr LayerNorm forward: 4.547273400006816ms apex LayerNorm forward: 5.378365494078025ms curr LayerNorm backward: 13.425063178874552ms apex LayerNorm backward: 14.235145597020164ms ************************************** Shape = (2048, 131072) curr LayerNorm forward: 4.526399010093883ms apex LayerNorm forward: 4.775081946980208ms curr LayerNorm backward: 13.222738380078226ms apex LayerNorm backward: 13.59594238596037ms ************************************** Shape = (4096, 65536) curr LayerNorm forward: 4.28789056581445ms apex LayerNorm forward: 4.48913648002781ms curr LayerNorm backward: 13.026655421825126ms apex LayerNorm backward: 13.57052089786157ms ************************************** Shape = (8192, 32768) curr LayerNorm forward: 4.243518367875367ms apex LayerNorm forward: 4.34588153520599ms curr LayerNorm backward: 13.140627697808668ms apex LayerNorm backward: 13.49891544203274ms ************************************** Shape = (16384, 16384) curr LayerNorm forward: 4.181216162163764ms apex LayerNorm forward: 4.268723972840235ms curr LayerNorm backward: 13.035593512002379ms apex LayerNorm backward: 13.463351831072941ms ************************************** Shape = (32768, 8192) curr LayerNorm forward: 4.097899778978899ms apex LayerNorm forward: 4.109480210812762ms curr LayerNorm backward: 13.041268918896094ms apex LayerNorm backward: 13.586135944118723ms Test Plan: buck test mode/dev-nosan caffe2/test:nn -- "LayerNorm" Reviewed By: houseroad Differential Revision: D17462420 fbshipit-source-id: d4a67d160bb4eff73ffac64af46c56c3845cf211
We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when `fs` (=`config_m` in our benchmark script) is large and `bs` (=`config_n` in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of [PR #68238](#68238 (comment)) on AMD GPUs. This PR is to replace `GammaBetaBackwardCUDAKernel` with the Apex layernorm backward kernel with some ROCm-specific parameter tuning when `fs` (=`config_m`) is larger than 512 on AMD GPUs. There are a few PRs for LayerNorm kernel: - #26201 - #27634 - #68238 Therefore, we have tested and compared the kernel before and at this PR with the input shapes in the last two PRs along with those commonly used in the CvT model on AMD MI100. --- **Current** <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float) -- | -- | -- | -- | -- | -- 50432 | 384 | 0.387256 | 1.372758 | 0.378975 | 1.47892 50176 | 384 | 0.38231 | 1.362416 | 0.378084 | 1.473886 200704 | 192 | 0.997859 | 4.315875 | 0.989306 | 4.560827 802816 | 64 | 3.671828 | 16.68013 | 3.613515 | 16.827946 200 | 256 | 0.066503 | 0.332096 | 0.071422 | 0.325349 1000 | 256 | 0.071848 | 0.333355 | 0.073038 | 0.334753 6000 | 256 | 0.086334 | 0.345139 | 0.086834 | 0.347429 6272 | 256 | 0.088601 | 0.347906 | 0.087855 | 0.351245 200 | 512 | 0.071626 | 0.329726 | 0.073798 | 0.326878 1000 | 512 | 0.073975 | 0.330226 | 0.074166 | 0.332751 6000 | 512 | 0.099617 | 0.362367 | 0.100095 | 0.378313 6272 | 512 | 0.100378 | 0.358066 | 0.099857 | 0.395982 200 | 1024 | 0.072954 | 0.326382 | 0.073899 | 0.333007 1000 | 1024 | 0.0743 | 0.325532 | 0.071126 | 0.330991 6000 | 1024 | 0.127025 | 0.390084 | 0.128692 | 0.471504 6272 | 1024 | 0.130704 | 0.403536 | 0.135244 | 0.487133 200 | 1536 | 0.070331 | 0.339169 | 0.070086 | 0.331015 1000 | 1536 | 0.075085 | 0.330042 | 0.076295 | 0.328778 6000 | 1536 | 0.148889 | 0.44949 | 0.155781 | 0.659987 6272 | 1536 | 0.154939 | 0.478871 | 0.17673 | 0.716025 200 | 2048 | 0.070269 | 0.335585 | 0.072804 | 0.334655 1000 | 2048 | 0.080094 | 0.326991 | 0.080426 | 0.32685 6000 | 2048 | 0.187888 | 0.623023 | 0.245762 | 0.981635 6272 | 2048 | 0.195431 | 0.65244 | 0.262574 | 1.008141 200 | 3072 | 0.068205 | 0.339428 | 0.073068 | 0.344034 1000 | 3072 | 0.087554 | 0.328899 | 0.09218 | 0.346433 6000 | 3072 | 0.240352 | 0.905058 | 0.368135 | 1.280462 6272 | 3072 | 0.26179 | 0.959387 | 0.387782 | 1.476524 128 | 2097152 | 5.905976 | 22.724793 | 10.287974 | 30.242092 256 | 1048576 | 4.561596 | 19.554308 | 10.223171 | 29.42371 512 | 524288 | 4.146751 | 22.7247 | 11.404285 | 39.175902 1024 | 262144 | 5.193135 | 23.403325 | 11.334512 | 38.947192 2048 | 131072 | 4.992907 | 23.377801 | 11.400286 | 40.889191 4096 | 65536 | 5.429488 | 24.275701 | 11.196778 | 41.4751 8192 | 32768 | 5.35758 | 21.360312 | 10.535418 | 42.875646 16384 | 16384 | 5.44947 | 20.852605 | 10.357685 | 34.603408 32768 | 8192 | 4.688925 | 17.379392 | 9.635596 | 31.188271 </body> </html> --------- **At this PR** <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} .xl63 {color:windowtext;} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float) -- | -- | -- | -- | -- | -- 50432 | 384 | 0.38797 | 0.93103 | 0.37966 | 1.15283 50176 | 384 | 0.3874 | 0.96417 | 0.38462 | 1.18595 200704 | 192 | 1.00002 | 2.40876 | 0.99224 | 2.55579 802816 | 64 | 3.67348 | 7.98658 | 3.61871 | 7.72404 200 | 256 | 0.07292 | 0.35119 | 0.07195 | 0.32602 1000 | 256 | 0.07354 | 0.33325 | 0.07237 | 0.33742 6000 | 256 | 0.08819 | 0.33283 | 0.08453 | 0.3279 6272 | 256 | 0.0886 | 0.33446 | 0.08774 | 0.33426 200 | 512 | 0.0701 | 0.33505 | 0.07072 | 0.33018 1000 | 512 | 0.07042 | 0.33442 | 0.074 | 0.33206 6000 | 512 | 0.09931 | 0.34956 | 0.09895 | 0.3572 6272 | 512 | 0.10103 | 0.32976 | 0.10041 | 0.36635 200 | 1024 | 0.07144 | 0.33579 | 0.07209 | 0.33216 1000 | 1024 | 0.0736 | 0.32803 | 0.07286 | 0.32936 6000 | 1024 | 0.12584 | 0.38916 | 0.12852 | 0.48273 6272 | 1024 | 0.13053 | 0.38804 | 0.13464 | 0.49545 200 | 1536 | 0.07159 | 0.3396 | 0.07062 | 0.33545 1000 | 1536 | 0.07443 | 0.33239 | 0.07366 | 0.33204 6000 | 1536 | 0.14959 | 0.45043 | 0.15826 | 0.69119 6272 | 1536 | 0.1542 | 0.47644 | 0.18249 | 0.72208 200 | 2048 | 0.07258 | 0.33982 | 0.07412 | 0.33859 1000 | 2048 | 0.0793 | 0.32816 | 0.07864 | 0.32583 6000 | 2048 | 0.18973 | 0.571 | 0.25506 | 0.91796 6272 | 2048 | 0.19719 | 0.64208 | 0.26445 | 0.95055 200 | 3072 | 0.07092 | 0.33867 | 0.07104 | 0.34695 1000 | 3072 | 0.08727 | 0.33144 | 0.09144 | 0.36633 6000 | 3072 | 0.24683 | 0.87275 | 0.37761 | 1.3289 6272 | 3072 | 0.26437 | 0.91178 | 0.38496 | 1.53694 128 | 2097152 | 6.27936 | 23.69425 | 10.40004 | 30.13699 256 | 1048576 | 4.5404 | 19.47675 | 10.28494 | 29.36936 512 | 524288 | 4.13951 | 18.78771 | 10.09557 | 32.67083 1024 | 262144 | 4.47576 | 18.00411 | 9.56488 | 31.47117 2048 | 131072 | 4.28026 | 16.95619 | 9.40297 | 30.82845 4096 | 65536 | 4.2653 | 16.5018 | 9.03315 | 30.08392 8192 | 32768 | 4.25613 | 16.13583 | 8.9258 | 30.75296 16384 | 16384 | 4.20256 | 16.38207 | 9.52587 | 31.31113 32768 | 8192 | 4.20231 | 16.19452 | 9.31478 | 31.03514 </body> </html> --------- **Performance Improvement (%)** <html xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:dt="uuid:C2F41010-65B3-11d1-A29F-00AA00C14882" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=OneNote.File> <meta name=Generator content="Microsoft OneNote 15"> </head> <body lang=en-US style='font-family:Calibri;font-size:11.0pt'> <!--StartFragment--> <div style='direction:ltr'> M | N | fwdbwd, torch.float16 | fwdbwd, torch.float32 -- | -- | -- | -- 50432 | 384 | 32.178 | 22.049 50176 | 384 | 29.231 | 19.536 200704 | 192 | 44.188 | 43.962 802816 | 64 | 52.119 | 54.100 200 | 256 | -5.750 | -0.206 1000 | 256 | 0.031 | -0.797 6000 | 256 | 3.566 | 5.621 6272 | 256 | 3.865 | 4.836 200 | 512 | -1.615 | -1.010 1000 | 512 | -1.270 | 0.208 6000 | 512 | 3.534 | 5.581 6272 | 512 | 7.905 | 7.483 200 | 1024 | -2.883 | 0.254 1000 | 1024 | -0.767 | 0.493 6000 | 1024 | 0.237 | -2.381 6272 | 1024 | 3.840 | -1.707 200 | 1536 | -0.127 | -1.340 1000 | 1536 | -0.711 | -0.992 6000 | 1536 | -0.209 | -4.728 6272 | 1536 | 0.508 | -0.846 200 | 2048 | -1.262 | -1.176 1000 | 2048 | -0.358 | 0.312 6000 | 2048 | 8.350 | 6.487 6272 | 2048 | 1.588 | 5.713 200 | 3072 | 0.223 | -0.848 1000 | 3072 | -0.773 | -5.743 6000 | 3072 | 3.570 | -3.783 6272 | 3072 | 4.962 | -4.092 128 | 2097152 | -4.266 | 0.348 256 | 1048576 | 0.397 | 0.185 512 | 524288 | 17.325 | 16.605 1024 | 262144 | 23.070 | 19.195 2048 | 131072 | 27.469 | 24.605 4096 | 65536 | 32.023 | 27.465 8192 | 32768 | 24.459 | 28.274 16384 | 16384 | 21.439 | 9.514 32768 | 8192 | 6.818 | 0.491 </div> <!--EndFragment--> </body> </html> --------- **Benchmark script of this PR** ``` # Ref: # 1. #26201 # 2. #68238 from distutils.command.config import config import torch from torch.nn import LayerNorm import timeit number_runs = 1000 # TODO: Modify this to save time! def test_forward(layer_norm_cuda, input_cuda): layer_norm_cuda(input_cuda); torch.cuda.synchronize() def test_backward(out_cuda, layer_norm_grad_cuda, create_graph): out_cuda.backward(layer_norm_grad_cuda, retain_graph=True, create_graph=create_graph); torch.cuda.synchronize() def test_fwdbwd(input_cuda, layer_norm_cuda, gO): input_cuda.grad = None layer_norm_cuda.zero_grad(set_to_none=True) out = layer_norm_cuda(input_cuda) out.backward(gO) torch.cuda.synchronize() def benchmark(config_m, config_n): print("M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)") if len(config_m) != len(config_n): print("Please make sure the lengths of config_m and config_m are the same.") for i in range(len(config_m)): normalized_shape = config_n[i] results = [config_m[i], config_n[i]] for dtype in (torch.half, torch.float): if dtype == torch.half: layer_norm_cuda = LayerNorm(normalized_shape).half().cuda() else: layer_norm_cuda = LayerNorm(normalized_shape).cuda() input_cuda = torch.randn(config_m[i], config_n[i], device='cuda', dtype=dtype, requires_grad=True) # print("cuda forward:") result_fwd = timeit.timeit(lambda: test_forward(layer_norm_cuda, input_cuda), number=number_runs) results.append(result_fwd / number_runs * 1000) gO = torch.rand_like(input_cuda) result_fwdbwd = timeit.timeit(lambda: test_fwdbwd(input_cuda, layer_norm_cuda, gO), number=number_runs) results.append(result_fwdbwd / number_runs * 1000) print('{:09d}|{:09d}|{:9.5f}|{:9.5f}|{:9.5f}|{:9.5f}'.format(results[0], results[1], results[2], results[3], results[4], results[5])) print("Times are in microseconds (us).") # CVT config_m_cvt = [50432, 50176, 200704, 802816] config_n_cvt = [384, 384, 192, 64] # #68238 (comment) config_m_68238 = [200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272] config_n_68238 = [256,256,256,256,512,512,512,512,1024,1024,1024,1024,1536,1536,1536,1536,2048,2048,2048,2048,3072,3072,3072,3072] # #27634 config_m_27634 = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768] config_n_27634 = [2097152, 1048576, 524288, 262144, 131072, 65536, 32768, 16384, 8192] config_m = config_m_cvt + config_m_68238 + config_m_27634 config_n = config_n_cvt + config_n_68238 + config_n_27634 benchmark(config_m, config_n) ``` CC: @jeffdaily Pull Request resolved: #87635 Approved by: https://github.com/jataylo, https://github.com/jeffdaily, https://github.com/ezyang
We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when `fs` (=`config_m` in our benchmark script) is large and `bs` (=`config_n` in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of [PR pytorch#68238](pytorch#68238 (comment)) on AMD GPUs. This PR is to replace `GammaBetaBackwardCUDAKernel` with the Apex layernorm backward kernel with some ROCm-specific parameter tuning when `fs` (=`config_m`) is larger than 512 on AMD GPUs. There are a few PRs for LayerNorm kernel: - pytorch#26201 - pytorch#27634 - pytorch#68238 Therefore, we have tested and compared the kernel before and at this PR with the input shapes in the last two PRs along with those commonly used in the CvT model on AMD MI100. --- **Current** <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float) -- | -- | -- | -- | -- | -- 50432 | 384 | 0.387256 | 1.372758 | 0.378975 | 1.47892 50176 | 384 | 0.38231 | 1.362416 | 0.378084 | 1.473886 200704 | 192 | 0.997859 | 4.315875 | 0.989306 | 4.560827 802816 | 64 | 3.671828 | 16.68013 | 3.613515 | 16.827946 200 | 256 | 0.066503 | 0.332096 | 0.071422 | 0.325349 1000 | 256 | 0.071848 | 0.333355 | 0.073038 | 0.334753 6000 | 256 | 0.086334 | 0.345139 | 0.086834 | 0.347429 6272 | 256 | 0.088601 | 0.347906 | 0.087855 | 0.351245 200 | 512 | 0.071626 | 0.329726 | 0.073798 | 0.326878 1000 | 512 | 0.073975 | 0.330226 | 0.074166 | 0.332751 6000 | 512 | 0.099617 | 0.362367 | 0.100095 | 0.378313 6272 | 512 | 0.100378 | 0.358066 | 0.099857 | 0.395982 200 | 1024 | 0.072954 | 0.326382 | 0.073899 | 0.333007 1000 | 1024 | 0.0743 | 0.325532 | 0.071126 | 0.330991 6000 | 1024 | 0.127025 | 0.390084 | 0.128692 | 0.471504 6272 | 1024 | 0.130704 | 0.403536 | 0.135244 | 0.487133 200 | 1536 | 0.070331 | 0.339169 | 0.070086 | 0.331015 1000 | 1536 | 0.075085 | 0.330042 | 0.076295 | 0.328778 6000 | 1536 | 0.148889 | 0.44949 | 0.155781 | 0.659987 6272 | 1536 | 0.154939 | 0.478871 | 0.17673 | 0.716025 200 | 2048 | 0.070269 | 0.335585 | 0.072804 | 0.334655 1000 | 2048 | 0.080094 | 0.326991 | 0.080426 | 0.32685 6000 | 2048 | 0.187888 | 0.623023 | 0.245762 | 0.981635 6272 | 2048 | 0.195431 | 0.65244 | 0.262574 | 1.008141 200 | 3072 | 0.068205 | 0.339428 | 0.073068 | 0.344034 1000 | 3072 | 0.087554 | 0.328899 | 0.09218 | 0.346433 6000 | 3072 | 0.240352 | 0.905058 | 0.368135 | 1.280462 6272 | 3072 | 0.26179 | 0.959387 | 0.387782 | 1.476524 128 | 2097152 | 5.905976 | 22.724793 | 10.287974 | 30.242092 256 | 1048576 | 4.561596 | 19.554308 | 10.223171 | 29.42371 512 | 524288 | 4.146751 | 22.7247 | 11.404285 | 39.175902 1024 | 262144 | 5.193135 | 23.403325 | 11.334512 | 38.947192 2048 | 131072 | 4.992907 | 23.377801 | 11.400286 | 40.889191 4096 | 65536 | 5.429488 | 24.275701 | 11.196778 | 41.4751 8192 | 32768 | 5.35758 | 21.360312 | 10.535418 | 42.875646 16384 | 16384 | 5.44947 | 20.852605 | 10.357685 | 34.603408 32768 | 8192 | 4.688925 | 17.379392 | 9.635596 | 31.188271 </body> </html> --------- **At this PR** <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} .xl63 {color:windowtext;} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float) -- | -- | -- | -- | -- | -- 50432 | 384 | 0.38797 | 0.93103 | 0.37966 | 1.15283 50176 | 384 | 0.3874 | 0.96417 | 0.38462 | 1.18595 200704 | 192 | 1.00002 | 2.40876 | 0.99224 | 2.55579 802816 | 64 | 3.67348 | 7.98658 | 3.61871 | 7.72404 200 | 256 | 0.07292 | 0.35119 | 0.07195 | 0.32602 1000 | 256 | 0.07354 | 0.33325 | 0.07237 | 0.33742 6000 | 256 | 0.08819 | 0.33283 | 0.08453 | 0.3279 6272 | 256 | 0.0886 | 0.33446 | 0.08774 | 0.33426 200 | 512 | 0.0701 | 0.33505 | 0.07072 | 0.33018 1000 | 512 | 0.07042 | 0.33442 | 0.074 | 0.33206 6000 | 512 | 0.09931 | 0.34956 | 0.09895 | 0.3572 6272 | 512 | 0.10103 | 0.32976 | 0.10041 | 0.36635 200 | 1024 | 0.07144 | 0.33579 | 0.07209 | 0.33216 1000 | 1024 | 0.0736 | 0.32803 | 0.07286 | 0.32936 6000 | 1024 | 0.12584 | 0.38916 | 0.12852 | 0.48273 6272 | 1024 | 0.13053 | 0.38804 | 0.13464 | 0.49545 200 | 1536 | 0.07159 | 0.3396 | 0.07062 | 0.33545 1000 | 1536 | 0.07443 | 0.33239 | 0.07366 | 0.33204 6000 | 1536 | 0.14959 | 0.45043 | 0.15826 | 0.69119 6272 | 1536 | 0.1542 | 0.47644 | 0.18249 | 0.72208 200 | 2048 | 0.07258 | 0.33982 | 0.07412 | 0.33859 1000 | 2048 | 0.0793 | 0.32816 | 0.07864 | 0.32583 6000 | 2048 | 0.18973 | 0.571 | 0.25506 | 0.91796 6272 | 2048 | 0.19719 | 0.64208 | 0.26445 | 0.95055 200 | 3072 | 0.07092 | 0.33867 | 0.07104 | 0.34695 1000 | 3072 | 0.08727 | 0.33144 | 0.09144 | 0.36633 6000 | 3072 | 0.24683 | 0.87275 | 0.37761 | 1.3289 6272 | 3072 | 0.26437 | 0.91178 | 0.38496 | 1.53694 128 | 2097152 | 6.27936 | 23.69425 | 10.40004 | 30.13699 256 | 1048576 | 4.5404 | 19.47675 | 10.28494 | 29.36936 512 | 524288 | 4.13951 | 18.78771 | 10.09557 | 32.67083 1024 | 262144 | 4.47576 | 18.00411 | 9.56488 | 31.47117 2048 | 131072 | 4.28026 | 16.95619 | 9.40297 | 30.82845 4096 | 65536 | 4.2653 | 16.5018 | 9.03315 | 30.08392 8192 | 32768 | 4.25613 | 16.13583 | 8.9258 | 30.75296 16384 | 16384 | 4.20256 | 16.38207 | 9.52587 | 31.31113 32768 | 8192 | 4.20231 | 16.19452 | 9.31478 | 31.03514 </body> </html> --------- **Performance Improvement (%)** <html xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:dt="uuid:C2F41010-65B3-11d1-A29F-00AA00C14882" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=OneNote.File> <meta name=Generator content="Microsoft OneNote 15"> </head> <body lang=en-US style='font-family:Calibri;font-size:11.0pt'> <!--StartFragment--> <div style='direction:ltr'> M | N | fwdbwd, torch.float16 | fwdbwd, torch.float32 -- | -- | -- | -- 50432 | 384 | 32.178 | 22.049 50176 | 384 | 29.231 | 19.536 200704 | 192 | 44.188 | 43.962 802816 | 64 | 52.119 | 54.100 200 | 256 | -5.750 | -0.206 1000 | 256 | 0.031 | -0.797 6000 | 256 | 3.566 | 5.621 6272 | 256 | 3.865 | 4.836 200 | 512 | -1.615 | -1.010 1000 | 512 | -1.270 | 0.208 6000 | 512 | 3.534 | 5.581 6272 | 512 | 7.905 | 7.483 200 | 1024 | -2.883 | 0.254 1000 | 1024 | -0.767 | 0.493 6000 | 1024 | 0.237 | -2.381 6272 | 1024 | 3.840 | -1.707 200 | 1536 | -0.127 | -1.340 1000 | 1536 | -0.711 | -0.992 6000 | 1536 | -0.209 | -4.728 6272 | 1536 | 0.508 | -0.846 200 | 2048 | -1.262 | -1.176 1000 | 2048 | -0.358 | 0.312 6000 | 2048 | 8.350 | 6.487 6272 | 2048 | 1.588 | 5.713 200 | 3072 | 0.223 | -0.848 1000 | 3072 | -0.773 | -5.743 6000 | 3072 | 3.570 | -3.783 6272 | 3072 | 4.962 | -4.092 128 | 2097152 | -4.266 | 0.348 256 | 1048576 | 0.397 | 0.185 512 | 524288 | 17.325 | 16.605 1024 | 262144 | 23.070 | 19.195 2048 | 131072 | 27.469 | 24.605 4096 | 65536 | 32.023 | 27.465 8192 | 32768 | 24.459 | 28.274 16384 | 16384 | 21.439 | 9.514 32768 | 8192 | 6.818 | 0.491 </div> <!--EndFragment--> </body> </html> --------- **Benchmark script of this PR** ``` from distutils.command.config import config import torch from torch.nn import LayerNorm import timeit number_runs = 1000 # TODO: Modify this to save time! def test_forward(layer_norm_cuda, input_cuda): layer_norm_cuda(input_cuda); torch.cuda.synchronize() def test_backward(out_cuda, layer_norm_grad_cuda, create_graph): out_cuda.backward(layer_norm_grad_cuda, retain_graph=True, create_graph=create_graph); torch.cuda.synchronize() def test_fwdbwd(input_cuda, layer_norm_cuda, gO): input_cuda.grad = None layer_norm_cuda.zero_grad(set_to_none=True) out = layer_norm_cuda(input_cuda) out.backward(gO) torch.cuda.synchronize() def benchmark(config_m, config_n): print("M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)") if len(config_m) != len(config_n): print("Please make sure the lengths of config_m and config_m are the same.") for i in range(len(config_m)): normalized_shape = config_n[i] results = [config_m[i], config_n[i]] for dtype in (torch.half, torch.float): if dtype == torch.half: layer_norm_cuda = LayerNorm(normalized_shape).half().cuda() else: layer_norm_cuda = LayerNorm(normalized_shape).cuda() input_cuda = torch.randn(config_m[i], config_n[i], device='cuda', dtype=dtype, requires_grad=True) # print("cuda forward:") result_fwd = timeit.timeit(lambda: test_forward(layer_norm_cuda, input_cuda), number=number_runs) results.append(result_fwd / number_runs * 1000) gO = torch.rand_like(input_cuda) result_fwdbwd = timeit.timeit(lambda: test_fwdbwd(input_cuda, layer_norm_cuda, gO), number=number_runs) results.append(result_fwdbwd / number_runs * 1000) print('{:09d}|{:09d}|{:9.5f}|{:9.5f}|{:9.5f}|{:9.5f}'.format(results[0], results[1], results[2], results[3], results[4], results[5])) print("Times are in microseconds (us).") config_m_cvt = [50432, 50176, 200704, 802816] config_n_cvt = [384, 384, 192, 64] config_m_68238 = [200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272] config_n_68238 = [256,256,256,256,512,512,512,512,1024,1024,1024,1024,1536,1536,1536,1536,2048,2048,2048,2048,3072,3072,3072,3072] config_m_27634 = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768] config_n_27634 = [2097152, 1048576, 524288, 262144, 131072, 65536, 32768, 16384, 8192] config_m = config_m_cvt + config_m_68238 + config_m_27634 config_n = config_n_cvt + config_n_68238 + config_n_27634 benchmark(config_m, config_n) ``` CC: @jeffdaily Pull Request resolved: pytorch#87635 Approved by: https://github.com/jataylo, https://github.com/jeffdaily, https://github.com/ezyang
We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when `fs` (=`config_m` in our benchmark script) is large and `bs` (=`config_n` in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of [PR pytorch#68238](pytorch#68238 (comment)) on AMD GPUs. This PR is to replace `GammaBetaBackwardCUDAKernel` with the Apex layernorm backward kernel with some ROCm-specific parameter tuning when `fs` (=`config_m`) is larger than 512 on AMD GPUs. There are a few PRs for LayerNorm kernel: - pytorch#26201 - pytorch#27634 - pytorch#68238 Therefore, we have tested and compared the kernel before and at this PR with the input shapes in the last two PRs along with those commonly used in the CvT model on AMD MI100. --- **Current** <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float) -- | -- | -- | -- | -- | -- 50432 | 384 | 0.387256 | 1.372758 | 0.378975 | 1.47892 50176 | 384 | 0.38231 | 1.362416 | 0.378084 | 1.473886 200704 | 192 | 0.997859 | 4.315875 | 0.989306 | 4.560827 802816 | 64 | 3.671828 | 16.68013 | 3.613515 | 16.827946 200 | 256 | 0.066503 | 0.332096 | 0.071422 | 0.325349 1000 | 256 | 0.071848 | 0.333355 | 0.073038 | 0.334753 6000 | 256 | 0.086334 | 0.345139 | 0.086834 | 0.347429 6272 | 256 | 0.088601 | 0.347906 | 0.087855 | 0.351245 200 | 512 | 0.071626 | 0.329726 | 0.073798 | 0.326878 1000 | 512 | 0.073975 | 0.330226 | 0.074166 | 0.332751 6000 | 512 | 0.099617 | 0.362367 | 0.100095 | 0.378313 6272 | 512 | 0.100378 | 0.358066 | 0.099857 | 0.395982 200 | 1024 | 0.072954 | 0.326382 | 0.073899 | 0.333007 1000 | 1024 | 0.0743 | 0.325532 | 0.071126 | 0.330991 6000 | 1024 | 0.127025 | 0.390084 | 0.128692 | 0.471504 6272 | 1024 | 0.130704 | 0.403536 | 0.135244 | 0.487133 200 | 1536 | 0.070331 | 0.339169 | 0.070086 | 0.331015 1000 | 1536 | 0.075085 | 0.330042 | 0.076295 | 0.328778 6000 | 1536 | 0.148889 | 0.44949 | 0.155781 | 0.659987 6272 | 1536 | 0.154939 | 0.478871 | 0.17673 | 0.716025 200 | 2048 | 0.070269 | 0.335585 | 0.072804 | 0.334655 1000 | 2048 | 0.080094 | 0.326991 | 0.080426 | 0.32685 6000 | 2048 | 0.187888 | 0.623023 | 0.245762 | 0.981635 6272 | 2048 | 0.195431 | 0.65244 | 0.262574 | 1.008141 200 | 3072 | 0.068205 | 0.339428 | 0.073068 | 0.344034 1000 | 3072 | 0.087554 | 0.328899 | 0.09218 | 0.346433 6000 | 3072 | 0.240352 | 0.905058 | 0.368135 | 1.280462 6272 | 3072 | 0.26179 | 0.959387 | 0.387782 | 1.476524 128 | 2097152 | 5.905976 | 22.724793 | 10.287974 | 30.242092 256 | 1048576 | 4.561596 | 19.554308 | 10.223171 | 29.42371 512 | 524288 | 4.146751 | 22.7247 | 11.404285 | 39.175902 1024 | 262144 | 5.193135 | 23.403325 | 11.334512 | 38.947192 2048 | 131072 | 4.992907 | 23.377801 | 11.400286 | 40.889191 4096 | 65536 | 5.429488 | 24.275701 | 11.196778 | 41.4751 8192 | 32768 | 5.35758 | 21.360312 | 10.535418 | 42.875646 16384 | 16384 | 5.44947 | 20.852605 | 10.357685 | 34.603408 32768 | 8192 | 4.688925 | 17.379392 | 9.635596 | 31.188271 </body> </html> --------- **At this PR** <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=Excel.Sheet> <meta name=Generator content="Microsoft Excel 15"> <link id=Main-File rel=Main-File href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm"> <link rel=File-List href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml"> <!--table {mso-displayed-decimal-separator:"\."; mso-displayed-thousand-separator:"\,";} @page {mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D"; margin:.75in .7in .75in .7in; mso-header-margin:.3in; mso-footer-margin:.3in;} tr {mso-height-source:auto;} col {mso-width-source:auto;} br {mso-data-placement:same-cell;} td {padding-top:1px; padding-right:1px; padding-left:1px; mso-ignore:padding; color:black; font-size:11.0pt; font-weight:400; font-style:normal; text-decoration:none; font-family:Calibri, sans-serif; mso-font-charset:0; mso-number-format:General; text-align:general; vertical-align:bottom; border:none; mso-background-source:auto; mso-pattern:auto; mso-protection:locked visible; white-space:nowrap; mso-rotate:0;} .xl63 {color:windowtext;} --> </head> <body link="#0563C1" vlink="#954F72"> M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float) -- | -- | -- | -- | -- | -- 50432 | 384 | 0.38797 | 0.93103 | 0.37966 | 1.15283 50176 | 384 | 0.3874 | 0.96417 | 0.38462 | 1.18595 200704 | 192 | 1.00002 | 2.40876 | 0.99224 | 2.55579 802816 | 64 | 3.67348 | 7.98658 | 3.61871 | 7.72404 200 | 256 | 0.07292 | 0.35119 | 0.07195 | 0.32602 1000 | 256 | 0.07354 | 0.33325 | 0.07237 | 0.33742 6000 | 256 | 0.08819 | 0.33283 | 0.08453 | 0.3279 6272 | 256 | 0.0886 | 0.33446 | 0.08774 | 0.33426 200 | 512 | 0.0701 | 0.33505 | 0.07072 | 0.33018 1000 | 512 | 0.07042 | 0.33442 | 0.074 | 0.33206 6000 | 512 | 0.09931 | 0.34956 | 0.09895 | 0.3572 6272 | 512 | 0.10103 | 0.32976 | 0.10041 | 0.36635 200 | 1024 | 0.07144 | 0.33579 | 0.07209 | 0.33216 1000 | 1024 | 0.0736 | 0.32803 | 0.07286 | 0.32936 6000 | 1024 | 0.12584 | 0.38916 | 0.12852 | 0.48273 6272 | 1024 | 0.13053 | 0.38804 | 0.13464 | 0.49545 200 | 1536 | 0.07159 | 0.3396 | 0.07062 | 0.33545 1000 | 1536 | 0.07443 | 0.33239 | 0.07366 | 0.33204 6000 | 1536 | 0.14959 | 0.45043 | 0.15826 | 0.69119 6272 | 1536 | 0.1542 | 0.47644 | 0.18249 | 0.72208 200 | 2048 | 0.07258 | 0.33982 | 0.07412 | 0.33859 1000 | 2048 | 0.0793 | 0.32816 | 0.07864 | 0.32583 6000 | 2048 | 0.18973 | 0.571 | 0.25506 | 0.91796 6272 | 2048 | 0.19719 | 0.64208 | 0.26445 | 0.95055 200 | 3072 | 0.07092 | 0.33867 | 0.07104 | 0.34695 1000 | 3072 | 0.08727 | 0.33144 | 0.09144 | 0.36633 6000 | 3072 | 0.24683 | 0.87275 | 0.37761 | 1.3289 6272 | 3072 | 0.26437 | 0.91178 | 0.38496 | 1.53694 128 | 2097152 | 6.27936 | 23.69425 | 10.40004 | 30.13699 256 | 1048576 | 4.5404 | 19.47675 | 10.28494 | 29.36936 512 | 524288 | 4.13951 | 18.78771 | 10.09557 | 32.67083 1024 | 262144 | 4.47576 | 18.00411 | 9.56488 | 31.47117 2048 | 131072 | 4.28026 | 16.95619 | 9.40297 | 30.82845 4096 | 65536 | 4.2653 | 16.5018 | 9.03315 | 30.08392 8192 | 32768 | 4.25613 | 16.13583 | 8.9258 | 30.75296 16384 | 16384 | 4.20256 | 16.38207 | 9.52587 | 31.31113 32768 | 8192 | 4.20231 | 16.19452 | 9.31478 | 31.03514 </body> </html> --------- **Performance Improvement (%)** <html xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:dt="uuid:C2F41010-65B3-11d1-A29F-00AA00C14882" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=ProgId content=OneNote.File> <meta name=Generator content="Microsoft OneNote 15"> </head> <body lang=en-US style='font-family:Calibri;font-size:11.0pt'> <!--StartFragment--> <div style='direction:ltr'> M | N | fwdbwd, torch.float16 | fwdbwd, torch.float32 -- | -- | -- | -- 50432 | 384 | 32.178 | 22.049 50176 | 384 | 29.231 | 19.536 200704 | 192 | 44.188 | 43.962 802816 | 64 | 52.119 | 54.100 200 | 256 | -5.750 | -0.206 1000 | 256 | 0.031 | -0.797 6000 | 256 | 3.566 | 5.621 6272 | 256 | 3.865 | 4.836 200 | 512 | -1.615 | -1.010 1000 | 512 | -1.270 | 0.208 6000 | 512 | 3.534 | 5.581 6272 | 512 | 7.905 | 7.483 200 | 1024 | -2.883 | 0.254 1000 | 1024 | -0.767 | 0.493 6000 | 1024 | 0.237 | -2.381 6272 | 1024 | 3.840 | -1.707 200 | 1536 | -0.127 | -1.340 1000 | 1536 | -0.711 | -0.992 6000 | 1536 | -0.209 | -4.728 6272 | 1536 | 0.508 | -0.846 200 | 2048 | -1.262 | -1.176 1000 | 2048 | -0.358 | 0.312 6000 | 2048 | 8.350 | 6.487 6272 | 2048 | 1.588 | 5.713 200 | 3072 | 0.223 | -0.848 1000 | 3072 | -0.773 | -5.743 6000 | 3072 | 3.570 | -3.783 6272 | 3072 | 4.962 | -4.092 128 | 2097152 | -4.266 | 0.348 256 | 1048576 | 0.397 | 0.185 512 | 524288 | 17.325 | 16.605 1024 | 262144 | 23.070 | 19.195 2048 | 131072 | 27.469 | 24.605 4096 | 65536 | 32.023 | 27.465 8192 | 32768 | 24.459 | 28.274 16384 | 16384 | 21.439 | 9.514 32768 | 8192 | 6.818 | 0.491 </div> <!--EndFragment--> </body> </html> --------- **Benchmark script of this PR** ``` # Ref: # 1. pytorch#26201 # 2. pytorch#68238 from distutils.command.config import config import torch from torch.nn import LayerNorm import timeit number_runs = 1000 # TODO: Modify this to save time! def test_forward(layer_norm_cuda, input_cuda): layer_norm_cuda(input_cuda); torch.cuda.synchronize() def test_backward(out_cuda, layer_norm_grad_cuda, create_graph): out_cuda.backward(layer_norm_grad_cuda, retain_graph=True, create_graph=create_graph); torch.cuda.synchronize() def test_fwdbwd(input_cuda, layer_norm_cuda, gO): input_cuda.grad = None layer_norm_cuda.zero_grad(set_to_none=True) out = layer_norm_cuda(input_cuda) out.backward(gO) torch.cuda.synchronize() def benchmark(config_m, config_n): print("M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)") if len(config_m) != len(config_n): print("Please make sure the lengths of config_m and config_m are the same.") for i in range(len(config_m)): normalized_shape = config_n[i] results = [config_m[i], config_n[i]] for dtype in (torch.half, torch.float): if dtype == torch.half: layer_norm_cuda = LayerNorm(normalized_shape).half().cuda() else: layer_norm_cuda = LayerNorm(normalized_shape).cuda() input_cuda = torch.randn(config_m[i], config_n[i], device='cuda', dtype=dtype, requires_grad=True) # print("cuda forward:") result_fwd = timeit.timeit(lambda: test_forward(layer_norm_cuda, input_cuda), number=number_runs) results.append(result_fwd / number_runs * 1000) gO = torch.rand_like(input_cuda) result_fwdbwd = timeit.timeit(lambda: test_fwdbwd(input_cuda, layer_norm_cuda, gO), number=number_runs) results.append(result_fwdbwd / number_runs * 1000) print('{:09d}|{:09d}|{:9.5f}|{:9.5f}|{:9.5f}|{:9.5f}'.format(results[0], results[1], results[2], results[3], results[4], results[5])) print("Times are in microseconds (us).") # CVT config_m_cvt = [50432, 50176, 200704, 802816] config_n_cvt = [384, 384, 192, 64] # pytorch#68238 (comment) config_m_68238 = [200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272] config_n_68238 = [256,256,256,256,512,512,512,512,1024,1024,1024,1024,1536,1536,1536,1536,2048,2048,2048,2048,3072,3072,3072,3072] # pytorch#27634 config_m_27634 = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768] config_n_27634 = [2097152, 1048576, 524288, 262144, 131072, 65536, 32768, 16384, 8192] config_m = config_m_cvt + config_m_68238 + config_m_27634 config_n = config_n_cvt + config_n_68238 + config_n_27634 benchmark(config_m, config_n) ``` CC: @jeffdaily Pull Request resolved: pytorch#87635 Approved by: https://github.com/jataylo, https://github.com/jeffdaily, https://github.com/ezyang
Summary:
Add fused layer norm impl on CUDA in PyTorch
The second or higher order gradient part is copied from #26201. Thanks @zasdfgbnm provide a good example for dealing with higher order gradients.
The benchmark results for this PR vs original LayerNorm as well as apex.FusedLayerNorm are shown below. We can see some improvements compare to apex.FusedLayerNorm especially on common batch sizes such as 128 or 256.
Performance benchmark for original LayerNorm vs apex.FusedLayerNorm on a V100 machine
Performance benchmark for this PR vs apex.FusedLayerNorm on a V100 machine.
Performance benchmark script:
https://gist.github.com/BIT-silence/cc4221b51519acc5545b2cb8ce6599bc
Differential Revision: D17462420