Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pytorch] Layer norm backward speed gain with warp shuffles #87814

Closed

Conversation

valentinandrei
Copy link
Contributor

Summary:
Improved native layer norm backward performance.

Rewrote GammaBetaBackwardCUDAKernel to use shared memory only for the reduction step, but not for loading mean and rstd. The previous implementation used only threadIdx.x = 0 to load mean and rstd into shared memory, and then all threads would access the values in order to do loop unrolling. This approached increased register usage and decreased occupancy, without much benefit from using shared memory (this is because the values were already cached in L1). The new implementation is simpler and register usage is smaller, thus occupancy is better.

Added another implementation called GammaBetaBackwardCUDAKernel_32x32 which is only for shapes dividing exactly to a (32 x 32) block. This permits using warp shuffles for speeding up loading mean and rstd as well as for the final reduction stage. The effective bandwidth of this implementation is equal to STREAM Triad.

Observed that we can get additional benefit if we lower the threshold for calling GammaBetaBackwardSimpleCUDAKernel (simple col-wise reduction implementation) from 512 to 128.

Test Plan:
Wrote a simple CUDA app that calls the previous implementation of GammaBetaBackwardCUDAKernel and the current one, using FP32 values and compares the results. The epsilon value we used for FP comparison is 0.00001 for the weight and 0.0001 for the bias.
Ran the benchmark for various sizes A100 GPU and got the results below. Almost all sizes show good speedup.

Size (32, 32); Mismatches: dg = 0 db = 0 out of 32. reference = 0.0073 (ms); optimized = 0.0071 (ms); bw_opt = 1.14 GB/s; speedup = 2.68%
Size (64, 32); Mismatches: dg = 0 db = 0 out of 32. reference = 0.0107 (ms); optimized = 0.0107 (ms); bw_opt = 1.50 GB/s; speedup = 0.22%
Size (256, 128); Mismatches: dg = 0 db = 0 out of 128. reference = 0.0323 (ms); optimized = 0.0075 (ms); bw_opt = 32.89 GB/s; speedup = 330.16%
Size (512, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0103 (ms); optimized = 0.0089 (ms); bw_opt = 440.54 GB/s; speedup = 15.82%
Size (1024, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0197 (ms); optimized = 0.0136 (ms); bw_opt = 1151.44 GB/s; speedup = 44.91%
Size (2048, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0416 (ms); optimized = 0.0283 (ms); bw_opt = 1105.31 GB/s; speedup = 47.01%
Size (4096, 16384); Mismatches: dg = 0 db = 0 out of 16384. reference = 0.4420 (ms); optimized = 0.3915 (ms); bw_opt = 1277.58 GB/s; speedup = 12.90%
Size (70000, 64); Mismatches: dg = 0 db = 0 out of 64. reference = 0.5908 (ms); optimized = 0.6850 (ms); bw_opt = 49.49 GB/s; speedup = -13.75%
Size (131072, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 1.1961 (ms); optimized = 0.9234 (ms); bw_opt = 542.54 GB/s; speedup = 29.53%
Size (1000, 520); Mismatches: dg = 0 db = 0 out of 520. reference = 0.0132 (ms); optimized = 0.0113 (ms); bw_opt = 343.83 GB/s; speedup = 16.88%
Size (4005, 4005); Mismatches: dg = 0 db = 0 out of 4005. reference = 0.1441 (ms); optimized = 0.1054 (ms); bw_opt = 1134.36 GB/s; speedup = 36.71%
Size (10000, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.1293 (ms); optimized = 0.1248 (ms); bw_opt = 597.71 GB/s; speedup = 3.63%
Size (1024, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.0738 (ms); optimized = 0.0735 (ms); bw_opt = 1039.40 GB/s; speedup = 0.45%
Size (8192, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.2673 (ms); optimized = 0.2223 (ms); bw_opt = 1125.01 GB/s; speedup = 20.25%
Size (10000, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.7331 (ms); optimized = 0.8940 (ms); bw_opt = 833.54 GB/s; speedup = -18.00%
Size (3072, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.2087 (ms); optimized = 0.2364 (ms); bw_opt = 968.64 GB/s; speedup = -11.71%
Size (6144, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.4197 (ms); optimized = 0.5118 (ms); bw_opt = 894.63 GB/s; speedup = -18.00%
Size (1024, 20000); Mismatches: dg = 0 db = 0 out of 20000. reference = 0.1480 (ms); optimized = 0.1297 (ms); bw_opt = 1177.68 GB/s; speedup = 14.12%
Size (1024, 20000); Mismatches: dg = 0 db = 0 out of 20000. reference = 0.1483 (ms); optimized = 0.1278 (ms); bw_opt = 1195.26 GB/s; speedup = 16.04%
Size (512, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0104 (ms); optimized = 0.0091 (ms); bw_opt = 646.72 GB/s; speedup = 14.44%
Size (512, 6144); Mismatches: dg = 0 db = 0 out of 6144. reference = 0.0219 (ms); optimized = 0.0156 (ms); bw_opt = 1506.30 GB/s; speedup = 40.52%
Size (512, 10240); Mismatches: dg = 0 db = 0 out of 10240. reference = 0.0424 (ms); optimized = 0.0370 (ms); bw_opt = 1057.84 GB/s; speedup = 14.63%
Size (1000, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0139 (ms); optimized = 0.0119 (ms); bw_opt = 627.51 GB/s; speedup = 16.83%
Size (2000, 2000); Mismatches: dg = 0 db = 0 out of 2000. reference = 0.0421 (ms); optimized = 0.0412 (ms); bw_opt = 724.10 GB/s; speedup = 2.20%
Size (10240, 10240); Mismatches: dg = 0 db = 0 out of 10240. reference = 0.7210 (ms); optimized = 0.6098 (ms); bw_opt = 1281.40 GB/s; speedup = 18.24%
Size (384, 128); Mismatches: dg = 0 db = 0 out of 128. reference = 0.0449 (ms); optimized = 0.0089 (ms); bw_opt = 41.50 GB/s; speedup = 403.48%
Size (2048, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0208 (ms); optimized = 0.0169 (ms); bw_opt = 925.70 GB/s; speedup = 23.13%
Size (267, 513); Mismatches: dg = 0 db = 0 out of 513. reference = 0.0342 (ms); optimized = 0.0090 (ms); bw_opt = 114.18 GB/s; speedup = 280.64%
Size (67, 123479); Mismatches: dg = 0 db = 0 out of 123479. reference = 0.0562 (ms); optimized = 0.0552 (ms); bw_opt = 1133.46 GB/s; speedup = 1.81%
Size (1024, 123479); Mismatches: dg = 0 db = 0 out of 123479. reference = 0.8573 (ms); optimized = 0.9245 (ms); bw_opt = 1020.02 GB/s; speedup = -7.27%
Size (2048, 66679); Mismatches: dg = 0 db = 0 out of 66679. reference = 0.8778 (ms); optimized = 0.8590 (ms); bw_opt = 1185.05 GB/s; speedup = 2.19%
Size (200, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0215 (ms); optimized = 0.0066 (ms); bw_opt = 58.49 GB/s; speedup = 226.81%
Size (1000, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0109 (ms); optimized = 0.0092 (ms); bw_opt = 208.27 GB/s; speedup = 18.65%
Size (6000, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0394 (ms); optimized = 0.0301 (ms); bw_opt = 381.90 GB/s; speedup = 30.98%
Size (6272, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0403 (ms); optimized = 0.0300 (ms); bw_opt = 400.48 GB/s; speedup = 34.34%
Size (200, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0218 (ms); optimized = 0.0066 (ms); bw_opt = 116.33 GB/s; speedup = 229.96%
Size (1000, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0110 (ms); optimized = 0.0094 (ms); bw_opt = 407.29 GB/s; speedup = 17.26%
Size (6000, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0535 (ms); optimized = 0.0594 (ms); bw_opt = 386.05 GB/s; speedup = -9.95%
Size (6272, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0573 (ms); optimized = 0.0387 (ms); bw_opt = 619.62 GB/s; speedup = 48.06%
Size (200, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0221 (ms); optimized = 0.0069 (ms); bw_opt = 222.78 GB/s; speedup = 220.76%
Size (1000, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0113 (ms); optimized = 0.0097 (ms); bw_opt = 787.79 GB/s; speedup = 16.46%
Size (6000, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0723 (ms); optimized = 0.0715 (ms); bw_opt = 640.95 GB/s; speedup = 1.10%
Size (6272, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0751 (ms); optimized = 0.0572 (ms); bw_opt = 837.57 GB/s; speedup = 31.30%
Size (200, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0232 (ms); optimized = 0.0071 (ms); bw_opt = 323.97 GB/s; speedup = 226.51%
Size (1000, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0125 (ms); optimized = 0.0114 (ms); bw_opt = 1005.84 GB/s; speedup = 9.62%
Size (6000, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0807 (ms); optimized = 0.0830 (ms); bw_opt = 828.02 GB/s; speedup = -2.76%
Size (6272, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0836 (ms); optimized = 0.0695 (ms); bw_opt = 1033.62 GB/s; speedup = 20.27%
Size (200, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0224 (ms); optimized = 0.0075 (ms); bw_opt = 408.58 GB/s; speedup = 198.10%
Size (1000, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0165 (ms); optimized = 0.0135 (ms); bw_opt = 1132.42 GB/s; speedup = 22.26%
Size (6000, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0993 (ms); optimized = 0.0989 (ms); bw_opt = 926.35 GB/s; speedup = 0.41%
Size (6272, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.1033 (ms); optimized = 0.0826 (ms); bw_opt = 1159.55 GB/s; speedup = 25.09%
Size (200, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.0230 (ms); optimized = 0.0076 (ms); bw_opt = 605.09 GB/s; speedup = 202.51%
Size (1000, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.0207 (ms); optimized = 0.0213 (ms); bw_opt = 1076.45 GB/s; speedup = -2.69%
Size (6000, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.1198 (ms); optimized = 0.1274 (ms); bw_opt = 1078.58 GB/s; speedup = -5.95%
Size (6272, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.1293 (ms); optimized = 0.1189 (ms); bw_opt = 1207.95 GB/s; speedup = 8.76%

Average speedup = 52.88%

For additional numerical validation used the following script:

def run_model_on_device(fs, X, gO, device_string, numeric_type):
    ln = torch.nn.LayerNorm((fs,), device=device_string, dtype=numeric_type)
    ln.reset_parameters()
    X.grad = None
    ln.zero_grad(set_to_none=True)
    out = ln(X)
    out.backward(gO)
    return (ln.weight.grad, ln.bias.grad)

def run_correctness_test(eps_weight, eps_bias):
    dtype = torch.float
    for fs in (512, 1024, 2048, 4096, 8192, 10000, 500, 1000, 2001, 4005, 8117):
        for bs in (512, 1024, 2048, 4096, 525, 1033, 2064, 3000):
            mean_adjustment = torch.randn(fs, device="cpu", dtype=torch.float)
            X = mean_adjustment * torch.randn(
                bs, fs, device="cpu", dtype=torch.float, requires_grad=True
            )

            X = X.detach().requires_grad_()
            gO = torch.rand_like(X)
            X_gpu = X.to("cuda")
            X_gpu = X_gpu.detach().requires_grad_()
            gO_gpu = gO.to("cuda")
            gO_gpu = gO_gpu.detach().requires_grad_()

            grad_cpu_ref = run_model_on_device(fs, X, gO, "cpu", dtype)
            grad_gpu = run_model_on_device(fs, X_gpu, gO_gpu, "cuda", dtype)
            weight_grad_gpu_target = grad_gpu[0].detach().to("cpu")
            bias_grad_gpu_target = grad_gpu[1].detach().to("cpu")

            weight_delta = torch.abs(grad_cpu_ref[0] - weight_grad_gpu_target)
            weight_mismatches = (weight_delta >= eps_weight).nonzero()
            weight_mismatch_pct = len(weight_mismatches) / len(weight_delta) * 100

            bias_delta = torch.abs(grad_cpu_ref[1] - bias_grad_gpu_target)
            bias_mismatches = (bias_delta >= eps_bias).nonzero()
            bias_mismatch_pct = len(bias_mismatches) / len(bias_delta) * 100

            print(
                "Size ({} x {}) mismatch percentage: weight {:3.2f} bias {:3.2f}".format(
                    fs, bs, weight_mismatch_pct, bias_mismatch_pct
                )
            )

NVFuserTest.FusionMagicSchedulerLayerNormBackward_CUDA test also does additional numerical validation and it passes.

Differential Revision: D40730981

Summary:
Improved native layer norm backward performance.

Rewrote `GammaBetaBackwardCUDAKernel` to use shared memory only for the reduction step, but not for loading `mean` and `rstd`. The previous implementation used only `threadIdx.x = 0` to load `mean` and `rstd` into shared memory, and then all threads would access the values in order to do loop unrolling. This approached increased register usage and decreased occupancy, without much benefit from using shared memory (this is because the values were already cached in L1). The new implementation is simpler and register usage is smaller, thus occupancy is better.

Added another implementation called `GammaBetaBackwardCUDAKernel_32x32` which is only for shapes dividing exactly to a (32 x 32) block. This permits using warp shuffles for speeding up loading `mean` and `rstd` as well as for the final reduction stage. The effective bandwidth of this implementation is equal to STREAM Triad.

Observed that we can get additional benefit if we lower the threshold for calling `GammaBetaBackwardSimpleCUDAKernel` (simple col-wise reduction implementation) from `512` to `128`.

Test Plan:
Wrote a simple CUDA app that calls the previous implementation of `GammaBetaBackwardCUDAKernel` and the current one, using FP32 values and compares the results. The epsilon value we used for FP comparison is 0.00001 for the weight and 0.0001 for the bias.
Ran the benchmark for various sizes A100 GPU and got the results below. Almost all sizes show good speedup.

```
Size (32, 32); Mismatches: dg = 0 db = 0 out of 32. reference = 0.0073 (ms); optimized = 0.0071 (ms); bw_opt = 1.14 GB/s; speedup = 2.68%
Size (64, 32); Mismatches: dg = 0 db = 0 out of 32. reference = 0.0107 (ms); optimized = 0.0107 (ms); bw_opt = 1.50 GB/s; speedup = 0.22%
Size (256, 128); Mismatches: dg = 0 db = 0 out of 128. reference = 0.0323 (ms); optimized = 0.0075 (ms); bw_opt = 32.89 GB/s; speedup = 330.16%
Size (512, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0103 (ms); optimized = 0.0089 (ms); bw_opt = 440.54 GB/s; speedup = 15.82%
Size (1024, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0197 (ms); optimized = 0.0136 (ms); bw_opt = 1151.44 GB/s; speedup = 44.91%
Size (2048, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0416 (ms); optimized = 0.0283 (ms); bw_opt = 1105.31 GB/s; speedup = 47.01%
Size (4096, 16384); Mismatches: dg = 0 db = 0 out of 16384. reference = 0.4420 (ms); optimized = 0.3915 (ms); bw_opt = 1277.58 GB/s; speedup = 12.90%
Size (70000, 64); Mismatches: dg = 0 db = 0 out of 64. reference = 0.5908 (ms); optimized = 0.6850 (ms); bw_opt = 49.49 GB/s; speedup = -13.75%
Size (131072, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 1.1961 (ms); optimized = 0.9234 (ms); bw_opt = 542.54 GB/s; speedup = 29.53%
Size (1000, 520); Mismatches: dg = 0 db = 0 out of 520. reference = 0.0132 (ms); optimized = 0.0113 (ms); bw_opt = 343.83 GB/s; speedup = 16.88%
Size (4005, 4005); Mismatches: dg = 0 db = 0 out of 4005. reference = 0.1441 (ms); optimized = 0.1054 (ms); bw_opt = 1134.36 GB/s; speedup = 36.71%
Size (10000, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.1293 (ms); optimized = 0.1248 (ms); bw_opt = 597.71 GB/s; speedup = 3.63%
Size (1024, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.0738 (ms); optimized = 0.0735 (ms); bw_opt = 1039.40 GB/s; speedup = 0.45%
Size (8192, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.2673 (ms); optimized = 0.2223 (ms); bw_opt = 1125.01 GB/s; speedup = 20.25%
Size (10000, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.7331 (ms); optimized = 0.8940 (ms); bw_opt = 833.54 GB/s; speedup = -18.00%
Size (3072, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.2087 (ms); optimized = 0.2364 (ms); bw_opt = 968.64 GB/s; speedup = -11.71%
Size (6144, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.4197 (ms); optimized = 0.5118 (ms); bw_opt = 894.63 GB/s; speedup = -18.00%
Size (1024, 20000); Mismatches: dg = 0 db = 0 out of 20000. reference = 0.1480 (ms); optimized = 0.1297 (ms); bw_opt = 1177.68 GB/s; speedup = 14.12%
Size (1024, 20000); Mismatches: dg = 0 db = 0 out of 20000. reference = 0.1483 (ms); optimized = 0.1278 (ms); bw_opt = 1195.26 GB/s; speedup = 16.04%
Size (512, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0104 (ms); optimized = 0.0091 (ms); bw_opt = 646.72 GB/s; speedup = 14.44%
Size (512, 6144); Mismatches: dg = 0 db = 0 out of 6144. reference = 0.0219 (ms); optimized = 0.0156 (ms); bw_opt = 1506.30 GB/s; speedup = 40.52%
Size (512, 10240); Mismatches: dg = 0 db = 0 out of 10240. reference = 0.0424 (ms); optimized = 0.0370 (ms); bw_opt = 1057.84 GB/s; speedup = 14.63%
Size (1000, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0139 (ms); optimized = 0.0119 (ms); bw_opt = 627.51 GB/s; speedup = 16.83%
Size (2000, 2000); Mismatches: dg = 0 db = 0 out of 2000. reference = 0.0421 (ms); optimized = 0.0412 (ms); bw_opt = 724.10 GB/s; speedup = 2.20%
Size (10240, 10240); Mismatches: dg = 0 db = 0 out of 10240. reference = 0.7210 (ms); optimized = 0.6098 (ms); bw_opt = 1281.40 GB/s; speedup = 18.24%
Size (384, 128); Mismatches: dg = 0 db = 0 out of 128. reference = 0.0449 (ms); optimized = 0.0089 (ms); bw_opt = 41.50 GB/s; speedup = 403.48%
Size (2048, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0208 (ms); optimized = 0.0169 (ms); bw_opt = 925.70 GB/s; speedup = 23.13%
Size (267, 513); Mismatches: dg = 0 db = 0 out of 513. reference = 0.0342 (ms); optimized = 0.0090 (ms); bw_opt = 114.18 GB/s; speedup = 280.64%
Size (67, 123479); Mismatches: dg = 0 db = 0 out of 123479. reference = 0.0562 (ms); optimized = 0.0552 (ms); bw_opt = 1133.46 GB/s; speedup = 1.81%
Size (1024, 123479); Mismatches: dg = 0 db = 0 out of 123479. reference = 0.8573 (ms); optimized = 0.9245 (ms); bw_opt = 1020.02 GB/s; speedup = -7.27%
Size (2048, 66679); Mismatches: dg = 0 db = 0 out of 66679. reference = 0.8778 (ms); optimized = 0.8590 (ms); bw_opt = 1185.05 GB/s; speedup = 2.19%
Size (200, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0215 (ms); optimized = 0.0066 (ms); bw_opt = 58.49 GB/s; speedup = 226.81%
Size (1000, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0109 (ms); optimized = 0.0092 (ms); bw_opt = 208.27 GB/s; speedup = 18.65%
Size (6000, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0394 (ms); optimized = 0.0301 (ms); bw_opt = 381.90 GB/s; speedup = 30.98%
Size (6272, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0403 (ms); optimized = 0.0300 (ms); bw_opt = 400.48 GB/s; speedup = 34.34%
Size (200, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0218 (ms); optimized = 0.0066 (ms); bw_opt = 116.33 GB/s; speedup = 229.96%
Size (1000, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0110 (ms); optimized = 0.0094 (ms); bw_opt = 407.29 GB/s; speedup = 17.26%
Size (6000, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0535 (ms); optimized = 0.0594 (ms); bw_opt = 386.05 GB/s; speedup = -9.95%
Size (6272, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0573 (ms); optimized = 0.0387 (ms); bw_opt = 619.62 GB/s; speedup = 48.06%
Size (200, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0221 (ms); optimized = 0.0069 (ms); bw_opt = 222.78 GB/s; speedup = 220.76%
Size (1000, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0113 (ms); optimized = 0.0097 (ms); bw_opt = 787.79 GB/s; speedup = 16.46%
Size (6000, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0723 (ms); optimized = 0.0715 (ms); bw_opt = 640.95 GB/s; speedup = 1.10%
Size (6272, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0751 (ms); optimized = 0.0572 (ms); bw_opt = 837.57 GB/s; speedup = 31.30%
Size (200, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0232 (ms); optimized = 0.0071 (ms); bw_opt = 323.97 GB/s; speedup = 226.51%
Size (1000, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0125 (ms); optimized = 0.0114 (ms); bw_opt = 1005.84 GB/s; speedup = 9.62%
Size (6000, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0807 (ms); optimized = 0.0830 (ms); bw_opt = 828.02 GB/s; speedup = -2.76%
Size (6272, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0836 (ms); optimized = 0.0695 (ms); bw_opt = 1033.62 GB/s; speedup = 20.27%
Size (200, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0224 (ms); optimized = 0.0075 (ms); bw_opt = 408.58 GB/s; speedup = 198.10%
Size (1000, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0165 (ms); optimized = 0.0135 (ms); bw_opt = 1132.42 GB/s; speedup = 22.26%
Size (6000, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0993 (ms); optimized = 0.0989 (ms); bw_opt = 926.35 GB/s; speedup = 0.41%
Size (6272, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.1033 (ms); optimized = 0.0826 (ms); bw_opt = 1159.55 GB/s; speedup = 25.09%
Size (200, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.0230 (ms); optimized = 0.0076 (ms); bw_opt = 605.09 GB/s; speedup = 202.51%
Size (1000, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.0207 (ms); optimized = 0.0213 (ms); bw_opt = 1076.45 GB/s; speedup = -2.69%
Size (6000, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.1198 (ms); optimized = 0.1274 (ms); bw_opt = 1078.58 GB/s; speedup = -5.95%
Size (6272, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.1293 (ms); optimized = 0.1189 (ms); bw_opt = 1207.95 GB/s; speedup = 8.76%

Average speedup = 52.88%
```

For additional numerical validation used the following script:

```
def run_model_on_device(fs, X, gO, device_string, numeric_type):
    ln = torch.nn.LayerNorm((fs,), device=device_string, dtype=numeric_type)
    ln.reset_parameters()
    X.grad = None
    ln.zero_grad(set_to_none=True)
    out = ln(X)
    out.backward(gO)
    return (ln.weight.grad, ln.bias.grad)

def run_correctness_test(eps_weight, eps_bias):
    dtype = torch.float
    for fs in (512, 1024, 2048, 4096, 8192, 10000, 500, 1000, 2001, 4005, 8117):
        for bs in (512, 1024, 2048, 4096, 525, 1033, 2064, 3000):
            mean_adjustment = torch.randn(fs, device="cpu", dtype=torch.float)
            X = mean_adjustment * torch.randn(
                bs, fs, device="cpu", dtype=torch.float, requires_grad=True
            )

            X = X.detach().requires_grad_()
            gO = torch.rand_like(X)
            X_gpu = X.to("cuda")
            X_gpu = X_gpu.detach().requires_grad_()
            gO_gpu = gO.to("cuda")
            gO_gpu = gO_gpu.detach().requires_grad_()

            grad_cpu_ref = run_model_on_device(fs, X, gO, "cpu", dtype)
            grad_gpu = run_model_on_device(fs, X_gpu, gO_gpu, "cuda", dtype)
            weight_grad_gpu_target = grad_gpu[0].detach().to("cpu")
            bias_grad_gpu_target = grad_gpu[1].detach().to("cpu")

            weight_delta = torch.abs(grad_cpu_ref[0] - weight_grad_gpu_target)
            weight_mismatches = (weight_delta >= eps_weight).nonzero()
            weight_mismatch_pct = len(weight_mismatches) / len(weight_delta) * 100

            bias_delta = torch.abs(grad_cpu_ref[1] - bias_grad_gpu_target)
            bias_mismatches = (bias_delta >= eps_bias).nonzero()
            bias_mismatch_pct = len(bias_mismatches) / len(bias_delta) * 100

            print(
                "Size ({} x {}) mismatch percentage: weight {:3.2f} bias {:3.2f}".format(
                    fs, bs, weight_mismatch_pct, bias_mismatch_pct
                )
            )
```

`NVFuserTest.FusionMagicSchedulerLayerNormBackward_CUDA` test also does additional numerical validation and it passes.

Differential Revision: D40730981

fbshipit-source-id: bec81da3dd2a83b85328153ab3a4eeb85f3e0b60
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 26, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87814

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures, 5 Pending

As of commit a1de5b1:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: cuda release notes category label Oct 26, 2022
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D40730981

@valentinandrei
Copy link
Contributor Author

@pytorchbot label "topic: performance"

@pytorch-bot pytorch-bot bot added the topic: performance topic category label Oct 26, 2022
@weiwangmeta
Copy link
Contributor

@pytorchbot merge

@weiwangmeta
Copy link
Contributor

Skipping reviews as this is largely the same as #87445 that has been reviewed and approved.

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: PR #87814 has not been reviewed yet (Rule superuser)

Details for Dev Infra team Raised by workflow job

Copy link
Contributor

@weiwangmeta weiwangmeta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 26, 2022
@weiwangmeta
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

sgrigory pushed a commit to sgrigory/pytorch that referenced this pull request Oct 28, 2022
…87814)

Summary:
Improved native layer norm backward performance.

Rewrote `GammaBetaBackwardCUDAKernel` to use shared memory only for the reduction step, but not for loading `mean` and `rstd`. The previous implementation used only `threadIdx.x = 0` to load `mean` and `rstd` into shared memory, and then all threads would access the values in order to do loop unrolling. This approached increased register usage and decreased occupancy, without much benefit from using shared memory (this is because the values were already cached in L1). The new implementation is simpler and register usage is smaller, thus occupancy is better.

Added another implementation called `GammaBetaBackwardCUDAKernel_32x32` which is only for shapes dividing exactly to a (32 x 32) block. This permits using warp shuffles for speeding up loading `mean` and `rstd` as well as for the final reduction stage. The effective bandwidth of this implementation is equal to STREAM Triad.

Observed that we can get additional benefit if we lower the threshold for calling `GammaBetaBackwardSimpleCUDAKernel` (simple col-wise reduction implementation) from `512` to `128`.

Test Plan:
Wrote a simple CUDA app that calls the previous implementation of `GammaBetaBackwardCUDAKernel` and the current one, using FP32 values and compares the results. The epsilon value we used for FP comparison is 0.00001 for the weight and 0.0001 for the bias.
Ran the benchmark for various sizes A100 GPU and got the results below. Almost all sizes show good speedup.

```
Size (32, 32); Mismatches: dg = 0 db = 0 out of 32. reference = 0.0073 (ms); optimized = 0.0071 (ms); bw_opt = 1.14 GB/s; speedup = 2.68%
Size (64, 32); Mismatches: dg = 0 db = 0 out of 32. reference = 0.0107 (ms); optimized = 0.0107 (ms); bw_opt = 1.50 GB/s; speedup = 0.22%
Size (256, 128); Mismatches: dg = 0 db = 0 out of 128. reference = 0.0323 (ms); optimized = 0.0075 (ms); bw_opt = 32.89 GB/s; speedup = 330.16%
Size (512, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0103 (ms); optimized = 0.0089 (ms); bw_opt = 440.54 GB/s; speedup = 15.82%
Size (1024, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0197 (ms); optimized = 0.0136 (ms); bw_opt = 1151.44 GB/s; speedup = 44.91%
Size (2048, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0416 (ms); optimized = 0.0283 (ms); bw_opt = 1105.31 GB/s; speedup = 47.01%
Size (4096, 16384); Mismatches: dg = 0 db = 0 out of 16384. reference = 0.4420 (ms); optimized = 0.3915 (ms); bw_opt = 1277.58 GB/s; speedup = 12.90%
Size (70000, 64); Mismatches: dg = 0 db = 0 out of 64. reference = 0.5908 (ms); optimized = 0.6850 (ms); bw_opt = 49.49 GB/s; speedup = -13.75%
Size (131072, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 1.1961 (ms); optimized = 0.9234 (ms); bw_opt = 542.54 GB/s; speedup = 29.53%
Size (1000, 520); Mismatches: dg = 0 db = 0 out of 520. reference = 0.0132 (ms); optimized = 0.0113 (ms); bw_opt = 343.83 GB/s; speedup = 16.88%
Size (4005, 4005); Mismatches: dg = 0 db = 0 out of 4005. reference = 0.1441 (ms); optimized = 0.1054 (ms); bw_opt = 1134.36 GB/s; speedup = 36.71%
Size (10000, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.1293 (ms); optimized = 0.1248 (ms); bw_opt = 597.71 GB/s; speedup = 3.63%
Size (1024, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.0738 (ms); optimized = 0.0735 (ms); bw_opt = 1039.40 GB/s; speedup = 0.45%
Size (8192, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.2673 (ms); optimized = 0.2223 (ms); bw_opt = 1125.01 GB/s; speedup = 20.25%
Size (10000, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.7331 (ms); optimized = 0.8940 (ms); bw_opt = 833.54 GB/s; speedup = -18.00%
Size (3072, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.2087 (ms); optimized = 0.2364 (ms); bw_opt = 968.64 GB/s; speedup = -11.71%
Size (6144, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.4197 (ms); optimized = 0.5118 (ms); bw_opt = 894.63 GB/s; speedup = -18.00%
Size (1024, 20000); Mismatches: dg = 0 db = 0 out of 20000. reference = 0.1480 (ms); optimized = 0.1297 (ms); bw_opt = 1177.68 GB/s; speedup = 14.12%
Size (1024, 20000); Mismatches: dg = 0 db = 0 out of 20000. reference = 0.1483 (ms); optimized = 0.1278 (ms); bw_opt = 1195.26 GB/s; speedup = 16.04%
Size (512, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0104 (ms); optimized = 0.0091 (ms); bw_opt = 646.72 GB/s; speedup = 14.44%
Size (512, 6144); Mismatches: dg = 0 db = 0 out of 6144. reference = 0.0219 (ms); optimized = 0.0156 (ms); bw_opt = 1506.30 GB/s; speedup = 40.52%
Size (512, 10240); Mismatches: dg = 0 db = 0 out of 10240. reference = 0.0424 (ms); optimized = 0.0370 (ms); bw_opt = 1057.84 GB/s; speedup = 14.63%
Size (1000, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0139 (ms); optimized = 0.0119 (ms); bw_opt = 627.51 GB/s; speedup = 16.83%
Size (2000, 2000); Mismatches: dg = 0 db = 0 out of 2000. reference = 0.0421 (ms); optimized = 0.0412 (ms); bw_opt = 724.10 GB/s; speedup = 2.20%
Size (10240, 10240); Mismatches: dg = 0 db = 0 out of 10240. reference = 0.7210 (ms); optimized = 0.6098 (ms); bw_opt = 1281.40 GB/s; speedup = 18.24%
Size (384, 128); Mismatches: dg = 0 db = 0 out of 128. reference = 0.0449 (ms); optimized = 0.0089 (ms); bw_opt = 41.50 GB/s; speedup = 403.48%
Size (2048, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0208 (ms); optimized = 0.0169 (ms); bw_opt = 925.70 GB/s; speedup = 23.13%
Size (267, 513); Mismatches: dg = 0 db = 0 out of 513. reference = 0.0342 (ms); optimized = 0.0090 (ms); bw_opt = 114.18 GB/s; speedup = 280.64%
Size (67, 123479); Mismatches: dg = 0 db = 0 out of 123479. reference = 0.0562 (ms); optimized = 0.0552 (ms); bw_opt = 1133.46 GB/s; speedup = 1.81%
Size (1024, 123479); Mismatches: dg = 0 db = 0 out of 123479. reference = 0.8573 (ms); optimized = 0.9245 (ms); bw_opt = 1020.02 GB/s; speedup = -7.27%
Size (2048, 66679); Mismatches: dg = 0 db = 0 out of 66679. reference = 0.8778 (ms); optimized = 0.8590 (ms); bw_opt = 1185.05 GB/s; speedup = 2.19%
Size (200, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0215 (ms); optimized = 0.0066 (ms); bw_opt = 58.49 GB/s; speedup = 226.81%
Size (1000, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0109 (ms); optimized = 0.0092 (ms); bw_opt = 208.27 GB/s; speedup = 18.65%
Size (6000, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0394 (ms); optimized = 0.0301 (ms); bw_opt = 381.90 GB/s; speedup = 30.98%
Size (6272, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0403 (ms); optimized = 0.0300 (ms); bw_opt = 400.48 GB/s; speedup = 34.34%
Size (200, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0218 (ms); optimized = 0.0066 (ms); bw_opt = 116.33 GB/s; speedup = 229.96%
Size (1000, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0110 (ms); optimized = 0.0094 (ms); bw_opt = 407.29 GB/s; speedup = 17.26%
Size (6000, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0535 (ms); optimized = 0.0594 (ms); bw_opt = 386.05 GB/s; speedup = -9.95%
Size (6272, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0573 (ms); optimized = 0.0387 (ms); bw_opt = 619.62 GB/s; speedup = 48.06%
Size (200, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0221 (ms); optimized = 0.0069 (ms); bw_opt = 222.78 GB/s; speedup = 220.76%
Size (1000, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0113 (ms); optimized = 0.0097 (ms); bw_opt = 787.79 GB/s; speedup = 16.46%
Size (6000, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0723 (ms); optimized = 0.0715 (ms); bw_opt = 640.95 GB/s; speedup = 1.10%
Size (6272, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0751 (ms); optimized = 0.0572 (ms); bw_opt = 837.57 GB/s; speedup = 31.30%
Size (200, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0232 (ms); optimized = 0.0071 (ms); bw_opt = 323.97 GB/s; speedup = 226.51%
Size (1000, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0125 (ms); optimized = 0.0114 (ms); bw_opt = 1005.84 GB/s; speedup = 9.62%
Size (6000, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0807 (ms); optimized = 0.0830 (ms); bw_opt = 828.02 GB/s; speedup = -2.76%
Size (6272, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0836 (ms); optimized = 0.0695 (ms); bw_opt = 1033.62 GB/s; speedup = 20.27%
Size (200, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0224 (ms); optimized = 0.0075 (ms); bw_opt = 408.58 GB/s; speedup = 198.10%
Size (1000, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0165 (ms); optimized = 0.0135 (ms); bw_opt = 1132.42 GB/s; speedup = 22.26%
Size (6000, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0993 (ms); optimized = 0.0989 (ms); bw_opt = 926.35 GB/s; speedup = 0.41%
Size (6272, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.1033 (ms); optimized = 0.0826 (ms); bw_opt = 1159.55 GB/s; speedup = 25.09%
Size (200, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.0230 (ms); optimized = 0.0076 (ms); bw_opt = 605.09 GB/s; speedup = 202.51%
Size (1000, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.0207 (ms); optimized = 0.0213 (ms); bw_opt = 1076.45 GB/s; speedup = -2.69%
Size (6000, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.1198 (ms); optimized = 0.1274 (ms); bw_opt = 1078.58 GB/s; speedup = -5.95%
Size (6272, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.1293 (ms); optimized = 0.1189 (ms); bw_opt = 1207.95 GB/s; speedup = 8.76%

Average speedup = 52.88%
```

For additional numerical validation used the following script:

```
def run_model_on_device(fs, X, gO, device_string, numeric_type):
    ln = torch.nn.LayerNorm((fs,), device=device_string, dtype=numeric_type)
    ln.reset_parameters()
    X.grad = None
    ln.zero_grad(set_to_none=True)
    out = ln(X)
    out.backward(gO)
    return (ln.weight.grad, ln.bias.grad)

def run_correctness_test(eps_weight, eps_bias):
    dtype = torch.float
    for fs in (512, 1024, 2048, 4096, 8192, 10000, 500, 1000, 2001, 4005, 8117):
        for bs in (512, 1024, 2048, 4096, 525, 1033, 2064, 3000):
            mean_adjustment = torch.randn(fs, device="cpu", dtype=torch.float)
            X = mean_adjustment * torch.randn(
                bs, fs, device="cpu", dtype=torch.float, requires_grad=True
            )

            X = X.detach().requires_grad_()
            gO = torch.rand_like(X)
            X_gpu = X.to("cuda")
            X_gpu = X_gpu.detach().requires_grad_()
            gO_gpu = gO.to("cuda")
            gO_gpu = gO_gpu.detach().requires_grad_()

            grad_cpu_ref = run_model_on_device(fs, X, gO, "cpu", dtype)
            grad_gpu = run_model_on_device(fs, X_gpu, gO_gpu, "cuda", dtype)
            weight_grad_gpu_target = grad_gpu[0].detach().to("cpu")
            bias_grad_gpu_target = grad_gpu[1].detach().to("cpu")

            weight_delta = torch.abs(grad_cpu_ref[0] - weight_grad_gpu_target)
            weight_mismatches = (weight_delta >= eps_weight).nonzero()
            weight_mismatch_pct = len(weight_mismatches) / len(weight_delta) * 100

            bias_delta = torch.abs(grad_cpu_ref[1] - bias_grad_gpu_target)
            bias_mismatches = (bias_delta >= eps_bias).nonzero()
            bias_mismatch_pct = len(bias_mismatches) / len(bias_delta) * 100

            print(
                "Size ({} x {}) mismatch percentage: weight {:3.2f} bias {:3.2f}".format(
                    fs, bs, weight_mismatch_pct, bias_mismatch_pct
                )
            )
```

`NVFuserTest.FusionMagicSchedulerLayerNormBackward_CUDA` test also does additional numerical validation and it passes.

Differential Revision: D40730981

Pull Request resolved: pytorch#87814
Approved by: https://github.com/weiwangmeta
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
…87814)

Summary:
Improved native layer norm backward performance.

Rewrote `GammaBetaBackwardCUDAKernel` to use shared memory only for the reduction step, but not for loading `mean` and `rstd`. The previous implementation used only `threadIdx.x = 0` to load `mean` and `rstd` into shared memory, and then all threads would access the values in order to do loop unrolling. This approached increased register usage and decreased occupancy, without much benefit from using shared memory (this is because the values were already cached in L1). The new implementation is simpler and register usage is smaller, thus occupancy is better.

Added another implementation called `GammaBetaBackwardCUDAKernel_32x32` which is only for shapes dividing exactly to a (32 x 32) block. This permits using warp shuffles for speeding up loading `mean` and `rstd` as well as for the final reduction stage. The effective bandwidth of this implementation is equal to STREAM Triad.

Observed that we can get additional benefit if we lower the threshold for calling `GammaBetaBackwardSimpleCUDAKernel` (simple col-wise reduction implementation) from `512` to `128`.

Test Plan:
Wrote a simple CUDA app that calls the previous implementation of `GammaBetaBackwardCUDAKernel` and the current one, using FP32 values and compares the results. The epsilon value we used for FP comparison is 0.00001 for the weight and 0.0001 for the bias.
Ran the benchmark for various sizes A100 GPU and got the results below. Almost all sizes show good speedup.

```
Size (32, 32); Mismatches: dg = 0 db = 0 out of 32. reference = 0.0073 (ms); optimized = 0.0071 (ms); bw_opt = 1.14 GB/s; speedup = 2.68%
Size (64, 32); Mismatches: dg = 0 db = 0 out of 32. reference = 0.0107 (ms); optimized = 0.0107 (ms); bw_opt = 1.50 GB/s; speedup = 0.22%
Size (256, 128); Mismatches: dg = 0 db = 0 out of 128. reference = 0.0323 (ms); optimized = 0.0075 (ms); bw_opt = 32.89 GB/s; speedup = 330.16%
Size (512, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0103 (ms); optimized = 0.0089 (ms); bw_opt = 440.54 GB/s; speedup = 15.82%
Size (1024, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0197 (ms); optimized = 0.0136 (ms); bw_opt = 1151.44 GB/s; speedup = 44.91%
Size (2048, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0416 (ms); optimized = 0.0283 (ms); bw_opt = 1105.31 GB/s; speedup = 47.01%
Size (4096, 16384); Mismatches: dg = 0 db = 0 out of 16384. reference = 0.4420 (ms); optimized = 0.3915 (ms); bw_opt = 1277.58 GB/s; speedup = 12.90%
Size (70000, 64); Mismatches: dg = 0 db = 0 out of 64. reference = 0.5908 (ms); optimized = 0.6850 (ms); bw_opt = 49.49 GB/s; speedup = -13.75%
Size (131072, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 1.1961 (ms); optimized = 0.9234 (ms); bw_opt = 542.54 GB/s; speedup = 29.53%
Size (1000, 520); Mismatches: dg = 0 db = 0 out of 520. reference = 0.0132 (ms); optimized = 0.0113 (ms); bw_opt = 343.83 GB/s; speedup = 16.88%
Size (4005, 4005); Mismatches: dg = 0 db = 0 out of 4005. reference = 0.1441 (ms); optimized = 0.1054 (ms); bw_opt = 1134.36 GB/s; speedup = 36.71%
Size (10000, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.1293 (ms); optimized = 0.1248 (ms); bw_opt = 597.71 GB/s; speedup = 3.63%
Size (1024, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.0738 (ms); optimized = 0.0735 (ms); bw_opt = 1039.40 GB/s; speedup = 0.45%
Size (8192, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.2673 (ms); optimized = 0.2223 (ms); bw_opt = 1125.01 GB/s; speedup = 20.25%
Size (10000, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.7331 (ms); optimized = 0.8940 (ms); bw_opt = 833.54 GB/s; speedup = -18.00%
Size (3072, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.2087 (ms); optimized = 0.2364 (ms); bw_opt = 968.64 GB/s; speedup = -11.71%
Size (6144, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.4197 (ms); optimized = 0.5118 (ms); bw_opt = 894.63 GB/s; speedup = -18.00%
Size (1024, 20000); Mismatches: dg = 0 db = 0 out of 20000. reference = 0.1480 (ms); optimized = 0.1297 (ms); bw_opt = 1177.68 GB/s; speedup = 14.12%
Size (1024, 20000); Mismatches: dg = 0 db = 0 out of 20000. reference = 0.1483 (ms); optimized = 0.1278 (ms); bw_opt = 1195.26 GB/s; speedup = 16.04%
Size (512, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0104 (ms); optimized = 0.0091 (ms); bw_opt = 646.72 GB/s; speedup = 14.44%
Size (512, 6144); Mismatches: dg = 0 db = 0 out of 6144. reference = 0.0219 (ms); optimized = 0.0156 (ms); bw_opt = 1506.30 GB/s; speedup = 40.52%
Size (512, 10240); Mismatches: dg = 0 db = 0 out of 10240. reference = 0.0424 (ms); optimized = 0.0370 (ms); bw_opt = 1057.84 GB/s; speedup = 14.63%
Size (1000, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0139 (ms); optimized = 0.0119 (ms); bw_opt = 627.51 GB/s; speedup = 16.83%
Size (2000, 2000); Mismatches: dg = 0 db = 0 out of 2000. reference = 0.0421 (ms); optimized = 0.0412 (ms); bw_opt = 724.10 GB/s; speedup = 2.20%
Size (10240, 10240); Mismatches: dg = 0 db = 0 out of 10240. reference = 0.7210 (ms); optimized = 0.6098 (ms); bw_opt = 1281.40 GB/s; speedup = 18.24%
Size (384, 128); Mismatches: dg = 0 db = 0 out of 128. reference = 0.0449 (ms); optimized = 0.0089 (ms); bw_opt = 41.50 GB/s; speedup = 403.48%
Size (2048, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0208 (ms); optimized = 0.0169 (ms); bw_opt = 925.70 GB/s; speedup = 23.13%
Size (267, 513); Mismatches: dg = 0 db = 0 out of 513. reference = 0.0342 (ms); optimized = 0.0090 (ms); bw_opt = 114.18 GB/s; speedup = 280.64%
Size (67, 123479); Mismatches: dg = 0 db = 0 out of 123479. reference = 0.0562 (ms); optimized = 0.0552 (ms); bw_opt = 1133.46 GB/s; speedup = 1.81%
Size (1024, 123479); Mismatches: dg = 0 db = 0 out of 123479. reference = 0.8573 (ms); optimized = 0.9245 (ms); bw_opt = 1020.02 GB/s; speedup = -7.27%
Size (2048, 66679); Mismatches: dg = 0 db = 0 out of 66679. reference = 0.8778 (ms); optimized = 0.8590 (ms); bw_opt = 1185.05 GB/s; speedup = 2.19%
Size (200, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0215 (ms); optimized = 0.0066 (ms); bw_opt = 58.49 GB/s; speedup = 226.81%
Size (1000, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0109 (ms); optimized = 0.0092 (ms); bw_opt = 208.27 GB/s; speedup = 18.65%
Size (6000, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0394 (ms); optimized = 0.0301 (ms); bw_opt = 381.90 GB/s; speedup = 30.98%
Size (6272, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0403 (ms); optimized = 0.0300 (ms); bw_opt = 400.48 GB/s; speedup = 34.34%
Size (200, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0218 (ms); optimized = 0.0066 (ms); bw_opt = 116.33 GB/s; speedup = 229.96%
Size (1000, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0110 (ms); optimized = 0.0094 (ms); bw_opt = 407.29 GB/s; speedup = 17.26%
Size (6000, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0535 (ms); optimized = 0.0594 (ms); bw_opt = 386.05 GB/s; speedup = -9.95%
Size (6272, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0573 (ms); optimized = 0.0387 (ms); bw_opt = 619.62 GB/s; speedup = 48.06%
Size (200, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0221 (ms); optimized = 0.0069 (ms); bw_opt = 222.78 GB/s; speedup = 220.76%
Size (1000, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0113 (ms); optimized = 0.0097 (ms); bw_opt = 787.79 GB/s; speedup = 16.46%
Size (6000, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0723 (ms); optimized = 0.0715 (ms); bw_opt = 640.95 GB/s; speedup = 1.10%
Size (6272, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0751 (ms); optimized = 0.0572 (ms); bw_opt = 837.57 GB/s; speedup = 31.30%
Size (200, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0232 (ms); optimized = 0.0071 (ms); bw_opt = 323.97 GB/s; speedup = 226.51%
Size (1000, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0125 (ms); optimized = 0.0114 (ms); bw_opt = 1005.84 GB/s; speedup = 9.62%
Size (6000, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0807 (ms); optimized = 0.0830 (ms); bw_opt = 828.02 GB/s; speedup = -2.76%
Size (6272, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0836 (ms); optimized = 0.0695 (ms); bw_opt = 1033.62 GB/s; speedup = 20.27%
Size (200, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0224 (ms); optimized = 0.0075 (ms); bw_opt = 408.58 GB/s; speedup = 198.10%
Size (1000, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0165 (ms); optimized = 0.0135 (ms); bw_opt = 1132.42 GB/s; speedup = 22.26%
Size (6000, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0993 (ms); optimized = 0.0989 (ms); bw_opt = 926.35 GB/s; speedup = 0.41%
Size (6272, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.1033 (ms); optimized = 0.0826 (ms); bw_opt = 1159.55 GB/s; speedup = 25.09%
Size (200, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.0230 (ms); optimized = 0.0076 (ms); bw_opt = 605.09 GB/s; speedup = 202.51%
Size (1000, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.0207 (ms); optimized = 0.0213 (ms); bw_opt = 1076.45 GB/s; speedup = -2.69%
Size (6000, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.1198 (ms); optimized = 0.1274 (ms); bw_opt = 1078.58 GB/s; speedup = -5.95%
Size (6272, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.1293 (ms); optimized = 0.1189 (ms); bw_opt = 1207.95 GB/s; speedup = 8.76%

Average speedup = 52.88%
```

For additional numerical validation used the following script:

```
def run_model_on_device(fs, X, gO, device_string, numeric_type):
    ln = torch.nn.LayerNorm((fs,), device=device_string, dtype=numeric_type)
    ln.reset_parameters()
    X.grad = None
    ln.zero_grad(set_to_none=True)
    out = ln(X)
    out.backward(gO)
    return (ln.weight.grad, ln.bias.grad)

def run_correctness_test(eps_weight, eps_bias):
    dtype = torch.float
    for fs in (512, 1024, 2048, 4096, 8192, 10000, 500, 1000, 2001, 4005, 8117):
        for bs in (512, 1024, 2048, 4096, 525, 1033, 2064, 3000):
            mean_adjustment = torch.randn(fs, device="cpu", dtype=torch.float)
            X = mean_adjustment * torch.randn(
                bs, fs, device="cpu", dtype=torch.float, requires_grad=True
            )

            X = X.detach().requires_grad_()
            gO = torch.rand_like(X)
            X_gpu = X.to("cuda")
            X_gpu = X_gpu.detach().requires_grad_()
            gO_gpu = gO.to("cuda")
            gO_gpu = gO_gpu.detach().requires_grad_()

            grad_cpu_ref = run_model_on_device(fs, X, gO, "cpu", dtype)
            grad_gpu = run_model_on_device(fs, X_gpu, gO_gpu, "cuda", dtype)
            weight_grad_gpu_target = grad_gpu[0].detach().to("cpu")
            bias_grad_gpu_target = grad_gpu[1].detach().to("cpu")

            weight_delta = torch.abs(grad_cpu_ref[0] - weight_grad_gpu_target)
            weight_mismatches = (weight_delta >= eps_weight).nonzero()
            weight_mismatch_pct = len(weight_mismatches) / len(weight_delta) * 100

            bias_delta = torch.abs(grad_cpu_ref[1] - bias_grad_gpu_target)
            bias_mismatches = (bias_delta >= eps_bias).nonzero()
            bias_mismatch_pct = len(bias_mismatches) / len(bias_delta) * 100

            print(
                "Size ({} x {}) mismatch percentage: weight {:3.2f} bias {:3.2f}".format(
                    fs, bs, weight_mismatch_pct, bias_mismatch_pct
                )
            )
```

`NVFuserTest.FusionMagicSchedulerLayerNormBackward_CUDA` test also does additional numerical validation and it passes.

Differential Revision: D40730981

Pull Request resolved: pytorch#87814
Approved by: https://github.com/weiwangmeta
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
…87814)

Summary:
Improved native layer norm backward performance.

Rewrote `GammaBetaBackwardCUDAKernel` to use shared memory only for the reduction step, but not for loading `mean` and `rstd`. The previous implementation used only `threadIdx.x = 0` to load `mean` and `rstd` into shared memory, and then all threads would access the values in order to do loop unrolling. This approached increased register usage and decreased occupancy, without much benefit from using shared memory (this is because the values were already cached in L1). The new implementation is simpler and register usage is smaller, thus occupancy is better.

Added another implementation called `GammaBetaBackwardCUDAKernel_32x32` which is only for shapes dividing exactly to a (32 x 32) block. This permits using warp shuffles for speeding up loading `mean` and `rstd` as well as for the final reduction stage. The effective bandwidth of this implementation is equal to STREAM Triad.

Observed that we can get additional benefit if we lower the threshold for calling `GammaBetaBackwardSimpleCUDAKernel` (simple col-wise reduction implementation) from `512` to `128`.

Test Plan:
Wrote a simple CUDA app that calls the previous implementation of `GammaBetaBackwardCUDAKernel` and the current one, using FP32 values and compares the results. The epsilon value we used for FP comparison is 0.00001 for the weight and 0.0001 for the bias.
Ran the benchmark for various sizes A100 GPU and got the results below. Almost all sizes show good speedup.

```
Size (32, 32); Mismatches: dg = 0 db = 0 out of 32. reference = 0.0073 (ms); optimized = 0.0071 (ms); bw_opt = 1.14 GB/s; speedup = 2.68%
Size (64, 32); Mismatches: dg = 0 db = 0 out of 32. reference = 0.0107 (ms); optimized = 0.0107 (ms); bw_opt = 1.50 GB/s; speedup = 0.22%
Size (256, 128); Mismatches: dg = 0 db = 0 out of 128. reference = 0.0323 (ms); optimized = 0.0075 (ms); bw_opt = 32.89 GB/s; speedup = 330.16%
Size (512, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0103 (ms); optimized = 0.0089 (ms); bw_opt = 440.54 GB/s; speedup = 15.82%
Size (1024, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0197 (ms); optimized = 0.0136 (ms); bw_opt = 1151.44 GB/s; speedup = 44.91%
Size (2048, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0416 (ms); optimized = 0.0283 (ms); bw_opt = 1105.31 GB/s; speedup = 47.01%
Size (4096, 16384); Mismatches: dg = 0 db = 0 out of 16384. reference = 0.4420 (ms); optimized = 0.3915 (ms); bw_opt = 1277.58 GB/s; speedup = 12.90%
Size (70000, 64); Mismatches: dg = 0 db = 0 out of 64. reference = 0.5908 (ms); optimized = 0.6850 (ms); bw_opt = 49.49 GB/s; speedup = -13.75%
Size (131072, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 1.1961 (ms); optimized = 0.9234 (ms); bw_opt = 542.54 GB/s; speedup = 29.53%
Size (1000, 520); Mismatches: dg = 0 db = 0 out of 520. reference = 0.0132 (ms); optimized = 0.0113 (ms); bw_opt = 343.83 GB/s; speedup = 16.88%
Size (4005, 4005); Mismatches: dg = 0 db = 0 out of 4005. reference = 0.1441 (ms); optimized = 0.1054 (ms); bw_opt = 1134.36 GB/s; speedup = 36.71%
Size (10000, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.1293 (ms); optimized = 0.1248 (ms); bw_opt = 597.71 GB/s; speedup = 3.63%
Size (1024, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.0738 (ms); optimized = 0.0735 (ms); bw_opt = 1039.40 GB/s; speedup = 0.45%
Size (8192, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.2673 (ms); optimized = 0.2223 (ms); bw_opt = 1125.01 GB/s; speedup = 20.25%
Size (10000, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.7331 (ms); optimized = 0.8940 (ms); bw_opt = 833.54 GB/s; speedup = -18.00%
Size (3072, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.2087 (ms); optimized = 0.2364 (ms); bw_opt = 968.64 GB/s; speedup = -11.71%
Size (6144, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.4197 (ms); optimized = 0.5118 (ms); bw_opt = 894.63 GB/s; speedup = -18.00%
Size (1024, 20000); Mismatches: dg = 0 db = 0 out of 20000. reference = 0.1480 (ms); optimized = 0.1297 (ms); bw_opt = 1177.68 GB/s; speedup = 14.12%
Size (1024, 20000); Mismatches: dg = 0 db = 0 out of 20000. reference = 0.1483 (ms); optimized = 0.1278 (ms); bw_opt = 1195.26 GB/s; speedup = 16.04%
Size (512, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0104 (ms); optimized = 0.0091 (ms); bw_opt = 646.72 GB/s; speedup = 14.44%
Size (512, 6144); Mismatches: dg = 0 db = 0 out of 6144. reference = 0.0219 (ms); optimized = 0.0156 (ms); bw_opt = 1506.30 GB/s; speedup = 40.52%
Size (512, 10240); Mismatches: dg = 0 db = 0 out of 10240. reference = 0.0424 (ms); optimized = 0.0370 (ms); bw_opt = 1057.84 GB/s; speedup = 14.63%
Size (1000, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0139 (ms); optimized = 0.0119 (ms); bw_opt = 627.51 GB/s; speedup = 16.83%
Size (2000, 2000); Mismatches: dg = 0 db = 0 out of 2000. reference = 0.0421 (ms); optimized = 0.0412 (ms); bw_opt = 724.10 GB/s; speedup = 2.20%
Size (10240, 10240); Mismatches: dg = 0 db = 0 out of 10240. reference = 0.7210 (ms); optimized = 0.6098 (ms); bw_opt = 1281.40 GB/s; speedup = 18.24%
Size (384, 128); Mismatches: dg = 0 db = 0 out of 128. reference = 0.0449 (ms); optimized = 0.0089 (ms); bw_opt = 41.50 GB/s; speedup = 403.48%
Size (2048, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0208 (ms); optimized = 0.0169 (ms); bw_opt = 925.70 GB/s; speedup = 23.13%
Size (267, 513); Mismatches: dg = 0 db = 0 out of 513. reference = 0.0342 (ms); optimized = 0.0090 (ms); bw_opt = 114.18 GB/s; speedup = 280.64%
Size (67, 123479); Mismatches: dg = 0 db = 0 out of 123479. reference = 0.0562 (ms); optimized = 0.0552 (ms); bw_opt = 1133.46 GB/s; speedup = 1.81%
Size (1024, 123479); Mismatches: dg = 0 db = 0 out of 123479. reference = 0.8573 (ms); optimized = 0.9245 (ms); bw_opt = 1020.02 GB/s; speedup = -7.27%
Size (2048, 66679); Mismatches: dg = 0 db = 0 out of 66679. reference = 0.8778 (ms); optimized = 0.8590 (ms); bw_opt = 1185.05 GB/s; speedup = 2.19%
Size (200, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0215 (ms); optimized = 0.0066 (ms); bw_opt = 58.49 GB/s; speedup = 226.81%
Size (1000, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0109 (ms); optimized = 0.0092 (ms); bw_opt = 208.27 GB/s; speedup = 18.65%
Size (6000, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0394 (ms); optimized = 0.0301 (ms); bw_opt = 381.90 GB/s; speedup = 30.98%
Size (6272, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0403 (ms); optimized = 0.0300 (ms); bw_opt = 400.48 GB/s; speedup = 34.34%
Size (200, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0218 (ms); optimized = 0.0066 (ms); bw_opt = 116.33 GB/s; speedup = 229.96%
Size (1000, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0110 (ms); optimized = 0.0094 (ms); bw_opt = 407.29 GB/s; speedup = 17.26%
Size (6000, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0535 (ms); optimized = 0.0594 (ms); bw_opt = 386.05 GB/s; speedup = -9.95%
Size (6272, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0573 (ms); optimized = 0.0387 (ms); bw_opt = 619.62 GB/s; speedup = 48.06%
Size (200, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0221 (ms); optimized = 0.0069 (ms); bw_opt = 222.78 GB/s; speedup = 220.76%
Size (1000, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0113 (ms); optimized = 0.0097 (ms); bw_opt = 787.79 GB/s; speedup = 16.46%
Size (6000, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0723 (ms); optimized = 0.0715 (ms); bw_opt = 640.95 GB/s; speedup = 1.10%
Size (6272, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0751 (ms); optimized = 0.0572 (ms); bw_opt = 837.57 GB/s; speedup = 31.30%
Size (200, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0232 (ms); optimized = 0.0071 (ms); bw_opt = 323.97 GB/s; speedup = 226.51%
Size (1000, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0125 (ms); optimized = 0.0114 (ms); bw_opt = 1005.84 GB/s; speedup = 9.62%
Size (6000, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0807 (ms); optimized = 0.0830 (ms); bw_opt = 828.02 GB/s; speedup = -2.76%
Size (6272, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0836 (ms); optimized = 0.0695 (ms); bw_opt = 1033.62 GB/s; speedup = 20.27%
Size (200, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0224 (ms); optimized = 0.0075 (ms); bw_opt = 408.58 GB/s; speedup = 198.10%
Size (1000, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0165 (ms); optimized = 0.0135 (ms); bw_opt = 1132.42 GB/s; speedup = 22.26%
Size (6000, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0993 (ms); optimized = 0.0989 (ms); bw_opt = 926.35 GB/s; speedup = 0.41%
Size (6272, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.1033 (ms); optimized = 0.0826 (ms); bw_opt = 1159.55 GB/s; speedup = 25.09%
Size (200, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.0230 (ms); optimized = 0.0076 (ms); bw_opt = 605.09 GB/s; speedup = 202.51%
Size (1000, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.0207 (ms); optimized = 0.0213 (ms); bw_opt = 1076.45 GB/s; speedup = -2.69%
Size (6000, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.1198 (ms); optimized = 0.1274 (ms); bw_opt = 1078.58 GB/s; speedup = -5.95%
Size (6272, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.1293 (ms); optimized = 0.1189 (ms); bw_opt = 1207.95 GB/s; speedup = 8.76%

Average speedup = 52.88%
```

For additional numerical validation used the following script:

```
def run_model_on_device(fs, X, gO, device_string, numeric_type):
    ln = torch.nn.LayerNorm((fs,), device=device_string, dtype=numeric_type)
    ln.reset_parameters()
    X.grad = None
    ln.zero_grad(set_to_none=True)
    out = ln(X)
    out.backward(gO)
    return (ln.weight.grad, ln.bias.grad)

def run_correctness_test(eps_weight, eps_bias):
    dtype = torch.float
    for fs in (512, 1024, 2048, 4096, 8192, 10000, 500, 1000, 2001, 4005, 8117):
        for bs in (512, 1024, 2048, 4096, 525, 1033, 2064, 3000):
            mean_adjustment = torch.randn(fs, device="cpu", dtype=torch.float)
            X = mean_adjustment * torch.randn(
                bs, fs, device="cpu", dtype=torch.float, requires_grad=True
            )

            X = X.detach().requires_grad_()
            gO = torch.rand_like(X)
            X_gpu = X.to("cuda")
            X_gpu = X_gpu.detach().requires_grad_()
            gO_gpu = gO.to("cuda")
            gO_gpu = gO_gpu.detach().requires_grad_()

            grad_cpu_ref = run_model_on_device(fs, X, gO, "cpu", dtype)
            grad_gpu = run_model_on_device(fs, X_gpu, gO_gpu, "cuda", dtype)
            weight_grad_gpu_target = grad_gpu[0].detach().to("cpu")
            bias_grad_gpu_target = grad_gpu[1].detach().to("cpu")

            weight_delta = torch.abs(grad_cpu_ref[0] - weight_grad_gpu_target)
            weight_mismatches = (weight_delta >= eps_weight).nonzero()
            weight_mismatch_pct = len(weight_mismatches) / len(weight_delta) * 100

            bias_delta = torch.abs(grad_cpu_ref[1] - bias_grad_gpu_target)
            bias_mismatches = (bias_delta >= eps_bias).nonzero()
            bias_mismatch_pct = len(bias_mismatches) / len(bias_delta) * 100

            print(
                "Size ({} x {}) mismatch percentage: weight {:3.2f} bias {:3.2f}".format(
                    fs, bs, weight_mismatch_pct, bias_mismatch_pct
                )
            )
```

`NVFuserTest.FusionMagicSchedulerLayerNormBackward_CUDA` test also does additional numerical validation and it passes.

Differential Revision: D40730981

Pull Request resolved: pytorch#87814
Approved by: https://github.com/weiwangmeta
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request fb-exported Merged release notes: cuda release notes category topic: performance topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants