Skip to content

Add fused layer norm impl on CUDA in PyTorch #27634

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

xiaomengy
Copy link
Contributor

@xiaomengy xiaomengy commented Oct 9, 2019

Summary:
Add fused layer norm impl on CUDA in PyTorch

The second or higher order gradient part is copied from #26201. Thanks @zasdfgbnm provide a good example for dealing with higher order gradients.

The benchmark results for this PR vs original LayerNorm as well as apex.FusedLayerNorm are shown below. We can see some improvements compare to apex.FusedLayerNorm especially on common batch sizes such as 128 or 256.

Performance benchmark for original LayerNorm vs apex.FusedLayerNorm on a V100 machine

**************************************
Shape = (128, 2097152)
  curr LayerNorm forward: 11.19029480102472ms
  apex LayerNorm forward: 10.364733319962397ms
  curr LayerNorm backward: 27.58014868805185ms
  apex LayerNorm backward: 20.948871094966307ms
**************************************
Shape = (256, 1048576)
  curr LayerNorm forward: 10.13169369683601ms
  apex LayerNorm forward: 6.387616391992196ms
  curr LayerNorm backward: 26.80497354711406ms
  apex LayerNorm backward: 15.511555707082152ms
**************************************
Shape = (512, 524288)
  curr LayerNorm forward: 9.748662753030658ms
  apex LayerNorm forward: 4.730404008179903ms
  curr LayerNorm backward: 26.34211923298426ms
  apex LayerNorm backward: 14.006239669863135ms
**************************************
Shape = (1024, 262144)
  curr LayerNorm forward: 8.120901573915035ms
  apex LayerNorm forward: 5.376706059090793ms
  curr LayerNorm backward: 25.534225502051413ms
  apex LayerNorm backward: 14.285464284941554ms
**************************************
Shape = (2048, 131072)
  curr LayerNorm forward: 7.154732370981947ms
  apex LayerNorm forward: 4.762221570825204ms
  curr LayerNorm backward: 25.278658757917583ms
  apex LayerNorm backward: 13.5387117639184ms
**************************************
Shape = (4096, 65536)
  curr LayerNorm forward: 6.96491337986663ms
  apex LayerNorm forward: 4.47763676289469ms
  curr LayerNorm backward: 25.304558301810175ms
  apex LayerNorm backward: 13.59957629814744ms
**************************************
Shape = (8192, 32768)
  curr LayerNorm forward: 6.964353390969336ms
  apex LayerNorm forward: 4.3496931828558445ms
  curr LayerNorm backward: 25.336067016003653ms
  apex LayerNorm backward: 13.46561166504398ms
**************************************
Shape = (16384, 16384)
  curr LayerNorm forward: 6.937638100003824ms
  apex LayerNorm forward: 4.273262439062819ms
  curr LayerNorm backward: 25.28494614805095ms
  apex LayerNorm backward: 13.464738595997915ms
**************************************
Shape = (32768, 8192)
  curr LayerNorm forward: 6.977811902062967ms
  apex LayerNorm forward: 4.115052119130269ms
  curr LayerNorm backward: 25.34099874086678ms
  apex LayerNorm backward: 13.600329003995284ms

Performance benchmark for this PR vs apex.FusedLayerNorm on a V100 machine.

**************************************
Shape = (128, 2097152)
  curr LayerNorm forward: 7.252584544941783ms
  apex LayerNorm forward: 10.366813436849043ms
  curr LayerNorm backward: 15.568048988003284ms
  apex LayerNorm backward: 20.869979876093566ms
**************************************
Shape = (256, 1048576)
  curr LayerNorm forward: 5.185673736967146ms
  apex LayerNorm forward: 6.3868385690730065ms
  curr LayerNorm backward: 13.942008479032665ms
  apex LayerNorm backward: 15.469660016940907ms
**************************************
Shape = (512, 524288)
  curr LayerNorm forward: 4.672068868065253ms
  apex LayerNorm forward: 4.717993081081659ms
  curr LayerNorm backward: 13.46354596503079ms
  apex LayerNorm backward: 14.04774487693794ms
**************************************
Shape = (1024, 262144)
  curr LayerNorm forward: 4.547273400006816ms
  apex LayerNorm forward: 5.378365494078025ms
  curr LayerNorm backward: 13.425063178874552ms
  apex LayerNorm backward: 14.235145597020164ms
**************************************
Shape = (2048, 131072)
  curr LayerNorm forward: 4.526399010093883ms
  apex LayerNorm forward: 4.775081946980208ms
  curr LayerNorm backward: 13.222738380078226ms
  apex LayerNorm backward: 13.59594238596037ms
**************************************
Shape = (4096, 65536)
  curr LayerNorm forward: 4.28789056581445ms
  apex LayerNorm forward: 4.48913648002781ms
  curr LayerNorm backward: 13.026655421825126ms
  apex LayerNorm backward: 13.57052089786157ms
**************************************
Shape = (8192, 32768)
  curr LayerNorm forward: 4.243518367875367ms
  apex LayerNorm forward: 4.34588153520599ms
  curr LayerNorm backward: 13.140627697808668ms
  apex LayerNorm backward: 13.49891544203274ms
**************************************
Shape = (16384, 16384)
  curr LayerNorm forward: 4.181216162163764ms
  apex LayerNorm forward: 4.268723972840235ms
  curr LayerNorm backward: 13.035593512002379ms
  apex LayerNorm backward: 13.463351831072941ms
**************************************
Shape = (32768, 8192)
  curr LayerNorm forward: 4.097899778978899ms
  apex LayerNorm forward: 4.109480210812762ms
  curr LayerNorm backward: 13.041268918896094ms
  apex LayerNorm backward: 13.586135944118723ms

Performance benchmark script:
https://gist.github.com/BIT-silence/cc4221b51519acc5545b2cb8ce6599bc

Differential Revision: D17462420

@xiaomengy xiaomengy requested a review from apaszke as a code owner October 9, 2019 21:23
@pytorchbot pytorchbot added oncall: jit Add this issue/PR to JIT oncall triage queue module: cpu CPU specific problem (e.g., perf, algorithm) module: cuda Related to torch.cuda, and CUDA support in general module: internals Related to internal abstractions in c10 and ATen module: nn Related to torch.nn module: operators labels Oct 9, 2019
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D17462420

@xiaomengy
Copy link
Contributor Author

linked to Issue #27633

@zasdfgbnm
Copy link
Collaborator

Should we benchmark on (M, N) where M could be as small as 8 and as large as 4096, and N being 768, 1024, 1280, 2048? These are the real-world use case I see in https://github.com/huggingface/transformers

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D17462420

@xiaomengy
Copy link
Contributor Author

Should we benchmark on (M, N) where M could be as small as 8 and as large as 4096, and N being 768, 1024, 1280, 2048? These are the real-world use case I see in https://github.com/huggingface/transformers

The reason I didn't test input with such small size is the overhead in the ops will become comparable large while the result is not very stable.

One of the test with such small input is here https://gist.github.com/BIT-silence/4c762eb6c4db55d58a57a43f492674df
While the timeit result cannot match the profiler result here https://gist.github.com/BIT-silence/27848840dc77f26fc459c8cc5311eb65.
And from the profiler result, we can see that things like empty_like will take a large percentage of time.

Here is the profiler result for the large input https://gist.github.com/BIT-silence/ba01562912bf5815a8059093a87f7c15.
We can see for the large inputs, the numbers from timeit and profiler can matches.

BTW, I found I made a mistake when doing the benchmark that I included the gc.collect time for the 1000 runs, I have changed the benchmark script and update the results above.

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, can you please do a minor cleanup?

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D17462420

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D17462420

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D17462420

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D17462420

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D17462420

Summary:
Pull Request resolved: pytorch#27634

Add fused layer norm impl on CUDA in PyTorch

Performance benchmark compare to apex.FusedLayerNorm on a V100 machine.

**************************************
Shape = (128, 2097152)
  curr LayerNorm forward: 7.252584544941783ms
  apex LayerNorm forward: 10.366813436849043ms
  curr LayerNorm backward: 15.568048988003284ms
  apex LayerNorm backward: 20.869979876093566ms
**************************************
Shape = (256, 1048576)
  curr LayerNorm forward: 5.185673736967146ms
  apex LayerNorm forward: 6.3868385690730065ms
  curr LayerNorm backward: 13.942008479032665ms
  apex LayerNorm backward: 15.469660016940907ms
**************************************
Shape = (512, 524288)
  curr LayerNorm forward: 4.672068868065253ms
  apex LayerNorm forward: 4.717993081081659ms
  curr LayerNorm backward: 13.46354596503079ms
  apex LayerNorm backward: 14.04774487693794ms
**************************************
Shape = (1024, 262144)
  curr LayerNorm forward: 4.547273400006816ms
  apex LayerNorm forward: 5.378365494078025ms
  curr LayerNorm backward: 13.425063178874552ms
  apex LayerNorm backward: 14.235145597020164ms
**************************************
Shape = (2048, 131072)
  curr LayerNorm forward: 4.526399010093883ms
  apex LayerNorm forward: 4.775081946980208ms
  curr LayerNorm backward: 13.222738380078226ms
  apex LayerNorm backward: 13.59594238596037ms
**************************************
Shape = (4096, 65536)
  curr LayerNorm forward: 4.28789056581445ms
  apex LayerNorm forward: 4.48913648002781ms
  curr LayerNorm backward: 13.026655421825126ms
  apex LayerNorm backward: 13.57052089786157ms
**************************************
Shape = (8192, 32768)
  curr LayerNorm forward: 4.243518367875367ms
  apex LayerNorm forward: 4.34588153520599ms
  curr LayerNorm backward: 13.140627697808668ms
  apex LayerNorm backward: 13.49891544203274ms
**************************************
Shape = (16384, 16384)
  curr LayerNorm forward: 4.181216162163764ms
  apex LayerNorm forward: 4.268723972840235ms
  curr LayerNorm backward: 13.035593512002379ms
  apex LayerNorm backward: 13.463351831072941ms
**************************************
Shape = (32768, 8192)
  curr LayerNorm forward: 4.097899778978899ms
  apex LayerNorm forward: 4.109480210812762ms
  curr LayerNorm backward: 13.041268918896094ms
  apex LayerNorm backward: 13.586135944118723ms

Test Plan: buck test mode/dev-nosan caffe2/test:nn -- "LayerNorm"

Differential Revision: D17462420

fbshipit-source-id: ab8859272d21b94d9a2849a2255b9b5859ab6e38
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D17462420

zdevito pushed a commit to zdevito/ATen that referenced this pull request Oct 15, 2019
Summary:
Pull Request resolved: pytorch/pytorch#27634

Add fused layer norm impl on CUDA in PyTorch

Performance benchmark compare to apex.FusedLayerNorm on a V100 machine.

**************************************
Shape = (128, 2097152)
  curr LayerNorm forward: 7.252584544941783ms
  apex LayerNorm forward: 10.366813436849043ms
  curr LayerNorm backward: 15.568048988003284ms
  apex LayerNorm backward: 20.869979876093566ms
**************************************
Shape = (256, 1048576)
  curr LayerNorm forward: 5.185673736967146ms
  apex LayerNorm forward: 6.3868385690730065ms
  curr LayerNorm backward: 13.942008479032665ms
  apex LayerNorm backward: 15.469660016940907ms
**************************************
Shape = (512, 524288)
  curr LayerNorm forward: 4.672068868065253ms
  apex LayerNorm forward: 4.717993081081659ms
  curr LayerNorm backward: 13.46354596503079ms
  apex LayerNorm backward: 14.04774487693794ms
**************************************
Shape = (1024, 262144)
  curr LayerNorm forward: 4.547273400006816ms
  apex LayerNorm forward: 5.378365494078025ms
  curr LayerNorm backward: 13.425063178874552ms
  apex LayerNorm backward: 14.235145597020164ms
**************************************
Shape = (2048, 131072)
  curr LayerNorm forward: 4.526399010093883ms
  apex LayerNorm forward: 4.775081946980208ms
  curr LayerNorm backward: 13.222738380078226ms
  apex LayerNorm backward: 13.59594238596037ms
**************************************
Shape = (4096, 65536)
  curr LayerNorm forward: 4.28789056581445ms
  apex LayerNorm forward: 4.48913648002781ms
  curr LayerNorm backward: 13.026655421825126ms
  apex LayerNorm backward: 13.57052089786157ms
**************************************
Shape = (8192, 32768)
  curr LayerNorm forward: 4.243518367875367ms
  apex LayerNorm forward: 4.34588153520599ms
  curr LayerNorm backward: 13.140627697808668ms
  apex LayerNorm backward: 13.49891544203274ms
**************************************
Shape = (16384, 16384)
  curr LayerNorm forward: 4.181216162163764ms
  apex LayerNorm forward: 4.268723972840235ms
  curr LayerNorm backward: 13.035593512002379ms
  apex LayerNorm backward: 13.463351831072941ms
**************************************
Shape = (32768, 8192)
  curr LayerNorm forward: 4.097899778978899ms
  apex LayerNorm forward: 4.109480210812762ms
  curr LayerNorm backward: 13.041268918896094ms
  apex LayerNorm backward: 13.586135944118723ms

Test Plan: buck test mode/dev-nosan caffe2/test:nn -- "LayerNorm"

Reviewed By: houseroad

Differential Revision: D17462420

fbshipit-source-id: d4a67d160bb4eff73ffac64af46c56c3845cf211
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 8b87f9a.

@xiaomengy xiaomengy deleted the export-D17462420 branch October 15, 2019 05:44
@gchanan
Copy link
Contributor

gchanan commented Oct 16, 2019

just a naming thing (and it appears this name probably just comes from apex), but it's strange to refer to an operator as "fused" but not specify what is fused.

thiagocrepaldi pushed a commit to thiagocrepaldi/pytorch that referenced this pull request Feb 4, 2020
Summary:
Pull Request resolved: pytorch#27634

Add fused layer norm impl on CUDA in PyTorch

Performance benchmark compare to apex.FusedLayerNorm on a V100 machine.

**************************************
Shape = (128, 2097152)
  curr LayerNorm forward: 7.252584544941783ms
  apex LayerNorm forward: 10.366813436849043ms
  curr LayerNorm backward: 15.568048988003284ms
  apex LayerNorm backward: 20.869979876093566ms
**************************************
Shape = (256, 1048576)
  curr LayerNorm forward: 5.185673736967146ms
  apex LayerNorm forward: 6.3868385690730065ms
  curr LayerNorm backward: 13.942008479032665ms
  apex LayerNorm backward: 15.469660016940907ms
**************************************
Shape = (512, 524288)
  curr LayerNorm forward: 4.672068868065253ms
  apex LayerNorm forward: 4.717993081081659ms
  curr LayerNorm backward: 13.46354596503079ms
  apex LayerNorm backward: 14.04774487693794ms
**************************************
Shape = (1024, 262144)
  curr LayerNorm forward: 4.547273400006816ms
  apex LayerNorm forward: 5.378365494078025ms
  curr LayerNorm backward: 13.425063178874552ms
  apex LayerNorm backward: 14.235145597020164ms
**************************************
Shape = (2048, 131072)
  curr LayerNorm forward: 4.526399010093883ms
  apex LayerNorm forward: 4.775081946980208ms
  curr LayerNorm backward: 13.222738380078226ms
  apex LayerNorm backward: 13.59594238596037ms
**************************************
Shape = (4096, 65536)
  curr LayerNorm forward: 4.28789056581445ms
  apex LayerNorm forward: 4.48913648002781ms
  curr LayerNorm backward: 13.026655421825126ms
  apex LayerNorm backward: 13.57052089786157ms
**************************************
Shape = (8192, 32768)
  curr LayerNorm forward: 4.243518367875367ms
  apex LayerNorm forward: 4.34588153520599ms
  curr LayerNorm backward: 13.140627697808668ms
  apex LayerNorm backward: 13.49891544203274ms
**************************************
Shape = (16384, 16384)
  curr LayerNorm forward: 4.181216162163764ms
  apex LayerNorm forward: 4.268723972840235ms
  curr LayerNorm backward: 13.035593512002379ms
  apex LayerNorm backward: 13.463351831072941ms
**************************************
Shape = (32768, 8192)
  curr LayerNorm forward: 4.097899778978899ms
  apex LayerNorm forward: 4.109480210812762ms
  curr LayerNorm backward: 13.041268918896094ms
  apex LayerNorm backward: 13.586135944118723ms

Test Plan: buck test mode/dev-nosan caffe2/test:nn -- "LayerNorm"

Reviewed By: houseroad

Differential Revision: D17462420

fbshipit-source-id: d4a67d160bb4eff73ffac64af46c56c3845cf211
pytorchmergebot pushed a commit that referenced this pull request Nov 22, 2022
We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when `fs`  (=`config_m` in our benchmark script) is large and `bs`  (=`config_n` in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of [PR #68238](#68238 (comment)) on AMD GPUs.

This PR is to replace `GammaBetaBackwardCUDAKernel` with the Apex layernorm backward kernel with some ROCm-specific parameter tuning when `fs`  (=`config_m`) is larger than 512 on AMD GPUs.

There are a few PRs for LayerNorm kernel:
- #26201
- #27634
- #68238

Therefore, we have tested and compared the kernel before and at this PR with the input shapes in the last two PRs along with those commonly used in the CvT model on AMD MI100.

---
**Current**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.387256 | 1.372758 | 0.378975 | 1.47892
50176 | 384 | 0.38231 | 1.362416 | 0.378084 | 1.473886
200704 | 192 | 0.997859 | 4.315875 | 0.989306 | 4.560827
802816 | 64 | 3.671828 | 16.68013 | 3.613515 | 16.827946
200 | 256 | 0.066503 | 0.332096 | 0.071422 | 0.325349
1000 | 256 | 0.071848 | 0.333355 | 0.073038 | 0.334753
6000 | 256 | 0.086334 | 0.345139 | 0.086834 | 0.347429
6272 | 256 | 0.088601 | 0.347906 | 0.087855 | 0.351245
200 | 512 | 0.071626 | 0.329726 | 0.073798 | 0.326878
1000 | 512 | 0.073975 | 0.330226 | 0.074166 | 0.332751
6000 | 512 | 0.099617 | 0.362367 | 0.100095 | 0.378313
6272 | 512 | 0.100378 | 0.358066 | 0.099857 | 0.395982
200 | 1024 | 0.072954 | 0.326382 | 0.073899 | 0.333007
1000 | 1024 | 0.0743 | 0.325532 | 0.071126 | 0.330991
6000 | 1024 | 0.127025 | 0.390084 | 0.128692 | 0.471504
6272 | 1024 | 0.130704 | 0.403536 | 0.135244 | 0.487133
200 | 1536 | 0.070331 | 0.339169 | 0.070086 | 0.331015
1000 | 1536 | 0.075085 | 0.330042 | 0.076295 | 0.328778
6000 | 1536 | 0.148889 | 0.44949 | 0.155781 | 0.659987
6272 | 1536 | 0.154939 | 0.478871 | 0.17673 | 0.716025
200 | 2048 | 0.070269 | 0.335585 | 0.072804 | 0.334655
1000 | 2048 | 0.080094 | 0.326991 | 0.080426 | 0.32685
6000 | 2048 | 0.187888 | 0.623023 | 0.245762 | 0.981635
6272 | 2048 | 0.195431 | 0.65244 | 0.262574 | 1.008141
200 | 3072 | 0.068205 | 0.339428 | 0.073068 | 0.344034
1000 | 3072 | 0.087554 | 0.328899 | 0.09218 | 0.346433
6000 | 3072 | 0.240352 | 0.905058 | 0.368135 | 1.280462
6272 | 3072 | 0.26179 | 0.959387 | 0.387782 | 1.476524
128 | 2097152 | 5.905976 | 22.724793 | 10.287974 | 30.242092
256 | 1048576 | 4.561596 | 19.554308 | 10.223171 | 29.42371
512 | 524288 | 4.146751 | 22.7247 | 11.404285 | 39.175902
1024 | 262144 | 5.193135 | 23.403325 | 11.334512 | 38.947192
2048 | 131072 | 4.992907 | 23.377801 | 11.400286 | 40.889191
4096 | 65536 | 5.429488 | 24.275701 | 11.196778 | 41.4751
8192 | 32768 | 5.35758 | 21.360312 | 10.535418 | 42.875646
16384 | 16384 | 5.44947 | 20.852605 | 10.357685 | 34.603408
32768 | 8192 | 4.688925 | 17.379392 | 9.635596 | 31.188271

</body>

</html>

---------
**At this PR**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">

<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl63
	{color:windowtext;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.38797 | 0.93103 | 0.37966 | 1.15283
50176 | 384 | 0.3874 | 0.96417 | 0.38462 | 1.18595
200704 | 192 | 1.00002 | 2.40876 | 0.99224 | 2.55579
802816 | 64 | 3.67348 | 7.98658 | 3.61871 | 7.72404
200 | 256 | 0.07292 | 0.35119 | 0.07195 | 0.32602
1000 | 256 | 0.07354 | 0.33325 | 0.07237 | 0.33742
6000 | 256 | 0.08819 | 0.33283 | 0.08453 | 0.3279
6272 | 256 | 0.0886 | 0.33446 | 0.08774 | 0.33426
200 | 512 | 0.0701 | 0.33505 | 0.07072 | 0.33018
1000 | 512 | 0.07042 | 0.33442 | 0.074 | 0.33206
6000 | 512 | 0.09931 | 0.34956 | 0.09895 | 0.3572
6272 | 512 | 0.10103 | 0.32976 | 0.10041 | 0.36635
200 | 1024 | 0.07144 | 0.33579 | 0.07209 | 0.33216
1000 | 1024 | 0.0736 | 0.32803 | 0.07286 | 0.32936
6000 | 1024 | 0.12584 | 0.38916 | 0.12852 | 0.48273
6272 | 1024 | 0.13053 | 0.38804 | 0.13464 | 0.49545
200 | 1536 | 0.07159 | 0.3396 | 0.07062 | 0.33545
1000 | 1536 | 0.07443 | 0.33239 | 0.07366 | 0.33204
6000 | 1536 | 0.14959 | 0.45043 | 0.15826 | 0.69119
6272 | 1536 | 0.1542 | 0.47644 | 0.18249 | 0.72208
200 | 2048 | 0.07258 | 0.33982 | 0.07412 | 0.33859
1000 | 2048 | 0.0793 | 0.32816 | 0.07864 | 0.32583
6000 | 2048 | 0.18973 | 0.571 | 0.25506 | 0.91796
6272 | 2048 | 0.19719 | 0.64208 | 0.26445 | 0.95055
200 | 3072 | 0.07092 | 0.33867 | 0.07104 | 0.34695
1000 | 3072 | 0.08727 | 0.33144 | 0.09144 | 0.36633
6000 | 3072 | 0.24683 | 0.87275 | 0.37761 | 1.3289
6272 | 3072 | 0.26437 | 0.91178 | 0.38496 | 1.53694
128 | 2097152 | 6.27936 | 23.69425 | 10.40004 | 30.13699
256 | 1048576 | 4.5404 | 19.47675 | 10.28494 | 29.36936
512 | 524288 | 4.13951 | 18.78771 | 10.09557 | 32.67083
1024 | 262144 | 4.47576 | 18.00411 | 9.56488 | 31.47117
2048 | 131072 | 4.28026 | 16.95619 | 9.40297 | 30.82845
4096 | 65536 | 4.2653 | 16.5018 | 9.03315 | 30.08392
8192 | 32768 | 4.25613 | 16.13583 | 8.9258 | 30.75296
16384 | 16384 | 4.20256 | 16.38207 | 9.52587 | 31.31113
32768 | 8192 | 4.20231 | 16.19452 | 9.31478 | 31.03514

</body>

</html>

---------

**Performance Improvement (%)**
<html xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:dt="uuid:C2F41010-65B3-11d1-A29F-00AA00C14882"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=OneNote.File>
<meta name=Generator content="Microsoft OneNote 15">
</head>

<body lang=en-US style='font-family:Calibri;font-size:11.0pt'>
<!--StartFragment-->

<div style='direction:ltr'>

M | N | fwdbwd,   torch.float16 | fwdbwd,   torch.float32
-- | -- | -- | --
50432 | 384 | 32.178 | 22.049
50176 | 384 | 29.231 | 19.536
200704 | 192 | 44.188 | 43.962
802816 | 64 | 52.119 | 54.100
200 | 256 | -5.750 | -0.206
1000 | 256 | 0.031 | -0.797
6000 | 256 | 3.566 | 5.621
6272 | 256 | 3.865 | 4.836
200 | 512 | -1.615 | -1.010
1000 | 512 | -1.270 | 0.208
6000 | 512 | 3.534 | 5.581
6272 | 512 | 7.905 | 7.483
200 | 1024 | -2.883 | 0.254
1000 | 1024 | -0.767 | 0.493
6000 | 1024 | 0.237 | -2.381
6272 | 1024 | 3.840 | -1.707
200 | 1536 | -0.127 | -1.340
1000 | 1536 | -0.711 | -0.992
6000 | 1536 | -0.209 | -4.728
6272 | 1536 | 0.508 | -0.846
200 | 2048 | -1.262 | -1.176
1000 | 2048 | -0.358 | 0.312
6000 | 2048 | 8.350 | 6.487
6272 | 2048 | 1.588 | 5.713
200 | 3072 | 0.223 | -0.848
1000 | 3072 | -0.773 | -5.743
6000 | 3072 | 3.570 | -3.783
6272 | 3072 | 4.962 | -4.092
128 | 2097152 | -4.266 | 0.348
256 | 1048576 | 0.397 | 0.185
512 | 524288 | 17.325 | 16.605
1024 | 262144 | 23.070 | 19.195
2048 | 131072 | 27.469 | 24.605
4096 | 65536 | 32.023 | 27.465
8192 | 32768 | 24.459 | 28.274
16384 | 16384 | 21.439 | 9.514
32768 | 8192 | 6.818 | 0.491

</div>

<!--EndFragment-->
</body>

</html>

---------
**Benchmark script of this PR**
```
# Ref:
#       1. #26201
#       2. #68238

from distutils.command.config import config
import torch
from torch.nn import LayerNorm
import timeit

number_runs = 1000  # TODO: Modify this to save time!
def test_forward(layer_norm_cuda, input_cuda):
    layer_norm_cuda(input_cuda); torch.cuda.synchronize()

def test_backward(out_cuda, layer_norm_grad_cuda, create_graph):
    out_cuda.backward(layer_norm_grad_cuda, retain_graph=True, create_graph=create_graph); torch.cuda.synchronize()

def test_fwdbwd(input_cuda, layer_norm_cuda, gO):
    input_cuda.grad = None
    layer_norm_cuda.zero_grad(set_to_none=True)
    out = layer_norm_cuda(input_cuda)
    out.backward(gO)
    torch.cuda.synchronize()

def benchmark(config_m, config_n):

    print("M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)")
    if len(config_m) != len(config_n):
        print("Please make sure the lengths of config_m and config_m are the same.")

    for i in range(len(config_m)):
        normalized_shape = config_n[i]
        results = [config_m[i], config_n[i]]
        for dtype in (torch.half, torch.float):
            if dtype == torch.half:
                layer_norm_cuda = LayerNorm(normalized_shape).half().cuda()
            else:
                layer_norm_cuda = LayerNorm(normalized_shape).cuda()

            input_cuda = torch.randn(config_m[i], config_n[i], device='cuda', dtype=dtype, requires_grad=True)

            # print("cuda forward:")
            result_fwd = timeit.timeit(lambda: test_forward(layer_norm_cuda, input_cuda), number=number_runs)
            results.append(result_fwd / number_runs * 1000)

            gO = torch.rand_like(input_cuda)

            result_fwdbwd = timeit.timeit(lambda: test_fwdbwd(input_cuda, layer_norm_cuda, gO), number=number_runs)
            results.append(result_fwdbwd / number_runs * 1000)

        print('{:09d}|{:09d}|{:9.5f}|{:9.5f}|{:9.5f}|{:9.5f}'.format(results[0], results[1], results[2], results[3], results[4], results[5]))

    print("Times are in microseconds (us).")

# CVT
config_m_cvt = [50432, 50176, 200704, 802816]
config_n_cvt = [384, 384, 192, 64]

# #68238 (comment)
config_m_68238 = [200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272]
config_n_68238 = [256,256,256,256,512,512,512,512,1024,1024,1024,1024,1536,1536,1536,1536,2048,2048,2048,2048,3072,3072,3072,3072]

# #27634
config_m_27634 = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
config_n_27634 = [2097152, 1048576, 524288, 262144, 131072, 65536, 32768, 16384, 8192]

config_m = config_m_cvt + config_m_68238 + config_m_27634
config_n = config_n_cvt + config_n_68238 + config_n_27634

benchmark(config_m, config_n)
```

CC: @jeffdaily

Pull Request resolved: #87635
Approved by: https://github.com/jataylo, https://github.com/jeffdaily, https://github.com/ezyang
jithunnair-amd pushed a commit to ROCm/pytorch that referenced this pull request Dec 1, 2022
We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when `fs`  (=`config_m` in our benchmark script) is large and `bs`  (=`config_n` in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of [PR pytorch#68238](pytorch#68238 (comment)) on AMD GPUs.

This PR is to replace `GammaBetaBackwardCUDAKernel` with the Apex layernorm backward kernel with some ROCm-specific parameter tuning when `fs`  (=`config_m`) is larger than 512 on AMD GPUs.

There are a few PRs for LayerNorm kernel:
- pytorch#26201
- pytorch#27634
- pytorch#68238

Therefore, we have tested and compared the kernel before and at this PR with the input shapes in the last two PRs along with those commonly used in the CvT model on AMD MI100.

---
**Current**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.387256 | 1.372758 | 0.378975 | 1.47892
50176 | 384 | 0.38231 | 1.362416 | 0.378084 | 1.473886
200704 | 192 | 0.997859 | 4.315875 | 0.989306 | 4.560827
802816 | 64 | 3.671828 | 16.68013 | 3.613515 | 16.827946
200 | 256 | 0.066503 | 0.332096 | 0.071422 | 0.325349
1000 | 256 | 0.071848 | 0.333355 | 0.073038 | 0.334753
6000 | 256 | 0.086334 | 0.345139 | 0.086834 | 0.347429
6272 | 256 | 0.088601 | 0.347906 | 0.087855 | 0.351245
200 | 512 | 0.071626 | 0.329726 | 0.073798 | 0.326878
1000 | 512 | 0.073975 | 0.330226 | 0.074166 | 0.332751
6000 | 512 | 0.099617 | 0.362367 | 0.100095 | 0.378313
6272 | 512 | 0.100378 | 0.358066 | 0.099857 | 0.395982
200 | 1024 | 0.072954 | 0.326382 | 0.073899 | 0.333007
1000 | 1024 | 0.0743 | 0.325532 | 0.071126 | 0.330991
6000 | 1024 | 0.127025 | 0.390084 | 0.128692 | 0.471504
6272 | 1024 | 0.130704 | 0.403536 | 0.135244 | 0.487133
200 | 1536 | 0.070331 | 0.339169 | 0.070086 | 0.331015
1000 | 1536 | 0.075085 | 0.330042 | 0.076295 | 0.328778
6000 | 1536 | 0.148889 | 0.44949 | 0.155781 | 0.659987
6272 | 1536 | 0.154939 | 0.478871 | 0.17673 | 0.716025
200 | 2048 | 0.070269 | 0.335585 | 0.072804 | 0.334655
1000 | 2048 | 0.080094 | 0.326991 | 0.080426 | 0.32685
6000 | 2048 | 0.187888 | 0.623023 | 0.245762 | 0.981635
6272 | 2048 | 0.195431 | 0.65244 | 0.262574 | 1.008141
200 | 3072 | 0.068205 | 0.339428 | 0.073068 | 0.344034
1000 | 3072 | 0.087554 | 0.328899 | 0.09218 | 0.346433
6000 | 3072 | 0.240352 | 0.905058 | 0.368135 | 1.280462
6272 | 3072 | 0.26179 | 0.959387 | 0.387782 | 1.476524
128 | 2097152 | 5.905976 | 22.724793 | 10.287974 | 30.242092
256 | 1048576 | 4.561596 | 19.554308 | 10.223171 | 29.42371
512 | 524288 | 4.146751 | 22.7247 | 11.404285 | 39.175902
1024 | 262144 | 5.193135 | 23.403325 | 11.334512 | 38.947192
2048 | 131072 | 4.992907 | 23.377801 | 11.400286 | 40.889191
4096 | 65536 | 5.429488 | 24.275701 | 11.196778 | 41.4751
8192 | 32768 | 5.35758 | 21.360312 | 10.535418 | 42.875646
16384 | 16384 | 5.44947 | 20.852605 | 10.357685 | 34.603408
32768 | 8192 | 4.688925 | 17.379392 | 9.635596 | 31.188271

</body>

</html>

---------
**At this PR**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">

<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl63
	{color:windowtext;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.38797 | 0.93103 | 0.37966 | 1.15283
50176 | 384 | 0.3874 | 0.96417 | 0.38462 | 1.18595
200704 | 192 | 1.00002 | 2.40876 | 0.99224 | 2.55579
802816 | 64 | 3.67348 | 7.98658 | 3.61871 | 7.72404
200 | 256 | 0.07292 | 0.35119 | 0.07195 | 0.32602
1000 | 256 | 0.07354 | 0.33325 | 0.07237 | 0.33742
6000 | 256 | 0.08819 | 0.33283 | 0.08453 | 0.3279
6272 | 256 | 0.0886 | 0.33446 | 0.08774 | 0.33426
200 | 512 | 0.0701 | 0.33505 | 0.07072 | 0.33018
1000 | 512 | 0.07042 | 0.33442 | 0.074 | 0.33206
6000 | 512 | 0.09931 | 0.34956 | 0.09895 | 0.3572
6272 | 512 | 0.10103 | 0.32976 | 0.10041 | 0.36635
200 | 1024 | 0.07144 | 0.33579 | 0.07209 | 0.33216
1000 | 1024 | 0.0736 | 0.32803 | 0.07286 | 0.32936
6000 | 1024 | 0.12584 | 0.38916 | 0.12852 | 0.48273
6272 | 1024 | 0.13053 | 0.38804 | 0.13464 | 0.49545
200 | 1536 | 0.07159 | 0.3396 | 0.07062 | 0.33545
1000 | 1536 | 0.07443 | 0.33239 | 0.07366 | 0.33204
6000 | 1536 | 0.14959 | 0.45043 | 0.15826 | 0.69119
6272 | 1536 | 0.1542 | 0.47644 | 0.18249 | 0.72208
200 | 2048 | 0.07258 | 0.33982 | 0.07412 | 0.33859
1000 | 2048 | 0.0793 | 0.32816 | 0.07864 | 0.32583
6000 | 2048 | 0.18973 | 0.571 | 0.25506 | 0.91796
6272 | 2048 | 0.19719 | 0.64208 | 0.26445 | 0.95055
200 | 3072 | 0.07092 | 0.33867 | 0.07104 | 0.34695
1000 | 3072 | 0.08727 | 0.33144 | 0.09144 | 0.36633
6000 | 3072 | 0.24683 | 0.87275 | 0.37761 | 1.3289
6272 | 3072 | 0.26437 | 0.91178 | 0.38496 | 1.53694
128 | 2097152 | 6.27936 | 23.69425 | 10.40004 | 30.13699
256 | 1048576 | 4.5404 | 19.47675 | 10.28494 | 29.36936
512 | 524288 | 4.13951 | 18.78771 | 10.09557 | 32.67083
1024 | 262144 | 4.47576 | 18.00411 | 9.56488 | 31.47117
2048 | 131072 | 4.28026 | 16.95619 | 9.40297 | 30.82845
4096 | 65536 | 4.2653 | 16.5018 | 9.03315 | 30.08392
8192 | 32768 | 4.25613 | 16.13583 | 8.9258 | 30.75296
16384 | 16384 | 4.20256 | 16.38207 | 9.52587 | 31.31113
32768 | 8192 | 4.20231 | 16.19452 | 9.31478 | 31.03514

</body>

</html>

---------

**Performance Improvement (%)**
<html xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:dt="uuid:C2F41010-65B3-11d1-A29F-00AA00C14882"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=OneNote.File>
<meta name=Generator content="Microsoft OneNote 15">
</head>

<body lang=en-US style='font-family:Calibri;font-size:11.0pt'>
<!--StartFragment-->

<div style='direction:ltr'>

M | N | fwdbwd,   torch.float16 | fwdbwd,   torch.float32
-- | -- | -- | --
50432 | 384 | 32.178 | 22.049
50176 | 384 | 29.231 | 19.536
200704 | 192 | 44.188 | 43.962
802816 | 64 | 52.119 | 54.100
200 | 256 | -5.750 | -0.206
1000 | 256 | 0.031 | -0.797
6000 | 256 | 3.566 | 5.621
6272 | 256 | 3.865 | 4.836
200 | 512 | -1.615 | -1.010
1000 | 512 | -1.270 | 0.208
6000 | 512 | 3.534 | 5.581
6272 | 512 | 7.905 | 7.483
200 | 1024 | -2.883 | 0.254
1000 | 1024 | -0.767 | 0.493
6000 | 1024 | 0.237 | -2.381
6272 | 1024 | 3.840 | -1.707
200 | 1536 | -0.127 | -1.340
1000 | 1536 | -0.711 | -0.992
6000 | 1536 | -0.209 | -4.728
6272 | 1536 | 0.508 | -0.846
200 | 2048 | -1.262 | -1.176
1000 | 2048 | -0.358 | 0.312
6000 | 2048 | 8.350 | 6.487
6272 | 2048 | 1.588 | 5.713
200 | 3072 | 0.223 | -0.848
1000 | 3072 | -0.773 | -5.743
6000 | 3072 | 3.570 | -3.783
6272 | 3072 | 4.962 | -4.092
128 | 2097152 | -4.266 | 0.348
256 | 1048576 | 0.397 | 0.185
512 | 524288 | 17.325 | 16.605
1024 | 262144 | 23.070 | 19.195
2048 | 131072 | 27.469 | 24.605
4096 | 65536 | 32.023 | 27.465
8192 | 32768 | 24.459 | 28.274
16384 | 16384 | 21.439 | 9.514
32768 | 8192 | 6.818 | 0.491

</div>

<!--EndFragment-->
</body>

</html>

---------
**Benchmark script of this PR**
```

from distutils.command.config import config
import torch
from torch.nn import LayerNorm
import timeit

number_runs = 1000  # TODO: Modify this to save time!
def test_forward(layer_norm_cuda, input_cuda):
    layer_norm_cuda(input_cuda); torch.cuda.synchronize()

def test_backward(out_cuda, layer_norm_grad_cuda, create_graph):
    out_cuda.backward(layer_norm_grad_cuda, retain_graph=True, create_graph=create_graph); torch.cuda.synchronize()

def test_fwdbwd(input_cuda, layer_norm_cuda, gO):
    input_cuda.grad = None
    layer_norm_cuda.zero_grad(set_to_none=True)
    out = layer_norm_cuda(input_cuda)
    out.backward(gO)
    torch.cuda.synchronize()

def benchmark(config_m, config_n):

    print("M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)")
    if len(config_m) != len(config_n):
        print("Please make sure the lengths of config_m and config_m are the same.")

    for i in range(len(config_m)):
        normalized_shape = config_n[i]
        results = [config_m[i], config_n[i]]
        for dtype in (torch.half, torch.float):
            if dtype == torch.half:
                layer_norm_cuda = LayerNorm(normalized_shape).half().cuda()
            else:
                layer_norm_cuda = LayerNorm(normalized_shape).cuda()

            input_cuda = torch.randn(config_m[i], config_n[i], device='cuda', dtype=dtype, requires_grad=True)

            # print("cuda forward:")
            result_fwd = timeit.timeit(lambda: test_forward(layer_norm_cuda, input_cuda), number=number_runs)
            results.append(result_fwd / number_runs * 1000)

            gO = torch.rand_like(input_cuda)

            result_fwdbwd = timeit.timeit(lambda: test_fwdbwd(input_cuda, layer_norm_cuda, gO), number=number_runs)
            results.append(result_fwdbwd / number_runs * 1000)

        print('{:09d}|{:09d}|{:9.5f}|{:9.5f}|{:9.5f}|{:9.5f}'.format(results[0], results[1], results[2], results[3], results[4], results[5]))

    print("Times are in microseconds (us).")

config_m_cvt = [50432, 50176, 200704, 802816]
config_n_cvt = [384, 384, 192, 64]

config_m_68238 = [200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272]
config_n_68238 = [256,256,256,256,512,512,512,512,1024,1024,1024,1024,1536,1536,1536,1536,2048,2048,2048,2048,3072,3072,3072,3072]

config_m_27634 = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
config_n_27634 = [2097152, 1048576, 524288, 262144, 131072, 65536, 32768, 16384, 8192]

config_m = config_m_cvt + config_m_68238 + config_m_27634
config_n = config_n_cvt + config_n_68238 + config_n_27634

benchmark(config_m, config_n)
```

CC: @jeffdaily

Pull Request resolved: pytorch#87635
Approved by: https://github.com/jataylo, https://github.com/jeffdaily, https://github.com/ezyang
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
We observed that the native PyTorch LayerNormBackwardKernelImplInternal has suboptimal performance for certain input sizes on AMD GPUs especially when `fs`  (=`config_m` in our benchmark script) is large and `bs`  (=`config_n` in our benchmark script) is small (commonly seen in [the CvT model](https://arxiv.org/abs/2103.15808)) in the benchmark script of [PR pytorch#68238](pytorch#68238 (comment)) on AMD GPUs.

This PR is to replace `GammaBetaBackwardCUDAKernel` with the Apex layernorm backward kernel with some ROCm-specific parameter tuning when `fs`  (=`config_m`) is larger than 512 on AMD GPUs.

There are a few PRs for LayerNorm kernel:
- pytorch#26201
- pytorch#27634
- pytorch#68238

Therefore, we have tested and compared the kernel before and at this PR with the input shapes in the last two PRs along with those commonly used in the CvT model on AMD MI100.

---
**Current**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">
<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.387256 | 1.372758 | 0.378975 | 1.47892
50176 | 384 | 0.38231 | 1.362416 | 0.378084 | 1.473886
200704 | 192 | 0.997859 | 4.315875 | 0.989306 | 4.560827
802816 | 64 | 3.671828 | 16.68013 | 3.613515 | 16.827946
200 | 256 | 0.066503 | 0.332096 | 0.071422 | 0.325349
1000 | 256 | 0.071848 | 0.333355 | 0.073038 | 0.334753
6000 | 256 | 0.086334 | 0.345139 | 0.086834 | 0.347429
6272 | 256 | 0.088601 | 0.347906 | 0.087855 | 0.351245
200 | 512 | 0.071626 | 0.329726 | 0.073798 | 0.326878
1000 | 512 | 0.073975 | 0.330226 | 0.074166 | 0.332751
6000 | 512 | 0.099617 | 0.362367 | 0.100095 | 0.378313
6272 | 512 | 0.100378 | 0.358066 | 0.099857 | 0.395982
200 | 1024 | 0.072954 | 0.326382 | 0.073899 | 0.333007
1000 | 1024 | 0.0743 | 0.325532 | 0.071126 | 0.330991
6000 | 1024 | 0.127025 | 0.390084 | 0.128692 | 0.471504
6272 | 1024 | 0.130704 | 0.403536 | 0.135244 | 0.487133
200 | 1536 | 0.070331 | 0.339169 | 0.070086 | 0.331015
1000 | 1536 | 0.075085 | 0.330042 | 0.076295 | 0.328778
6000 | 1536 | 0.148889 | 0.44949 | 0.155781 | 0.659987
6272 | 1536 | 0.154939 | 0.478871 | 0.17673 | 0.716025
200 | 2048 | 0.070269 | 0.335585 | 0.072804 | 0.334655
1000 | 2048 | 0.080094 | 0.326991 | 0.080426 | 0.32685
6000 | 2048 | 0.187888 | 0.623023 | 0.245762 | 0.981635
6272 | 2048 | 0.195431 | 0.65244 | 0.262574 | 1.008141
200 | 3072 | 0.068205 | 0.339428 | 0.073068 | 0.344034
1000 | 3072 | 0.087554 | 0.328899 | 0.09218 | 0.346433
6000 | 3072 | 0.240352 | 0.905058 | 0.368135 | 1.280462
6272 | 3072 | 0.26179 | 0.959387 | 0.387782 | 1.476524
128 | 2097152 | 5.905976 | 22.724793 | 10.287974 | 30.242092
256 | 1048576 | 4.561596 | 19.554308 | 10.223171 | 29.42371
512 | 524288 | 4.146751 | 22.7247 | 11.404285 | 39.175902
1024 | 262144 | 5.193135 | 23.403325 | 11.334512 | 38.947192
2048 | 131072 | 4.992907 | 23.377801 | 11.400286 | 40.889191
4096 | 65536 | 5.429488 | 24.275701 | 11.196778 | 41.4751
8192 | 32768 | 5.35758 | 21.360312 | 10.535418 | 42.875646
16384 | 16384 | 5.44947 | 20.852605 | 10.357685 | 34.603408
32768 | 8192 | 4.688925 | 17.379392 | 9.635596 | 31.188271

</body>

</html>

---------
**At this PR**
<html xmlns:v="urn:schemas-microsoft-com:vml"
xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:x="urn:schemas-microsoft-com:office:excel"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=Excel.Sheet>
<meta name=Generator content="Microsoft Excel 15">
<link id=Main-File rel=Main-File
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip.htm">
<link rel=File-List
href="file:///C:/Users/hubertlu/AppData/Local/Temp/msohtmlclip1/01/clip_filelist.xml">

<!--table
	{mso-displayed-decimal-separator:"\.";
	mso-displayed-thousand-separator:"\,";}
@page
	{mso-header-data:"&L&\0022Arial\0022&10&K0000FF \[AMD Official Use Only - General\]&1\#\000D";
	margin:.75in .7in .75in .7in;
	mso-header-margin:.3in;
	mso-footer-margin:.3in;}
tr
	{mso-height-source:auto;}
col
	{mso-width-source:auto;}
br
	{mso-data-placement:same-cell;}
td
	{padding-top:1px;
	padding-right:1px;
	padding-left:1px;
	mso-ignore:padding;
	color:black;
	font-size:11.0pt;
	font-weight:400;
	font-style:normal;
	text-decoration:none;
	font-family:Calibri, sans-serif;
	mso-font-charset:0;
	mso-number-format:General;
	text-align:general;
	vertical-align:bottom;
	border:none;
	mso-background-source:auto;
	mso-pattern:auto;
	mso-protection:locked visible;
	white-space:nowrap;
	mso-rotate:0;}
.xl63
	{color:windowtext;}
-->
</head>

<body link="#0563C1" vlink="#954F72">

M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)
-- | -- | -- | -- | -- | --
50432 | 384 | 0.38797 | 0.93103 | 0.37966 | 1.15283
50176 | 384 | 0.3874 | 0.96417 | 0.38462 | 1.18595
200704 | 192 | 1.00002 | 2.40876 | 0.99224 | 2.55579
802816 | 64 | 3.67348 | 7.98658 | 3.61871 | 7.72404
200 | 256 | 0.07292 | 0.35119 | 0.07195 | 0.32602
1000 | 256 | 0.07354 | 0.33325 | 0.07237 | 0.33742
6000 | 256 | 0.08819 | 0.33283 | 0.08453 | 0.3279
6272 | 256 | 0.0886 | 0.33446 | 0.08774 | 0.33426
200 | 512 | 0.0701 | 0.33505 | 0.07072 | 0.33018
1000 | 512 | 0.07042 | 0.33442 | 0.074 | 0.33206
6000 | 512 | 0.09931 | 0.34956 | 0.09895 | 0.3572
6272 | 512 | 0.10103 | 0.32976 | 0.10041 | 0.36635
200 | 1024 | 0.07144 | 0.33579 | 0.07209 | 0.33216
1000 | 1024 | 0.0736 | 0.32803 | 0.07286 | 0.32936
6000 | 1024 | 0.12584 | 0.38916 | 0.12852 | 0.48273
6272 | 1024 | 0.13053 | 0.38804 | 0.13464 | 0.49545
200 | 1536 | 0.07159 | 0.3396 | 0.07062 | 0.33545
1000 | 1536 | 0.07443 | 0.33239 | 0.07366 | 0.33204
6000 | 1536 | 0.14959 | 0.45043 | 0.15826 | 0.69119
6272 | 1536 | 0.1542 | 0.47644 | 0.18249 | 0.72208
200 | 2048 | 0.07258 | 0.33982 | 0.07412 | 0.33859
1000 | 2048 | 0.0793 | 0.32816 | 0.07864 | 0.32583
6000 | 2048 | 0.18973 | 0.571 | 0.25506 | 0.91796
6272 | 2048 | 0.19719 | 0.64208 | 0.26445 | 0.95055
200 | 3072 | 0.07092 | 0.33867 | 0.07104 | 0.34695
1000 | 3072 | 0.08727 | 0.33144 | 0.09144 | 0.36633
6000 | 3072 | 0.24683 | 0.87275 | 0.37761 | 1.3289
6272 | 3072 | 0.26437 | 0.91178 | 0.38496 | 1.53694
128 | 2097152 | 6.27936 | 23.69425 | 10.40004 | 30.13699
256 | 1048576 | 4.5404 | 19.47675 | 10.28494 | 29.36936
512 | 524288 | 4.13951 | 18.78771 | 10.09557 | 32.67083
1024 | 262144 | 4.47576 | 18.00411 | 9.56488 | 31.47117
2048 | 131072 | 4.28026 | 16.95619 | 9.40297 | 30.82845
4096 | 65536 | 4.2653 | 16.5018 | 9.03315 | 30.08392
8192 | 32768 | 4.25613 | 16.13583 | 8.9258 | 30.75296
16384 | 16384 | 4.20256 | 16.38207 | 9.52587 | 31.31113
32768 | 8192 | 4.20231 | 16.19452 | 9.31478 | 31.03514

</body>

</html>

---------

**Performance Improvement (%)**
<html xmlns:o="urn:schemas-microsoft-com:office:office"
xmlns:dt="uuid:C2F41010-65B3-11d1-A29F-00AA00C14882"
xmlns="http://www.w3.org/TR/REC-html40">

<head>

<meta name=ProgId content=OneNote.File>
<meta name=Generator content="Microsoft OneNote 15">
</head>

<body lang=en-US style='font-family:Calibri;font-size:11.0pt'>
<!--StartFragment-->

<div style='direction:ltr'>

M | N | fwdbwd,   torch.float16 | fwdbwd,   torch.float32
-- | -- | -- | --
50432 | 384 | 32.178 | 22.049
50176 | 384 | 29.231 | 19.536
200704 | 192 | 44.188 | 43.962
802816 | 64 | 52.119 | 54.100
200 | 256 | -5.750 | -0.206
1000 | 256 | 0.031 | -0.797
6000 | 256 | 3.566 | 5.621
6272 | 256 | 3.865 | 4.836
200 | 512 | -1.615 | -1.010
1000 | 512 | -1.270 | 0.208
6000 | 512 | 3.534 | 5.581
6272 | 512 | 7.905 | 7.483
200 | 1024 | -2.883 | 0.254
1000 | 1024 | -0.767 | 0.493
6000 | 1024 | 0.237 | -2.381
6272 | 1024 | 3.840 | -1.707
200 | 1536 | -0.127 | -1.340
1000 | 1536 | -0.711 | -0.992
6000 | 1536 | -0.209 | -4.728
6272 | 1536 | 0.508 | -0.846
200 | 2048 | -1.262 | -1.176
1000 | 2048 | -0.358 | 0.312
6000 | 2048 | 8.350 | 6.487
6272 | 2048 | 1.588 | 5.713
200 | 3072 | 0.223 | -0.848
1000 | 3072 | -0.773 | -5.743
6000 | 3072 | 3.570 | -3.783
6272 | 3072 | 4.962 | -4.092
128 | 2097152 | -4.266 | 0.348
256 | 1048576 | 0.397 | 0.185
512 | 524288 | 17.325 | 16.605
1024 | 262144 | 23.070 | 19.195
2048 | 131072 | 27.469 | 24.605
4096 | 65536 | 32.023 | 27.465
8192 | 32768 | 24.459 | 28.274
16384 | 16384 | 21.439 | 9.514
32768 | 8192 | 6.818 | 0.491

</div>

<!--EndFragment-->
</body>

</html>

---------
**Benchmark script of this PR**
```
# Ref:
#       1. pytorch#26201
#       2. pytorch#68238

from distutils.command.config import config
import torch
from torch.nn import LayerNorm
import timeit

number_runs = 1000  # TODO: Modify this to save time!
def test_forward(layer_norm_cuda, input_cuda):
    layer_norm_cuda(input_cuda); torch.cuda.synchronize()

def test_backward(out_cuda, layer_norm_grad_cuda, create_graph):
    out_cuda.backward(layer_norm_grad_cuda, retain_graph=True, create_graph=create_graph); torch.cuda.synchronize()

def test_fwdbwd(input_cuda, layer_norm_cuda, gO):
    input_cuda.grad = None
    layer_norm_cuda.zero_grad(set_to_none=True)
    out = layer_norm_cuda(input_cuda)
    out.backward(gO)
    torch.cuda.synchronize()

def benchmark(config_m, config_n):

    print("M | N | fwd (half) | fwdbwd (half) | fwd (float) | fwdbwd (float)")
    if len(config_m) != len(config_n):
        print("Please make sure the lengths of config_m and config_m are the same.")

    for i in range(len(config_m)):
        normalized_shape = config_n[i]
        results = [config_m[i], config_n[i]]
        for dtype in (torch.half, torch.float):
            if dtype == torch.half:
                layer_norm_cuda = LayerNorm(normalized_shape).half().cuda()
            else:
                layer_norm_cuda = LayerNorm(normalized_shape).cuda()

            input_cuda = torch.randn(config_m[i], config_n[i], device='cuda', dtype=dtype, requires_grad=True)

            # print("cuda forward:")
            result_fwd = timeit.timeit(lambda: test_forward(layer_norm_cuda, input_cuda), number=number_runs)
            results.append(result_fwd / number_runs * 1000)

            gO = torch.rand_like(input_cuda)

            result_fwdbwd = timeit.timeit(lambda: test_fwdbwd(input_cuda, layer_norm_cuda, gO), number=number_runs)
            results.append(result_fwdbwd / number_runs * 1000)

        print('{:09d}|{:09d}|{:9.5f}|{:9.5f}|{:9.5f}|{:9.5f}'.format(results[0], results[1], results[2], results[3], results[4], results[5]))

    print("Times are in microseconds (us).")

# CVT
config_m_cvt = [50432, 50176, 200704, 802816]
config_n_cvt = [384, 384, 192, 64]

# pytorch#68238 (comment)
config_m_68238 = [200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272, 200, 1000, 6000, 6272]
config_n_68238 = [256,256,256,256,512,512,512,512,1024,1024,1024,1024,1536,1536,1536,1536,2048,2048,2048,2048,3072,3072,3072,3072]

# pytorch#27634
config_m_27634 = [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768]
config_n_27634 = [2097152, 1048576, 524288, 262144, 131072, 65536, 32768, 16384, 8192]

config_m = config_m_cvt + config_m_68238 + config_m_27634
config_n = config_n_cvt + config_n_68238 + config_n_27634

benchmark(config_m, config_n)
```

CC: @jeffdaily

Pull Request resolved: pytorch#87635
Approved by: https://github.com/jataylo, https://github.com/jeffdaily, https://github.com/ezyang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: cuda Related to torch.cuda, and CUDA support in general module: internals Related to internal abstractions in c10 and ATen module: nn Related to torch.nn oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants