Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pytorch] Layer norm backward speed gain with warp shuffles #87445

Closed

Conversation

valentinandrei
Copy link
Contributor

Test Plan:

Times below are Forward + Backward on A100

       Size             FP32.   Gain.   FP16.   Gain
        256,   256  	101.30	9%	103.9	6%
        512,   256  	110.10	-4%	102.9	10%
       1024,   256  	104.30	7%	102.4	6%
       2048,   256  	107.60	4%	109.7	0%
       4096,   256  	116.70	8%	109.1	0%
       6144,   256  	106.10	7%	112.8	2%
       8192,   256  	106.10	1%	109.7	2%
        256,   512  	102.10	3%	108.5	1%
        512,   512  	101.50	40%	105.9	4%
       1024,   512  	109.70	20%	109.2	-1%
       2048,   512  	107.40	24%	107.2	1%
       4096,   512  	108.00	6%	110.6	-3%
       6144,   512  	103.90	13%	105.8	7%
       8192,   512  	138.70	14%	105.6	7%
        256,  1024  	106.20	1%	102.9	6%
        512,  1024  	104.50	4%	104.2	3%
       1024,  1024  	126.90	-15%	103.9	10%
       2048,  1024  	127.40	-15%	102.2	6%
       4096,  1024  	117.70	6%	102.8	21%
       6144,  1024  	165.30	11%	112.2	12%
       8192,  1024  	211.90	11%	144.8	13%
        256,  1536  	102.80	11%	103.1	6%
        512,  1536  	103.30	9%	102.9	18%
       1024,  1536  	111.00	-2%	117.2	7%
       2048,  1536  	102.30	12%	132.1	-4%
       4096,  1536  	165.50	5%	112.9	18%
       6144,  1536  	236.60	5%	145.7	12%
       8192,  1536  	307.80	5%	186.1	11%
        256,  2048  	110.60	-1%	103.8	7%
        512,  2048  	105.20	3%	105.6	1%
       1024,  2048  	106.70	3%	114.8	3%
       2048,  2048  	124.90	5%	109.7	0%
       4096,  2048  	231.40	4%	129.9	10%
       6144,  2048  	332.80	4%	182.5	11%
       8192,  2048  	434.60	4%	235.2	11%
        256,  3072  	111.60	8%	110.8	1%
        512,  3072  	106.80	1%	104.6	10%
       1024,  3072  	104.90	3%	109.9	4%
       2048,  3072  	193.80	0%	106.2	10%
       4096,  3072  	364.50	0%	187.8	5%
       6144,  3072  	538.30	0%	267	5%
       8192,  3072  	718.00	-1%	346.7	6%
        256,  4096  	103.60	4%	110.2	-1%
        512,  4096  	131.40	-11%	117	-7%
       1024,  4096  	135.80	1%	104.8	7%
       2048,  4096  	268.20	1%	149.4	10%
       4096,  4096  	520.70	1%	268.5	9%
       6144,  4096  	786.30	0%	389.8	9%
       8192,  4096  	1043.50	0%	509	10%

Used the following script from ngimel:

import torch
from torch.utils.benchmark import Compare, Timer

results = []
for dtype in (torch.float, torch.half):
    for fs in (256, 512, 1024, 1536, 2048, 3072, 4096):
        for bs in (256, 512, 1024, 2048, 4096, 6144, 8192):
            ln = torch.nn.LayerNorm((fs,), device="cuda", dtype=dtype)
            X = torch.randn(bs, fs, device="cuda", dtype=dtype, requires_grad=True)
            gO = torch.rand_like(X)
            stmtfwd = "ln(X)"
            stmtfwdbwd = "X.grad=None; ln.zero_grad(set_to_none=True); out = ln(X); out.backward(gO)"
            tfwd = Timer(
                stmt=stmtfwd,
                label="ln",
                sub_label=f"{bs:5}, {fs:5}",
                description=f"fwd, {dtype}",
                globals=globals(),
            )
            tfwdbwd = Timer(
                stmt=stmtfwdbwd,
                label="ln",
                sub_label=f"{bs:5}, {fs:5}",
                description=f"fwdbwd, {dtype}",
                globals=globals(),
            )
            for t in (tfwd, tfwdbwd):
                results.append(t.blocked_autorange())
        print(fs, end="\r")
c = Compare(results)
c.print()

Differential Revision: D40567574

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 21, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87445

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e702b0c:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D40567574

@efiks
Copy link
Contributor

efiks commented Oct 21, 2022

@pytorchbot start ci

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 21, 2022

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: argument command: invalid choice: 'start' (choose from 'merge', 'revert', 'rebase', 'label')

usage: @pytorchbot [-h] {merge,revert,rebase,label} ...

Try @pytorchbot --help for more info.

@efiks
Copy link
Contributor

efiks commented Oct 21, 2022

@pytorchbot --help

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 21, 2022

PyTorchBot Help

usage: @pytorchbot [-h] {merge,revert,rebase,label} ...

In order to invoke the bot on your PR, include a line that starts with
@pytorchbot anywhere in a comment. That line will form the command; no
multi-line commands are allowed. 

Example:
    Some extra context, blah blah, wow this PR looks awesome

    @pytorchbot merge

optional arguments:
  -h, --help            Show this help message and exit.

command:
  {merge,revert,rebase,label}
    merge               Merge a PR
    revert              Revert a PR
    rebase              Rebase a PR
    label               Add label to a PR

Merge

usage: @pytorchbot merge [-g | -f MESSAGE | -l] [-r [{viable/strict,master}]]

Merge an accepted PR, subject to the rules in .github/merge_rules.json.
By default, this will wait for all required checks (lint, pull) to succeed before merging.

optional arguments:
  -g, --green           Merge when all status checks running on the PR pass. To add status checks, use labels like `ciflow/trunk`.
  -f MESSAGE, --force MESSAGE
                        Merge without checking anything. This requires a reason for auditting purpose, for example:
                        @pytorchbot merge -f 'Minor update to fix lint. Expecting all PR tests to pass'
  -l, --land-checks     [Deprecated - your PR instead now gets the `ciflow/trunk` label on approval] Merge with land time checks. This will create a new branch with your changes rebased on viable/strict and run a majority of trunk tests _before_ landing to increase trunk reliability and decrease risk of revert. The tests added are: pull, Lint and trunk. Note that periodic is excluded.
  -r [{viable/strict,master}], --rebase [{viable/strict,master}]
                        Rebase the PR to re run checks before merging.  Accepts viable/strict or master as branch options and will default to viable/strict if not specified.

Revert

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst}

Revert a merged PR. This requires that you are a Meta employee.

Example:
  @pytorchbot revert -m="This is breaking tests on trunk. hud.pytorch.org/" -c=nosignal

optional arguments:
  -m MESSAGE, --message MESSAGE
                        The reason you are reverting, will be put in the commit message. Must be longer than 3 words.
  -c {nosignal,ignoredsignal,landrace,weird,ghfirst}, --classification {nosignal,ignoredsignal,landrace,weird,ghfirst}
                        A machine-friendly classification of the revert reason.

Rebase

usage: @pytorchbot rebase [-s | -b BRANCH]

Rebase a PR. Rebasing defaults to the stable viable/strict branch of pytorch.
You, along with any member of the pytorch organization, can rebase your PR.

optional arguments:
  -s, --stable          [DEPRECATED] Rebase onto viable/strict
  -b BRANCH, --branch BRANCH
                        Branch you would like to rebase to

Label

usage: @pytorchbot label labels [labels ...]

Adds label to a PR

positional arguments:
  labels  Labels to add to given Pull Request

@efiks
Copy link
Contributor

efiks commented Oct 21, 2022

/ci

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D40567574

valentinandrei added a commit to valentinandrei/pytorch that referenced this pull request Oct 22, 2022
…87445)

Summary:
Pull Request resolved: pytorch#87445

This implementation stores `mean` and `rstd` in registers and exchanges them across a warp using `__shfl_sync`. This avoids some `__syncthreads()` along the way and warp exchanges should be slightly faster than loading from shared memory.

Test Plan:
```
Times below are Forward + Backward on A100

       Size             FP32.   Gain.   FP16.   Gain
        256,   256  	101.30	9%	103.9	6%
        512,   256  	110.10	-4%	102.9	10%
       1024,   256  	104.30	7%	102.4	6%
       2048,   256  	107.60	4%	109.7	0%
       4096,   256  	116.70	8%	109.1	0%
       6144,   256  	106.10	7%	112.8	2%
       8192,   256  	106.10	1%	109.7	2%
        256,   512  	102.10	3%	108.5	1%
        512,   512  	101.50	40%	105.9	4%
       1024,   512  	109.70	20%	109.2	-1%
       2048,   512  	107.40	24%	107.2	1%
       4096,   512  	108.00	6%	110.6	-3%
       6144,   512  	103.90	13%	105.8	7%
       8192,   512  	138.70	14%	105.6	7%
        256,  1024  	106.20	1%	102.9	6%
        512,  1024  	104.50	4%	104.2	3%
       1024,  1024  	126.90	-15%	103.9	10%
       2048,  1024  	127.40	-15%	102.2	6%
       4096,  1024  	117.70	6%	102.8	21%
       6144,  1024  	165.30	11%	112.2	12%
       8192,  1024  	211.90	11%	144.8	13%
        256,  1536  	102.80	11%	103.1	6%
        512,  1536  	103.30	9%	102.9	18%
       1024,  1536  	111.00	-2%	117.2	7%
       2048,  1536  	102.30	12%	132.1	-4%
       4096,  1536  	165.50	5%	112.9	18%
       6144,  1536  	236.60	5%	145.7	12%
       8192,  1536  	307.80	5%	186.1	11%
        256,  2048  	110.60	-1%	103.8	7%
        512,  2048  	105.20	3%	105.6	1%
       1024,  2048  	106.70	3%	114.8	3%
       2048,  2048  	124.90	5%	109.7	0%
       4096,  2048  	231.40	4%	129.9	10%
       6144,  2048  	332.80	4%	182.5	11%
       8192,  2048  	434.60	4%	235.2	11%
        256,  3072  	111.60	8%	110.8	1%
        512,  3072  	106.80	1%	104.6	10%
       1024,  3072  	104.90	3%	109.9	4%
       2048,  3072  	193.80	0%	106.2	10%
       4096,  3072  	364.50	0%	187.8	5%
       6144,  3072  	538.30	0%	267	5%
       8192,  3072  	718.00	-1%	346.7	6%
        256,  4096  	103.60	4%	110.2	-1%
        512,  4096  	131.40	-11%	117	-7%
       1024,  4096  	135.80	1%	104.8	7%
       2048,  4096  	268.20	1%	149.4	10%
       4096,  4096  	520.70	1%	268.5	9%
       6144,  4096  	786.30	0%	389.8	9%
       8192,  4096  	1043.50	0%	509	10%
```

Used the following script from ngimel:

```
import torch
from torch.utils.benchmark import Compare, Timer

results = []
for dtype in (torch.float, torch.half):
    for fs in (256, 512, 1024, 1536, 2048, 3072, 4096):
        for bs in (256, 512, 1024, 2048, 4096, 6144, 8192):
            ln = torch.nn.LayerNorm((fs,), device="cuda", dtype=dtype)
            X = torch.randn(bs, fs, device="cuda", dtype=dtype, requires_grad=True)
            gO = torch.rand_like(X)
            stmtfwd = "ln(X)"
            stmtfwdbwd = "X.grad=None; ln.zero_grad(set_to_none=True); out = ln(X); out.backward(gO)"
            tfwd = Timer(
                stmt=stmtfwd,
                label="ln",
                sub_label=f"{bs:5}, {fs:5}",
                description=f"fwd, {dtype}",
                globals=globals(),
            )
            tfwdbwd = Timer(
                stmt=stmtfwdbwd,
                label="ln",
                sub_label=f"{bs:5}, {fs:5}",
                description=f"fwdbwd, {dtype}",
                globals=globals(),
            )
            for t in (tfwd, tfwdbwd):
                results.append(t.blocked_autorange())
        print(fs, end="\r")
c = Compare(results)
c.print()
```

For numerical validation used the following script:

```
def run_model_on_device(fs, X, gO, device_string, numeric_type):
    ln = torch.nn.LayerNorm((fs,), device=device_string, dtype=numeric_type)
    ln.reset_parameters()
    X.grad = None
    ln.zero_grad(set_to_none=True)
    out = ln(X)
    out.backward(gO)
    return (ln.weight.grad, ln.bias.grad)

def run_correctness_test(eps_weight, eps_bias):
    dtype = torch.float
    for fs in (256, 512, 1024, 1536, 2048, 3072, 4096):
        for bs in (256, 512, 1024, 2048, 4096, 6144, 8192):
            mean_adjustment = torch.randn(fs, device="cpu", dtype=torch.float)
            X = mean_adjustment * torch.randn(
                bs, fs, device="cpu", dtype=torch.float, requires_grad=True
            )

            X = X.detach().requires_grad_()
            gO = torch.rand_like(X)
            X_gpu = X.to("cuda")
            X_gpu = X_gpu.detach().requires_grad_()
            gO_gpu = gO.to("cuda")
            gO_gpu = gO_gpu.detach().requires_grad_()

            grad_cpu_ref = run_model_on_device(fs, X, gO, "cpu", dtype)
            grad_gpu = run_model_on_device(fs, X_gpu, gO_gpu, "cuda", dtype)
            weight_grad_gpu_target = grad_gpu[0].detach().to("cpu")
            bias_grad_gpu_target = grad_gpu[1].detach().to("cpu")

            weight_delta = torch.abs(grad_cpu_ref[0] - weight_grad_gpu_target)
            weight_mismatches = (weight_delta >= eps_weight).nonzero()
            weight_mismatch_pct = len(weight_mismatches) / len(weight_delta) * 100

            bias_delta = torch.abs(grad_cpu_ref[1] - bias_grad_gpu_target)
            bias_mismatches = (bias_delta >= eps_bias).nonzero()
            bias_mismatch_pct = len(bias_mismatches) / len(bias_delta) * 100

            print(
                "Size ({} x {}) mismatch percentage: weight {:3.2f} bias {:3.2f}".format(
                    fs, bs, weight_mismatch_pct, bias_mismatch_pct
                )
            )
```
Numerical validation results:
```
...
Size (256 x 4096) mismatch percentage: weight 0.00 bias 0.00
Size (256 x 6144) mismatch percentage: weight 0.00 bias 0.39
Size (256 x 8192) mismatch percentage: weight 0.00 bias 1.17
...
Size (512 x 4096) mismatch percentage: weight 0.00 bias 0.00
Size (512 x 6144) mismatch percentage: weight 0.00 bias 0.39
Size (512 x 8192) mismatch percentage: weight 0.39 bias 0.59
...
Size (1024 x 4096) mismatch percentage: weight 0.00 bias 0.00
Size (1024 x 6144) mismatch percentage: weight 0.00 bias 0.49
Size (1024 x 8192) mismatch percentage: weight 0.10 bias 1.07
...
Size (1536 x 6144) mismatch percentage: weight 0.00 bias 0.20
Size (1536 x 8192) mismatch percentage: weight 0.20 bias 0.78
...
Size (2048 x 6144) mismatch percentage: weight 0.05 bias 0.10
Size (2048 x 8192) mismatch percentage: weight 0.29 bias 0.88
...
Size (3072 x 6144) mismatch percentage: weight 0.00 bias 0.33
Size (3072 x 8192) mismatch percentage: weight 0.16 bias 0.72
...
Size (4096 x 6144) mismatch percentage: weight 0.12 bias 0.10
Size (4096 x 8192) mismatch percentage: weight 0.20 bias 1.03

Differential Revision: D40567574

fbshipit-source-id: 56c6bfb215b2818fd959aef7b768de6aec127ffb
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D40567574

valentinandrei added a commit to valentinandrei/pytorch that referenced this pull request Oct 23, 2022
…87445)

Summary:
Pull Request resolved: pytorch#87445

This implementation stores `mean` and `rstd` in registers and exchanges them across a warp using `__shfl_sync`. This avoids some `__syncthreads()` along the way and warp exchanges should be slightly faster than loading from shared memory.

Test Plan:
Wrote a simple CUDA app that calls the previous implementation of `GammaBetaBackwardCUDAKernel` and the current one, using FP32 values and compares the results. The epsilon value we used for FP comparison is 0.00001 for the weight and 0.0001 for the bias.

Ran the benchmark for various sizes A100 GPU and got the results below. Almost all sizes show good speedup.

```
Size (512, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0098 (ms); optimized = 0.0098 (ms); speedup = 0.00%
Size (1024, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0131 (ms); optimized = 0.0123 (ms); speedup = 6.40%
Size (2048, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0198 (ms); optimized = 0.0170 (ms); speedup = 16.41%
Size (4096, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0345 (ms); optimized = 0.0282 (ms); speedup = 22.32%
Size (8192, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0820 (ms); optimized = 0.0804 (ms); speedup = 2.02%
Size (512, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0102 (ms); optimized = 0.0099 (ms); speedup = 3.13%
Size (1024, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0134 (ms); optimized = 0.0128 (ms); speedup = 4.66%
Size (2048, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0216 (ms); optimized = 0.0189 (ms); speedup = 14.25%
Size (4096, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0488 (ms); optimized = 0.0459 (ms); speedup = 6.34%
Size (8192, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0973 (ms); optimized = 0.0870 (ms); speedup = 11.84%
Size (512, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0131 (ms); optimized = 0.0106 (ms); speedup = 23.37%
Size (1024, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0210 (ms); optimized = 0.0146 (ms); speedup = 43.56%
Size (2048, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0420 (ms); optimized = 0.0316 (ms); speedup = 32.81%
Size (4096, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0728 (ms); optimized = 0.0606 (ms); speedup = 20.19%
Size (8192, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.1399 (ms); optimized = 0.1129 (ms); speedup = 23.93%
Size (512, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.0181 (ms); optimized = 0.0155 (ms); speedup = 16.77%
Size (1024, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.0379 (ms); optimized = 0.0317 (ms); speedup = 19.47%
Size (2048, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.0680 (ms); optimized = 0.0609 (ms); speedup = 11.71%
Size (4096, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.1286 (ms); optimized = 0.1179 (ms); speedup = 9.04%
Size (8192, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.2567 (ms); optimized = 0.2296 (ms); speedup = 11.80%
Size (512, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.0337 (ms); optimized = 0.0291 (ms); speedup = 15.90%
Size (1024, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.0628 (ms); optimized = 0.0579 (ms); speedup = 8.48%
Size (2048, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.1154 (ms); optimized = 0.1068 (ms); speedup = 8.04%
Size (4096, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.2207 (ms); optimized = 0.2051 (ms); speedup = 7.61%
Size (8192, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.4504 (ms); optimized = 0.4022 (ms); speedup = 11.98%
Size (512, 16384); Mismatches: dg = 0 db = 0 out of 16384. reference = 0.0628 (ms); optimized = 0.0561 (ms); speedup = 11.94%
Size (1024, 16384); Mismatches: dg = 0 db = 0 out of 16384. reference = 0.1125 (ms); optimized = 0.1027 (ms); speedup = 9.52%
Size (2048, 16384); Mismatches: dg = 0 db = 0 out of 16384. reference = 0.2144 (ms); optimized = 0.1983 (ms); speedup = 8.13%
Size (4096, 16384); Mismatches: dg = 0 db = 0 out of 16384. reference = 0.4543 (ms); optimized = 0.3920 (ms); speedup = 15.89%
Size (8192, 16384); Mismatches: dg = 0 db = 0 out of 16384. reference = 0.9087 (ms); optimized = 0.7878 (ms); speedup = 15.35%
Size (512, 13261); Mismatches: dg = 0 db = 0 out of 13261. reference = 0.0541 (ms); optimized = 0.0481 (ms); speedup = 12.49%
Size (1024, 13261); Mismatches: dg = 0 db = 0 out of 13261. reference = 0.0953 (ms); optimized = 0.0895 (ms); speedup = 6.50%
Size (2048, 13261); Mismatches: dg = 0 db = 0 out of 13261. reference = 0.1766 (ms); optimized = 0.1849 (ms); speedup = -4.49%
Size (4096, 13261); Mismatches: dg = 0 db = 0 out of 13261. reference = 0.3431 (ms); optimized = 0.3821 (ms); speedup = -10.21%
Size (8192, 13261); Mismatches: dg = 0 db = 0 out of 13261. reference = 0.6696 (ms); optimized = 0.7788 (ms); speedup = -14.02%
Size (512, 999); Mismatches: dg = 7 db = 0 out of 999. reference = 0.0086 (ms); optimized = 0.0086 (ms); speedup = 0.00%
Size (1024, 999); Mismatches: dg = 7 db = 0 out of 999. reference = 0.0122 (ms); optimized = 0.0112 (ms); speedup = 8.94%
Size (2048, 999); Mismatches: dg = 7 db = 0 out of 999. reference = 0.0193 (ms); optimized = 0.0162 (ms); speedup = 19.29%
Size (4096, 999); Mismatches: dg = 7 db = 0 out of 999. reference = 0.0493 (ms); optimized = 0.0441 (ms); speedup = 11.78%
Size (8192, 999); Mismatches: dg = 7 db = 0 out of 999. reference = 0.0947 (ms); optimized = 0.0817 (ms); speedup = 15.90%
Size (512, 667); Mismatches: dg = 0 db = 0 out of 667. reference = 0.0084 (ms); optimized = 0.0084 (ms); speedup = -0.28%
Size (1024, 667); Mismatches: dg = 0 db = 0 out of 667. reference = 0.0114 (ms); optimized = 0.0106 (ms); speedup = 7.66%
Size (2048, 667); Mismatches: dg = 0 db = 0 out of 667. reference = 0.0179 (ms); optimized = 0.0158 (ms); speedup = 13.27%
Size (4096, 667); Mismatches: dg = 0 db = 0 out of 667. reference = 0.0378 (ms); optimized = 0.0343 (ms); speedup = 10.15%
Size (8192, 667); Mismatches: dg = 0 db = 0 out of 667. reference = 0.0828 (ms); optimized = 0.0766 (ms); speedup = 8.09%
Size (512, 312); Mismatches: dg = 0 db = 0 out of 312. reference = 0.0083 (ms); optimized = 0.0082 (ms); speedup = 1.16%
Size (1024, 312); Mismatches: dg = 0 db = 0 out of 312. reference = 0.0111 (ms); optimized = 0.0103 (ms); speedup = 7.87%
Size (2048, 312); Mismatches: dg = 0 db = 0 out of 312. reference = 0.0166 (ms); optimized = 0.0141 (ms); speedup = 17.94%
Size (4096, 312); Mismatches: dg = 0 db = 0 out of 312. reference = 0.0283 (ms); optimized = 0.0230 (ms); speedup = 23.01%
Size (8192, 312); Mismatches: dg = 0 db = 0 out of 312. reference = 0.0543 (ms); optimized = 0.0517 (ms); speedup = 4.98%
Average speedup = 11.06%
```

For additional numerical validation used the following script:

```
def run_model_on_device(fs, X, gO, device_string, numeric_type):
    ln = torch.nn.LayerNorm((fs,), device=device_string, dtype=numeric_type)
    ln.reset_parameters()
    X.grad = None
    ln.zero_grad(set_to_none=True)
    out = ln(X)
    out.backward(gO)
    return (ln.weight.grad, ln.bias.grad)

def run_correctness_test(eps_weight, eps_bias):
    dtype = torch.float
    for fs in (256, 512, 1024, 1536, 2048, 3072, 4096):
        for bs in (256, 512, 1024, 2048, 4096, 6144, 8192):
            mean_adjustment = torch.randn(fs, device="cpu", dtype=torch.float)
            X = mean_adjustment * torch.randn(
                bs, fs, device="cpu", dtype=torch.float, requires_grad=True
            )

            X = X.detach().requires_grad_()
            gO = torch.rand_like(X)
            X_gpu = X.to("cuda")
            X_gpu = X_gpu.detach().requires_grad_()
            gO_gpu = gO.to("cuda")
            gO_gpu = gO_gpu.detach().requires_grad_()

            grad_cpu_ref = run_model_on_device(fs, X, gO, "cpu", dtype)
            grad_gpu = run_model_on_device(fs, X_gpu, gO_gpu, "cuda", dtype)
            weight_grad_gpu_target = grad_gpu[0].detach().to("cpu")
            bias_grad_gpu_target = grad_gpu[1].detach().to("cpu")

            weight_delta = torch.abs(grad_cpu_ref[0] - weight_grad_gpu_target)
            weight_mismatches = (weight_delta >= eps_weight).nonzero()
            weight_mismatch_pct = len(weight_mismatches) / len(weight_delta) * 100

            bias_delta = torch.abs(grad_cpu_ref[1] - bias_grad_gpu_target)
            bias_mismatches = (bias_delta >= eps_bias).nonzero()
            bias_mismatch_pct = len(bias_mismatches) / len(bias_delta) * 100

            print(
                "Size ({} x {}) mismatch percentage: weight {:3.2f} bias {:3.2f}".format(
                    fs, bs, weight_mismatch_pct, bias_mismatch_pct
                )
            )
```

Differential Revision: D40567574

fbshipit-source-id: 756a0340678e0c8700f53cd15fb64a3243cd39ba
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D40567574

valentinandrei added a commit to valentinandrei/pytorch that referenced this pull request Oct 24, 2022
…87445)

Summary:
Pull Request resolved: pytorch#87445

This implementation stores `mean` and `rstd` in registers and exchanges them across a warp using `__shfl_sync`. This avoids some `__syncthreads()` along the way and warp exchanges should be slightly faster than loading from shared memory.

Test Plan:
Wrote a simple CUDA app that calls the previous implementation of `GammaBetaBackwardCUDAKernel` and the current one, using FP32 values and compares the results. The epsilon value we used for FP comparison is 0.00001 for the weight and 0.0001 for the bias.

Ran the benchmark for various sizes A100 GPU and got the results below. Almost all sizes show good speedup; average is 15.6%.

```
Size (512, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0099 (ms); optimized = 0.0093 (ms); bw_opt = 210.87 GB/s; speedup = 6.41%
Size (1024, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0132 (ms); optimized = 0.0118 (ms); bw_opt = 331.96 GB/s; speedup = 11.72%
Size (2048, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0198 (ms); optimized = 0.0169 (ms); bw_opt = 463.30 GB/s; speedup = 17.07%
Size (4096, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0344 (ms); optimized = 0.0285 (ms); bw_opt = 549.16 GB/s; speedup = 20.57%
Size (525, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0114 (ms); optimized = 0.0117 (ms); bw_opt = 171.74 GB/s; speedup = -2.65%
Size (1033, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0148 (ms); optimized = 0.0144 (ms); bw_opt = 274.44 GB/s; speedup = 2.65%
Size (2064, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0215 (ms); optimized = 0.0196 (ms); bw_opt = 402.73 GB/s; speedup = 9.73%
Size (3000, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0276 (ms); optimized = 0.0245 (ms); bw_opt = 467.99 GB/s; speedup = 12.65%
Size (512, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0099 (ms); optimized = 0.0094 (ms); bw_opt = 416.00 GB/s; speedup = 5.06%
Size (1024, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0134 (ms); optimized = 0.0120 (ms); bw_opt = 651.43 GB/s; speedup = 11.51%
Size (2048, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0214 (ms); optimized = 0.0186 (ms); bw_opt = 841.44 GB/s; speedup = 15.00%
Size (4096, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0486 (ms); optimized = 0.0465 (ms); bw_opt = 672.98 GB/s; speedup = 4.56%
Size (525, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0117 (ms); optimized = 0.0117 (ms); bw_opt = 343.85 GB/s; speedup = 0.20%
Size (1033, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0153 (ms); optimized = 0.0146 (ms); bw_opt = 540.30 GB/s; speedup = 4.57%
Size (2064, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0235 (ms); optimized = 0.0214 (ms); bw_opt = 736.58 GB/s; speedup = 9.69%
Size (3000, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0354 (ms); optimized = 0.0287 (ms); bw_opt = 798.39 GB/s; speedup = 23.34%
Size (512, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0132 (ms); optimized = 0.0102 (ms); bw_opt = 769.27 GB/s; speedup = 29.51%
Size (1024, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0201 (ms); optimized = 0.0142 (ms); bw_opt = 1103.06 GB/s; speedup = 41.68%
Size (2048, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0417 (ms); optimized = 0.0311 (ms); bw_opt = 1005.36 GB/s; speedup = 34.02%
Size (4096, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0725 (ms); optimized = 0.0615 (ms); bw_opt = 1017.20 GB/s; speedup = 17.91%
Size (525, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0148 (ms); optimized = 0.0125 (ms); bw_opt = 641.53 GB/s; speedup = 18.10%
Size (1033, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0224 (ms); optimized = 0.0168 (ms); bw_opt = 939.12 GB/s; speedup = 33.19%
Size (2064, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0457 (ms); optimized = 0.0369 (ms); bw_opt = 854.16 GB/s; speedup = 23.84%
Size (3000, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0600 (ms); optimized = 0.0509 (ms); bw_opt = 900.04 GB/s; speedup = 17.85%
Size (512, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.0178 (ms); optimized = 0.0152 (ms); bw_opt = 1029.47 GB/s; speedup = 16.93%
Size (1024, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.0378 (ms); optimized = 0.0304 (ms); bw_opt = 1029.27 GB/s; speedup = 24.31%
Size (2048, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.0680 (ms); optimized = 0.0654 (ms); bw_opt = 956.38 GB/s; speedup = 4.01%
Size (4096, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.1344 (ms); optimized = 0.1161 (ms); bw_opt = 1077.09 GB/s; speedup = 15.75%
Size (525, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.0228 (ms); optimized = 0.0173 (ms); bw_opt = 927.61 GB/s; speedup = 31.68%
Size (1033, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.0498 (ms); optimized = 0.0361 (ms); bw_opt = 873.82 GB/s; speedup = 37.82%
Size (2064, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.0826 (ms); optimized = 0.0659 (ms); bw_opt = 956.53 GB/s; speedup = 25.36%
Size (3000, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.1114 (ms); optimized = 0.0891 (ms); bw_opt = 1028.16 GB/s; speedup = 25.05%
Size (512, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.0342 (ms); optimized = 0.0270 (ms); bw_opt = 1159.26 GB/s; speedup = 26.57%
Size (1024, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.0625 (ms); optimized = 0.0579 (ms); bw_opt = 1080.86 GB/s; speedup = 7.99%
Size (2048, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.1160 (ms); optimized = 0.1064 (ms); bw_opt = 1175.72 GB/s; speedup = 9.05%
Size (4096, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.2204 (ms); optimized = 0.2044 (ms); bw_opt = 1223.56 GB/s; speedup = 7.84%
Size (525, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.0494 (ms); optimized = 0.0389 (ms); bw_opt = 825.70 GB/s; speedup = 27.04%
Size (1033, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.0762 (ms); optimized = 0.0674 (ms); bw_opt = 936.46 GB/s; speedup = 13.05%
Size (2064, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.1312 (ms); optimized = 0.1142 (ms); bw_opt = 1103.77 GB/s; speedup = 14.89%
Size (3000, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.1753 (ms); optimized = 0.1571 (ms); bw_opt = 1165.93 GB/s; speedup = 11.56%
Size (512, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.0425 (ms); optimized = 0.0393 (ms); bw_opt = 972.28 GB/s; speedup = 8.07%
Size (1024, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.0739 (ms); optimized = 0.0698 (ms); bw_opt = 1094.07 GB/s; speedup = 5.84%
Size (2048, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.1417 (ms); optimized = 0.1294 (ms); bw_opt = 1179.98 GB/s; speedup = 9.51%
Size (4096, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.2871 (ms); optimized = 0.2532 (ms); bw_opt = 1205.69 GB/s; speedup = 13.39%
Size (525, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.0575 (ms); optimized = 0.0496 (ms); bw_opt = 790.34 GB/s; speedup = 15.96%
Size (1033, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.0888 (ms); optimized = 0.0794 (ms); bw_opt = 970.15 GB/s; speedup = 11.80%
Size (2064, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.1541 (ms); optimized = 0.1362 (ms); bw_opt = 1129.66 GB/s; speedup = 13.13%
Size (3000, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.2098 (ms); optimized = 0.1882 (ms); bw_opt = 1188.13 GB/s; speedup = 11.46%
Size (512, 500); Mismatches: dg = 0 db = 0 out of 500. reference = 0.0097 (ms); optimized = 0.0099 (ms); bw_opt = 193.53 GB/s; speedup = -1.93%
Size (1024, 500); Mismatches: dg = 0 db = 0 out of 500. reference = 0.0133 (ms); optimized = 0.0125 (ms); bw_opt = 305.67 GB/s; speedup = 6.10%
Size (2048, 500); Mismatches: dg = 0 db = 0 out of 500. reference = 0.0202 (ms); optimized = 0.0175 (ms); bw_opt = 437.05 GB/s; speedup = 15.40%
Size (4096, 500); Mismatches: dg = 0 db = 0 out of 500. reference = 0.0360 (ms); optimized = 0.0291 (ms); bw_opt = 525.34 GB/s; speedup = 23.67%
Size (525, 500); Mismatches: dg = 0 db = 0 out of 500. reference = 0.0116 (ms); optimized = 0.0115 (ms); bw_opt = 170.85 GB/s; speedup = 1.04%
Size (1033, 500); Mismatches: dg = 0 db = 0 out of 500. reference = 0.0151 (ms); optimized = 0.0143 (ms); bw_opt = 269.81 GB/s; speedup = 5.50%
Size (2064, 500); Mismatches: dg = 0 db = 0 out of 500. reference = 0.0221 (ms); optimized = 0.0193 (ms); bw_opt = 399.14 GB/s; speedup = 14.44%
Size (3000, 500); Mismatches: dg = 0 db = 0 out of 500. reference = 0.0286 (ms); optimized = 0.0248 (ms); bw_opt = 451.77 GB/s; speedup = 15.38%
Size (512, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0098 (ms); optimized = 0.0100 (ms); bw_opt = 382.08 GB/s; speedup = -2.14%
Size (1024, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0132 (ms); optimized = 0.0126 (ms); bw_opt = 606.11 GB/s; speedup = 4.54%
Size (2048, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0214 (ms); optimized = 0.0189 (ms); bw_opt = 808.26 GB/s; speedup = 13.24%
Size (4096, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0523 (ms); optimized = 0.0491 (ms); bw_opt = 622.43 GB/s; speedup = 6.56%
Size (525, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0116 (ms); optimized = 0.0116 (ms); bw_opt = 338.56 GB/s; speedup = 0.21%
Size (1033, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0153 (ms); optimized = 0.0144 (ms); bw_opt = 535.51 GB/s; speedup = 6.29%
Size (2064, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0234 (ms); optimized = 0.0208 (ms); bw_opt = 740.78 GB/s; speedup = 12.50%
Size (3000, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0378 (ms); optimized = 0.0365 (ms); bw_opt = 613.16 GB/s; speedup = 3.59%
Size (512, 2001); Mismatches: dg = 0 db = 0 out of 2001. reference = 0.0133 (ms); optimized = 0.0109 (ms); bw_opt = 700.75 GB/s; speedup = 21.83%
Size (1024, 2001); Mismatches: dg = 0 db = 0 out of 2001. reference = 0.0206 (ms); optimized = 0.0151 (ms); bw_opt = 1011.46 GB/s; speedup = 36.28%
Size (2048, 2001); Mismatches: dg = 0 db = 0 out of 2001. reference = 0.0436 (ms); optimized = 0.0354 (ms); bw_opt = 863.24 GB/s; speedup = 23.16%
Size (4096, 2001); Mismatches: dg = 0 db = 0 out of 2001. reference = 0.0761 (ms); optimized = 0.0625 (ms); bw_opt = 977.57 GB/s; speedup = 21.70%
Size (525, 2001); Mismatches: dg = 0 db = 0 out of 2001. reference = 0.0149 (ms); optimized = 0.0124 (ms); bw_opt = 632.84 GB/s; speedup = 20.19%
Size (1033, 2001); Mismatches: dg = 0 db = 0 out of 2001. reference = 0.0224 (ms); optimized = 0.0169 (ms); bw_opt = 912.41 GB/s; speedup = 32.58%
Size (2064, 2001); Mismatches: dg = 0 db = 0 out of 2001. reference = 0.0480 (ms); optimized = 0.0396 (ms); bw_opt = 777.79 GB/s; speedup = 21.19%
Size (3000, 2001); Mismatches: dg = 0 db = 0 out of 2001. reference = 0.0623 (ms); optimized = 0.0514 (ms); bw_opt = 870.83 GB/s; speedup = 21.20%
Size (512, 4005); Mismatches: dg = 5 db = 0 out of 4005. reference = 0.0191 (ms); optimized = 0.0163 (ms); bw_opt = 938.91 GB/s; speedup = 17.11%
Size (1024, 4005); Mismatches: dg = 5 db = 0 out of 4005. reference = 0.0424 (ms); optimized = 0.0364 (ms); bw_opt = 840.87 GB/s; speedup = 16.58%
Size (2048, 4005); Mismatches: dg = 5 db = 0 out of 4005. reference = 0.0741 (ms); optimized = 0.0619 (ms); bw_opt = 987.71 GB/s; speedup = 19.68%
Size (4096, 4005); Mismatches: dg = 5 db = 0 out of 4005. reference = 0.1414 (ms); optimized = 0.1163 (ms); bw_opt = 1051.44 GB/s; speedup = 21.59%
Size (525, 4005); Mismatches: dg = 5 db = 0 out of 4005. reference = 0.0242 (ms); optimized = 0.0180 (ms); bw_opt = 872.17 GB/s; speedup = 34.44%
Size (1033, 4005); Mismatches: dg = 5 db = 0 out of 4005. reference = 0.0521 (ms); optimized = 0.0392 (ms); bw_opt = 787.37 GB/s; speedup = 32.91%
Size (2064, 4005); Mismatches: dg = 5 db = 0 out of 4005. reference = 0.0867 (ms); optimized = 0.0667 (ms); bw_opt = 924.25 GB/s; speedup = 30.03%
Size (3000, 4005); Mismatches: dg = 5 db = 0 out of 4005. reference = 0.1146 (ms); optimized = 0.0899 (ms); bw_opt = 996.25 GB/s; speedup = 27.47%
Size (512, 8117); Mismatches: dg = 0 db = 0 out of 8117. reference = 0.0383 (ms); optimized = 0.0337 (ms); bw_opt = 921.03 GB/s; speedup = 13.73%
Size (1024, 8117); Mismatches: dg = 0 db = 0 out of 8117. reference = 0.0656 (ms); optimized = 0.0591 (ms); bw_opt = 1048.93 GB/s; speedup = 11.01%
Size (2048, 8117); Mismatches: dg = 0 db = 0 out of 8117. reference = 0.1204 (ms); optimized = 0.1118 (ms); bw_opt = 1108.33 GB/s; speedup = 7.68%
Size (4096, 8117); Mismatches: dg = 0 db = 0 out of 8117. reference = 0.2281 (ms); optimized = 0.2276 (ms); bw_opt = 1088.79 GB/s; speedup = 0.22%
Size (525, 8117); Mismatches: dg = 0 db = 0 out of 8117. reference = 0.0524 (ms); optimized = 0.0436 (ms); bw_opt = 729.98 GB/s; speedup = 20.24%
Size (1033, 8117); Mismatches: dg = 0 db = 0 out of 8117. reference = 0.0810 (ms); optimized = 0.0679 (ms); bw_opt = 921.04 GB/s; speedup = 19.31%
Size (2064, 8117); Mismatches: dg = 0 db = 0 out of 8117. reference = 0.1343 (ms); optimized = 0.1192 (ms); bw_opt = 1047.94 GB/s; speedup = 12.66%
Size (3000, 8117); Mismatches: dg = 0 db = 0 out of 8117. reference = 0.1797 (ms); optimized = 0.1699 (ms); bw_opt = 1068.36 GB/s; speedup = 5.77%

Average speedup = 15.56%
```

For additional numerical validation used the following script:

```
def run_model_on_device(fs, X, gO, device_string, numeric_type):
    ln = torch.nn.LayerNorm((fs,), device=device_string, dtype=numeric_type)
    ln.reset_parameters()
    X.grad = None
    ln.zero_grad(set_to_none=True)
    out = ln(X)
    out.backward(gO)
    return (ln.weight.grad, ln.bias.grad)

def run_correctness_test(eps_weight, eps_bias):
    dtype = torch.float
    for fs in (512, 1024, 2048, 4096, 8192, 10000, 500, 1000, 2001, 4005, 8117):
        for bs in (512, 1024, 2048, 4096, 525, 1033, 2064, 3000):
            mean_adjustment = torch.randn(fs, device="cpu", dtype=torch.float)
            X = mean_adjustment * torch.randn(
                bs, fs, device="cpu", dtype=torch.float, requires_grad=True
            )

            X = X.detach().requires_grad_()
            gO = torch.rand_like(X)
            X_gpu = X.to("cuda")
            X_gpu = X_gpu.detach().requires_grad_()
            gO_gpu = gO.to("cuda")
            gO_gpu = gO_gpu.detach().requires_grad_()

            grad_cpu_ref = run_model_on_device(fs, X, gO, "cpu", dtype)
            grad_gpu = run_model_on_device(fs, X_gpu, gO_gpu, "cuda", dtype)
            weight_grad_gpu_target = grad_gpu[0].detach().to("cpu")
            bias_grad_gpu_target = grad_gpu[1].detach().to("cpu")

            weight_delta = torch.abs(grad_cpu_ref[0] - weight_grad_gpu_target)
            weight_mismatches = (weight_delta >= eps_weight).nonzero()
            weight_mismatch_pct = len(weight_mismatches) / len(weight_delta) * 100

            bias_delta = torch.abs(grad_cpu_ref[1] - bias_grad_gpu_target)
            bias_mismatches = (bias_delta >= eps_bias).nonzero()
            bias_mismatch_pct = len(bias_mismatches) / len(bias_delta) * 100

            print(
                "Size ({} x {}) mismatch percentage: weight {:3.2f} bias {:3.2f}".format(
                    fs, bs, weight_mismatch_pct, bias_mismatch_pct
                )
            )
```

Differential Revision: D40567574

fbshipit-source-id: b23e59f4cbd7d8a8104096af2ae5d9b6bb7e200c
@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 24, 2022
…87445)

Summary:
Pull Request resolved: pytorch#87445

Improved native layer norm backward performance.

Rewrote `GammaBetaBackwardCUDAKernel` to use shared memory only for the reduction step, but not for loading `mean` and `rstd`. The previous implementation used only `threadIdx.x = 0` to load `mean` and `rstd` into shared memory, and then all threads would access the values in order to do loop unrolling. This approached increased register usage and decreased occupancy, without much benefit from using shared memory (this is because the values were already cached in L1). The new implementation is simpler and register usage is smaller, thus occupancy is better.

Added another implementation called `GammaBetaBackwardCUDAKernel_32x32` which is only for shapes dividing exactly to a (32 x 32) block. This permits using warp shuffles for speeding up loading `mean` and `rstd` as well as for the final reduction stage. The effective bandwidth of this implementation is equal to STREAM Triad.

Observed that we can get additional benefit if we lower the threshold for calling `GammaBetaBackwardSimpleCUDAKernel` (simple col-wise reduction implementation) from `512` to `128`.

Test Plan:
Wrote a simple CUDA app that calls the previous implementation of `GammaBetaBackwardCUDAKernel` and the current one, using FP32 values and compares the results. The epsilon value we used for FP comparison is 0.00001 for the weight and 0.0001 for the bias.

Ran the benchmark for various sizes A100 GPU and got the results below. Almost all sizes show good speedup.

```
Size (32, 32); Mismatches: dg = 0 db = 0 out of 32. reference = 0.0073 (ms); optimized = 0.0071 (ms); bw_opt = 1.14 GB/s; speedup = 2.68%
Size (64, 32); Mismatches: dg = 0 db = 0 out of 32. reference = 0.0107 (ms); optimized = 0.0107 (ms); bw_opt = 1.50 GB/s; speedup = 0.22%
Size (256, 128); Mismatches: dg = 0 db = 0 out of 128. reference = 0.0323 (ms); optimized = 0.0075 (ms); bw_opt = 32.89 GB/s; speedup = 330.16%
Size (512, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0103 (ms); optimized = 0.0089 (ms); bw_opt = 440.54 GB/s; speedup = 15.82%
Size (1024, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0197 (ms); optimized = 0.0136 (ms); bw_opt = 1151.44 GB/s; speedup = 44.91%
Size (2048, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0416 (ms); optimized = 0.0283 (ms); bw_opt = 1105.31 GB/s; speedup = 47.01%
Size (4096, 16384); Mismatches: dg = 0 db = 0 out of 16384. reference = 0.4420 (ms); optimized = 0.3915 (ms); bw_opt = 1277.58 GB/s; speedup = 12.90%
Size (70000, 64); Mismatches: dg = 0 db = 0 out of 64. reference = 0.5908 (ms); optimized = 0.6850 (ms); bw_opt = 49.49 GB/s; speedup = -13.75%
Size (131072, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 1.1961 (ms); optimized = 0.9234 (ms); bw_opt = 542.54 GB/s; speedup = 29.53%
Size (1000, 520); Mismatches: dg = 0 db = 0 out of 520. reference = 0.0132 (ms); optimized = 0.0113 (ms); bw_opt = 343.83 GB/s; speedup = 16.88%
Size (4005, 4005); Mismatches: dg = 0 db = 0 out of 4005. reference = 0.1441 (ms); optimized = 0.1054 (ms); bw_opt = 1134.36 GB/s; speedup = 36.71%
Size (10000, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.1293 (ms); optimized = 0.1248 (ms); bw_opt = 597.71 GB/s; speedup = 3.63%
Size (1024, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.0738 (ms); optimized = 0.0735 (ms); bw_opt = 1039.40 GB/s; speedup = 0.45%
Size (8192, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.2673 (ms); optimized = 0.2223 (ms); bw_opt = 1125.01 GB/s; speedup = 20.25%
Size (10000, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.7331 (ms); optimized = 0.8940 (ms); bw_opt = 833.54 GB/s; speedup = -18.00%
Size (3072, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.2087 (ms); optimized = 0.2364 (ms); bw_opt = 968.64 GB/s; speedup = -11.71%
Size (6144, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.4197 (ms); optimized = 0.5118 (ms); bw_opt = 894.63 GB/s; speedup = -18.00%
Size (1024, 20000); Mismatches: dg = 0 db = 0 out of 20000. reference = 0.1480 (ms); optimized = 0.1297 (ms); bw_opt = 1177.68 GB/s; speedup = 14.12%
Size (1024, 20000); Mismatches: dg = 0 db = 0 out of 20000. reference = 0.1483 (ms); optimized = 0.1278 (ms); bw_opt = 1195.26 GB/s; speedup = 16.04%
Size (512, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0104 (ms); optimized = 0.0091 (ms); bw_opt = 646.72 GB/s; speedup = 14.44%
Size (512, 6144); Mismatches: dg = 0 db = 0 out of 6144. reference = 0.0219 (ms); optimized = 0.0156 (ms); bw_opt = 1506.30 GB/s; speedup = 40.52%
Size (512, 10240); Mismatches: dg = 0 db = 0 out of 10240. reference = 0.0424 (ms); optimized = 0.0370 (ms); bw_opt = 1057.84 GB/s; speedup = 14.63%
Size (1000, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0139 (ms); optimized = 0.0119 (ms); bw_opt = 627.51 GB/s; speedup = 16.83%
Size (2000, 2000); Mismatches: dg = 0 db = 0 out of 2000. reference = 0.0421 (ms); optimized = 0.0412 (ms); bw_opt = 724.10 GB/s; speedup = 2.20%
Size (10240, 10240); Mismatches: dg = 0 db = 0 out of 10240. reference = 0.7210 (ms); optimized = 0.6098 (ms); bw_opt = 1281.40 GB/s; speedup = 18.24%
Size (384, 128); Mismatches: dg = 0 db = 0 out of 128. reference = 0.0449 (ms); optimized = 0.0089 (ms); bw_opt = 41.50 GB/s; speedup = 403.48%
Size (2048, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0208 (ms); optimized = 0.0169 (ms); bw_opt = 925.70 GB/s; speedup = 23.13%
Size (267, 513); Mismatches: dg = 0 db = 0 out of 513. reference = 0.0342 (ms); optimized = 0.0090 (ms); bw_opt = 114.18 GB/s; speedup = 280.64%
Size (67, 123479); Mismatches: dg = 0 db = 0 out of 123479. reference = 0.0562 (ms); optimized = 0.0552 (ms); bw_opt = 1133.46 GB/s; speedup = 1.81%
Size (1024, 123479); Mismatches: dg = 0 db = 0 out of 123479. reference = 0.8573 (ms); optimized = 0.9245 (ms); bw_opt = 1020.02 GB/s; speedup = -7.27%
Size (2048, 66679); Mismatches: dg = 0 db = 0 out of 66679. reference = 0.8778 (ms); optimized = 0.8590 (ms); bw_opt = 1185.05 GB/s; speedup = 2.19%
Size (200, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0215 (ms); optimized = 0.0066 (ms); bw_opt = 58.49 GB/s; speedup = 226.81%
Size (1000, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0109 (ms); optimized = 0.0092 (ms); bw_opt = 208.27 GB/s; speedup = 18.65%
Size (6000, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0394 (ms); optimized = 0.0301 (ms); bw_opt = 381.90 GB/s; speedup = 30.98%
Size (6272, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0403 (ms); optimized = 0.0300 (ms); bw_opt = 400.48 GB/s; speedup = 34.34%
Size (200, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0218 (ms); optimized = 0.0066 (ms); bw_opt = 116.33 GB/s; speedup = 229.96%
Size (1000, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0110 (ms); optimized = 0.0094 (ms); bw_opt = 407.29 GB/s; speedup = 17.26%
Size (6000, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0535 (ms); optimized = 0.0594 (ms); bw_opt = 386.05 GB/s; speedup = -9.95%
Size (6272, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0573 (ms); optimized = 0.0387 (ms); bw_opt = 619.62 GB/s; speedup = 48.06%
Size (200, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0221 (ms); optimized = 0.0069 (ms); bw_opt = 222.78 GB/s; speedup = 220.76%
Size (1000, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0113 (ms); optimized = 0.0097 (ms); bw_opt = 787.79 GB/s; speedup = 16.46%
Size (6000, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0723 (ms); optimized = 0.0715 (ms); bw_opt = 640.95 GB/s; speedup = 1.10%
Size (6272, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0751 (ms); optimized = 0.0572 (ms); bw_opt = 837.57 GB/s; speedup = 31.30%
Size (200, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0232 (ms); optimized = 0.0071 (ms); bw_opt = 323.97 GB/s; speedup = 226.51%
Size (1000, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0125 (ms); optimized = 0.0114 (ms); bw_opt = 1005.84 GB/s; speedup = 9.62%
Size (6000, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0807 (ms); optimized = 0.0830 (ms); bw_opt = 828.02 GB/s; speedup = -2.76%
Size (6272, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0836 (ms); optimized = 0.0695 (ms); bw_opt = 1033.62 GB/s; speedup = 20.27%
Size (200, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0224 (ms); optimized = 0.0075 (ms); bw_opt = 408.58 GB/s; speedup = 198.10%
Size (1000, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0165 (ms); optimized = 0.0135 (ms); bw_opt = 1132.42 GB/s; speedup = 22.26%
Size (6000, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0993 (ms); optimized = 0.0989 (ms); bw_opt = 926.35 GB/s; speedup = 0.41%
Size (6272, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.1033 (ms); optimized = 0.0826 (ms); bw_opt = 1159.55 GB/s; speedup = 25.09%
Size (200, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.0230 (ms); optimized = 0.0076 (ms); bw_opt = 605.09 GB/s; speedup = 202.51%
Size (1000, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.0207 (ms); optimized = 0.0213 (ms); bw_opt = 1076.45 GB/s; speedup = -2.69%
Size (6000, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.1198 (ms); optimized = 0.1274 (ms); bw_opt = 1078.58 GB/s; speedup = -5.95%
Size (6272, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.1293 (ms); optimized = 0.1189 (ms); bw_opt = 1207.95 GB/s; speedup = 8.76%

Average speedup = 52.88%
```

For additional numerical validation used the following script:

```
def run_model_on_device(fs, X, gO, device_string, numeric_type):
    ln = torch.nn.LayerNorm((fs,), device=device_string, dtype=numeric_type)
    ln.reset_parameters()
    X.grad = None
    ln.zero_grad(set_to_none=True)
    out = ln(X)
    out.backward(gO)
    return (ln.weight.grad, ln.bias.grad)

def run_correctness_test(eps_weight, eps_bias):
    dtype = torch.float
    for fs in (512, 1024, 2048, 4096, 8192, 10000, 500, 1000, 2001, 4005, 8117):
        for bs in (512, 1024, 2048, 4096, 525, 1033, 2064, 3000):
            mean_adjustment = torch.randn(fs, device="cpu", dtype=torch.float)
            X = mean_adjustment * torch.randn(
                bs, fs, device="cpu", dtype=torch.float, requires_grad=True
            )

            X = X.detach().requires_grad_()
            gO = torch.rand_like(X)
            X_gpu = X.to("cuda")
            X_gpu = X_gpu.detach().requires_grad_()
            gO_gpu = gO.to("cuda")
            gO_gpu = gO_gpu.detach().requires_grad_()

            grad_cpu_ref = run_model_on_device(fs, X, gO, "cpu", dtype)
            grad_gpu = run_model_on_device(fs, X_gpu, gO_gpu, "cuda", dtype)
            weight_grad_gpu_target = grad_gpu[0].detach().to("cpu")
            bias_grad_gpu_target = grad_gpu[1].detach().to("cpu")

            weight_delta = torch.abs(grad_cpu_ref[0] - weight_grad_gpu_target)
            weight_mismatches = (weight_delta >= eps_weight).nonzero()
            weight_mismatch_pct = len(weight_mismatches) / len(weight_delta) * 100

            bias_delta = torch.abs(grad_cpu_ref[1] - bias_grad_gpu_target)
            bias_mismatches = (bias_delta >= eps_bias).nonzero()
            bias_mismatch_pct = len(bias_mismatches) / len(bias_delta) * 100

            print(
                "Size ({} x {}) mismatch percentage: weight {:3.2f} bias {:3.2f}".format(
                    fs, bs, weight_mismatch_pct, bias_mismatch_pct
                )
            )
```

`NVFuserTest.FusionMagicSchedulerLayerNormBackward_CUDA` test also does additional numerical validation and it passes.

Reviewed By: ngimel

Differential Revision: D40567574

fbshipit-source-id: c7ab0dff4c375cc0b60bbc2c0e452939d823b91b
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D40567574

@valentinandrei
Copy link
Contributor Author

@pytorchbot merge -g

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks on your PR pass since you used the green (-g) flag (ETA: 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions
Copy link

Hey @valentinandrei.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 25, 2022

Didn't find following labels among repository labels: topic:,performance,topic:,not,user,facing

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 25, 2022

Didn't find following labels among repository labels: topic:performance

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 25, 2022

Didn't find following labels among repository labels: performance

@valentinandrei
Copy link
Contributor Author

@pytorchbot label "topic: performance"

@pytorch-bot pytorch-bot bot added the topic: performance topic category label Oct 25, 2022
@weiwangmeta
Copy link
Contributor

@pytorchbot revert -m "breaking internal builds due to MS compiler " -c "ghfirst"”.

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 26, 2022

❌ 🤖 pytorchbot command failed:

@pytorchbot revert: error: argument -c/--classification: invalid choice: 'ghfirst”.' (choose from 'nosignal', 'ignoredsignal', 'landrace', 'weird', 'ghfirst')

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst}

Try @pytorchbot --help for more info.

@weiwangmeta
Copy link
Contributor

@pytorchbot revert -m "breaking internal builds due to MS compiler " -c ghfirst

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@valentinandrei your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Oct 26, 2022
…87445)"

This reverts commit b6f2833.

Reverted #87445 on behalf of https://github.com/weiwangmeta due to breaking internal builds due to MS compiler
sgrigory pushed a commit to sgrigory/pytorch that referenced this pull request Oct 28, 2022
…87445)

Test Plan:
```
Times below are Forward + Backward on A100

       Size             FP32.   Gain.   FP16.   Gain
        256,   256  	101.30	9%	103.9	6%
        512,   256  	110.10	-4%	102.9	10%
       1024,   256  	104.30	7%	102.4	6%
       2048,   256  	107.60	4%	109.7	0%
       4096,   256  	116.70	8%	109.1	0%
       6144,   256  	106.10	7%	112.8	2%
       8192,   256  	106.10	1%	109.7	2%
        256,   512  	102.10	3%	108.5	1%
        512,   512  	101.50	40%	105.9	4%
       1024,   512  	109.70	20%	109.2	-1%
       2048,   512  	107.40	24%	107.2	1%
       4096,   512  	108.00	6%	110.6	-3%
       6144,   512  	103.90	13%	105.8	7%
       8192,   512  	138.70	14%	105.6	7%
        256,  1024  	106.20	1%	102.9	6%
        512,  1024  	104.50	4%	104.2	3%
       1024,  1024  	126.90	-15%	103.9	10%
       2048,  1024  	127.40	-15%	102.2	6%
       4096,  1024  	117.70	6%	102.8	21%
       6144,  1024  	165.30	11%	112.2	12%
       8192,  1024  	211.90	11%	144.8	13%
        256,  1536  	102.80	11%	103.1	6%
        512,  1536  	103.30	9%	102.9	18%
       1024,  1536  	111.00	-2%	117.2	7%
       2048,  1536  	102.30	12%	132.1	-4%
       4096,  1536  	165.50	5%	112.9	18%
       6144,  1536  	236.60	5%	145.7	12%
       8192,  1536  	307.80	5%	186.1	11%
        256,  2048  	110.60	-1%	103.8	7%
        512,  2048  	105.20	3%	105.6	1%
       1024,  2048  	106.70	3%	114.8	3%
       2048,  2048  	124.90	5%	109.7	0%
       4096,  2048  	231.40	4%	129.9	10%
       6144,  2048  	332.80	4%	182.5	11%
       8192,  2048  	434.60	4%	235.2	11%
        256,  3072  	111.60	8%	110.8	1%
        512,  3072  	106.80	1%	104.6	10%
       1024,  3072  	104.90	3%	109.9	4%
       2048,  3072  	193.80	0%	106.2	10%
       4096,  3072  	364.50	0%	187.8	5%
       6144,  3072  	538.30	0%	267	5%
       8192,  3072  	718.00	-1%	346.7	6%
        256,  4096  	103.60	4%	110.2	-1%
        512,  4096  	131.40	-11%	117	-7%
       1024,  4096  	135.80	1%	104.8	7%
       2048,  4096  	268.20	1%	149.4	10%
       4096,  4096  	520.70	1%	268.5	9%
       6144,  4096  	786.30	0%	389.8	9%
       8192,  4096  	1043.50	0%	509	10%
```

Used the following script from ngimel:

```
import torch
from torch.utils.benchmark import Compare, Timer

results = []
for dtype in (torch.float, torch.half):
    for fs in (256, 512, 1024, 1536, 2048, 3072, 4096):
        for bs in (256, 512, 1024, 2048, 4096, 6144, 8192):
            ln = torch.nn.LayerNorm((fs,), device="cuda", dtype=dtype)
            X = torch.randn(bs, fs, device="cuda", dtype=dtype, requires_grad=True)
            gO = torch.rand_like(X)
            stmtfwd = "ln(X)"
            stmtfwdbwd = "X.grad=None; ln.zero_grad(set_to_none=True); out = ln(X); out.backward(gO)"
            tfwd = Timer(
                stmt=stmtfwd,
                label="ln",
                sub_label=f"{bs:5}, {fs:5}",
                description=f"fwd, {dtype}",
                globals=globals(),
            )
            tfwdbwd = Timer(
                stmt=stmtfwdbwd,
                label="ln",
                sub_label=f"{bs:5}, {fs:5}",
                description=f"fwdbwd, {dtype}",
                globals=globals(),
            )
            for t in (tfwd, tfwdbwd):
                results.append(t.blocked_autorange())
        print(fs, end="\r")
c = Compare(results)
c.print()
```

Differential Revision: D40567574

Pull Request resolved: pytorch#87445
Approved by: https://github.com/ngimel
sgrigory pushed a commit to sgrigory/pytorch that referenced this pull request Oct 28, 2022
…ytorch#87445)"

This reverts commit b6f2833.

Reverted pytorch#87445 on behalf of https://github.com/weiwangmeta due to breaking internal builds due to MS compiler
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
…87445)

Test Plan:
```
Times below are Forward + Backward on A100

       Size             FP32.   Gain.   FP16.   Gain
        256,   256  	101.30	9%	103.9	6%
        512,   256  	110.10	-4%	102.9	10%
       1024,   256  	104.30	7%	102.4	6%
       2048,   256  	107.60	4%	109.7	0%
       4096,   256  	116.70	8%	109.1	0%
       6144,   256  	106.10	7%	112.8	2%
       8192,   256  	106.10	1%	109.7	2%
        256,   512  	102.10	3%	108.5	1%
        512,   512  	101.50	40%	105.9	4%
       1024,   512  	109.70	20%	109.2	-1%
       2048,   512  	107.40	24%	107.2	1%
       4096,   512  	108.00	6%	110.6	-3%
       6144,   512  	103.90	13%	105.8	7%
       8192,   512  	138.70	14%	105.6	7%
        256,  1024  	106.20	1%	102.9	6%
        512,  1024  	104.50	4%	104.2	3%
       1024,  1024  	126.90	-15%	103.9	10%
       2048,  1024  	127.40	-15%	102.2	6%
       4096,  1024  	117.70	6%	102.8	21%
       6144,  1024  	165.30	11%	112.2	12%
       8192,  1024  	211.90	11%	144.8	13%
        256,  1536  	102.80	11%	103.1	6%
        512,  1536  	103.30	9%	102.9	18%
       1024,  1536  	111.00	-2%	117.2	7%
       2048,  1536  	102.30	12%	132.1	-4%
       4096,  1536  	165.50	5%	112.9	18%
       6144,  1536  	236.60	5%	145.7	12%
       8192,  1536  	307.80	5%	186.1	11%
        256,  2048  	110.60	-1%	103.8	7%
        512,  2048  	105.20	3%	105.6	1%
       1024,  2048  	106.70	3%	114.8	3%
       2048,  2048  	124.90	5%	109.7	0%
       4096,  2048  	231.40	4%	129.9	10%
       6144,  2048  	332.80	4%	182.5	11%
       8192,  2048  	434.60	4%	235.2	11%
        256,  3072  	111.60	8%	110.8	1%
        512,  3072  	106.80	1%	104.6	10%
       1024,  3072  	104.90	3%	109.9	4%
       2048,  3072  	193.80	0%	106.2	10%
       4096,  3072  	364.50	0%	187.8	5%
       6144,  3072  	538.30	0%	267	5%
       8192,  3072  	718.00	-1%	346.7	6%
        256,  4096  	103.60	4%	110.2	-1%
        512,  4096  	131.40	-11%	117	-7%
       1024,  4096  	135.80	1%	104.8	7%
       2048,  4096  	268.20	1%	149.4	10%
       4096,  4096  	520.70	1%	268.5	9%
       6144,  4096  	786.30	0%	389.8	9%
       8192,  4096  	1043.50	0%	509	10%
```

Used the following script from ngimel:

```
import torch
from torch.utils.benchmark import Compare, Timer

results = []
for dtype in (torch.float, torch.half):
    for fs in (256, 512, 1024, 1536, 2048, 3072, 4096):
        for bs in (256, 512, 1024, 2048, 4096, 6144, 8192):
            ln = torch.nn.LayerNorm((fs,), device="cuda", dtype=dtype)
            X = torch.randn(bs, fs, device="cuda", dtype=dtype, requires_grad=True)
            gO = torch.rand_like(X)
            stmtfwd = "ln(X)"
            stmtfwdbwd = "X.grad=None; ln.zero_grad(set_to_none=True); out = ln(X); out.backward(gO)"
            tfwd = Timer(
                stmt=stmtfwd,
                label="ln",
                sub_label=f"{bs:5}, {fs:5}",
                description=f"fwd, {dtype}",
                globals=globals(),
            )
            tfwdbwd = Timer(
                stmt=stmtfwdbwd,
                label="ln",
                sub_label=f"{bs:5}, {fs:5}",
                description=f"fwdbwd, {dtype}",
                globals=globals(),
            )
            for t in (tfwd, tfwdbwd):
                results.append(t.blocked_autorange())
        print(fs, end="\r")
c = Compare(results)
c.print()
```

Differential Revision: D40567574

Pull Request resolved: pytorch#87445
Approved by: https://github.com/ngimel
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
…ytorch#87445)"

This reverts commit b6f2833.

Reverted pytorch#87445 on behalf of https://github.com/weiwangmeta due to breaking internal builds due to MS compiler
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
…87445)

Test Plan:
```
Times below are Forward + Backward on A100

       Size             FP32.   Gain.   FP16.   Gain
        256,   256  	101.30	9%	103.9	6%
        512,   256  	110.10	-4%	102.9	10%
       1024,   256  	104.30	7%	102.4	6%
       2048,   256  	107.60	4%	109.7	0%
       4096,   256  	116.70	8%	109.1	0%
       6144,   256  	106.10	7%	112.8	2%
       8192,   256  	106.10	1%	109.7	2%
        256,   512  	102.10	3%	108.5	1%
        512,   512  	101.50	40%	105.9	4%
       1024,   512  	109.70	20%	109.2	-1%
       2048,   512  	107.40	24%	107.2	1%
       4096,   512  	108.00	6%	110.6	-3%
       6144,   512  	103.90	13%	105.8	7%
       8192,   512  	138.70	14%	105.6	7%
        256,  1024  	106.20	1%	102.9	6%
        512,  1024  	104.50	4%	104.2	3%
       1024,  1024  	126.90	-15%	103.9	10%
       2048,  1024  	127.40	-15%	102.2	6%
       4096,  1024  	117.70	6%	102.8	21%
       6144,  1024  	165.30	11%	112.2	12%
       8192,  1024  	211.90	11%	144.8	13%
        256,  1536  	102.80	11%	103.1	6%
        512,  1536  	103.30	9%	102.9	18%
       1024,  1536  	111.00	-2%	117.2	7%
       2048,  1536  	102.30	12%	132.1	-4%
       4096,  1536  	165.50	5%	112.9	18%
       6144,  1536  	236.60	5%	145.7	12%
       8192,  1536  	307.80	5%	186.1	11%
        256,  2048  	110.60	-1%	103.8	7%
        512,  2048  	105.20	3%	105.6	1%
       1024,  2048  	106.70	3%	114.8	3%
       2048,  2048  	124.90	5%	109.7	0%
       4096,  2048  	231.40	4%	129.9	10%
       6144,  2048  	332.80	4%	182.5	11%
       8192,  2048  	434.60	4%	235.2	11%
        256,  3072  	111.60	8%	110.8	1%
        512,  3072  	106.80	1%	104.6	10%
       1024,  3072  	104.90	3%	109.9	4%
       2048,  3072  	193.80	0%	106.2	10%
       4096,  3072  	364.50	0%	187.8	5%
       6144,  3072  	538.30	0%	267	5%
       8192,  3072  	718.00	-1%	346.7	6%
        256,  4096  	103.60	4%	110.2	-1%
        512,  4096  	131.40	-11%	117	-7%
       1024,  4096  	135.80	1%	104.8	7%
       2048,  4096  	268.20	1%	149.4	10%
       4096,  4096  	520.70	1%	268.5	9%
       6144,  4096  	786.30	0%	389.8	9%
       8192,  4096  	1043.50	0%	509	10%
```

Used the following script from ngimel:

```
import torch
from torch.utils.benchmark import Compare, Timer

results = []
for dtype in (torch.float, torch.half):
    for fs in (256, 512, 1024, 1536, 2048, 3072, 4096):
        for bs in (256, 512, 1024, 2048, 4096, 6144, 8192):
            ln = torch.nn.LayerNorm((fs,), device="cuda", dtype=dtype)
            X = torch.randn(bs, fs, device="cuda", dtype=dtype, requires_grad=True)
            gO = torch.rand_like(X)
            stmtfwd = "ln(X)"
            stmtfwdbwd = "X.grad=None; ln.zero_grad(set_to_none=True); out = ln(X); out.backward(gO)"
            tfwd = Timer(
                stmt=stmtfwd,
                label="ln",
                sub_label=f"{bs:5}, {fs:5}",
                description=f"fwd, {dtype}",
                globals=globals(),
            )
            tfwdbwd = Timer(
                stmt=stmtfwdbwd,
                label="ln",
                sub_label=f"{bs:5}, {fs:5}",
                description=f"fwdbwd, {dtype}",
                globals=globals(),
            )
            for t in (tfwd, tfwdbwd):
                results.append(t.blocked_autorange())
        print(fs, end="\r")
c = Compare(results)
c.print()
```

Differential Revision: D40567574

Pull Request resolved: pytorch#87445
Approved by: https://github.com/ngimel
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
…ytorch#87445)"

This reverts commit b6f2833.

Reverted pytorch#87445 on behalf of https://github.com/weiwangmeta due to breaking internal builds due to MS compiler
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged release notes: cuda release notes category Reverted topic: performance topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants