-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[pytorch] Layer norm backward speed gain with warp shuffles #87445
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87445
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e702b0c: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
41ff64f
to
dda5a87
Compare
This pull request was exported from Phabricator. Differential Revision: D40567574 |
@pytorchbot start ci |
❌ 🤖 pytorchbot command failed:
Try |
@pytorchbot --help |
PyTorchBot Help
Merge
Revert
Rebase
Label
|
/ci |
This pull request was exported from Phabricator. Differential Revision: D40567574 |
dda5a87
to
4fb81c0
Compare
…87445) Summary: Pull Request resolved: pytorch#87445 This implementation stores `mean` and `rstd` in registers and exchanges them across a warp using `__shfl_sync`. This avoids some `__syncthreads()` along the way and warp exchanges should be slightly faster than loading from shared memory. Test Plan: ``` Times below are Forward + Backward on A100 Size FP32. Gain. FP16. Gain 256, 256 101.30 9% 103.9 6% 512, 256 110.10 -4% 102.9 10% 1024, 256 104.30 7% 102.4 6% 2048, 256 107.60 4% 109.7 0% 4096, 256 116.70 8% 109.1 0% 6144, 256 106.10 7% 112.8 2% 8192, 256 106.10 1% 109.7 2% 256, 512 102.10 3% 108.5 1% 512, 512 101.50 40% 105.9 4% 1024, 512 109.70 20% 109.2 -1% 2048, 512 107.40 24% 107.2 1% 4096, 512 108.00 6% 110.6 -3% 6144, 512 103.90 13% 105.8 7% 8192, 512 138.70 14% 105.6 7% 256, 1024 106.20 1% 102.9 6% 512, 1024 104.50 4% 104.2 3% 1024, 1024 126.90 -15% 103.9 10% 2048, 1024 127.40 -15% 102.2 6% 4096, 1024 117.70 6% 102.8 21% 6144, 1024 165.30 11% 112.2 12% 8192, 1024 211.90 11% 144.8 13% 256, 1536 102.80 11% 103.1 6% 512, 1536 103.30 9% 102.9 18% 1024, 1536 111.00 -2% 117.2 7% 2048, 1536 102.30 12% 132.1 -4% 4096, 1536 165.50 5% 112.9 18% 6144, 1536 236.60 5% 145.7 12% 8192, 1536 307.80 5% 186.1 11% 256, 2048 110.60 -1% 103.8 7% 512, 2048 105.20 3% 105.6 1% 1024, 2048 106.70 3% 114.8 3% 2048, 2048 124.90 5% 109.7 0% 4096, 2048 231.40 4% 129.9 10% 6144, 2048 332.80 4% 182.5 11% 8192, 2048 434.60 4% 235.2 11% 256, 3072 111.60 8% 110.8 1% 512, 3072 106.80 1% 104.6 10% 1024, 3072 104.90 3% 109.9 4% 2048, 3072 193.80 0% 106.2 10% 4096, 3072 364.50 0% 187.8 5% 6144, 3072 538.30 0% 267 5% 8192, 3072 718.00 -1% 346.7 6% 256, 4096 103.60 4% 110.2 -1% 512, 4096 131.40 -11% 117 -7% 1024, 4096 135.80 1% 104.8 7% 2048, 4096 268.20 1% 149.4 10% 4096, 4096 520.70 1% 268.5 9% 6144, 4096 786.30 0% 389.8 9% 8192, 4096 1043.50 0% 509 10% ``` Used the following script from ngimel: ``` import torch from torch.utils.benchmark import Compare, Timer results = [] for dtype in (torch.float, torch.half): for fs in (256, 512, 1024, 1536, 2048, 3072, 4096): for bs in (256, 512, 1024, 2048, 4096, 6144, 8192): ln = torch.nn.LayerNorm((fs,), device="cuda", dtype=dtype) X = torch.randn(bs, fs, device="cuda", dtype=dtype, requires_grad=True) gO = torch.rand_like(X) stmtfwd = "ln(X)" stmtfwdbwd = "X.grad=None; ln.zero_grad(set_to_none=True); out = ln(X); out.backward(gO)" tfwd = Timer( stmt=stmtfwd, label="ln", sub_label=f"{bs:5}, {fs:5}", description=f"fwd, {dtype}", globals=globals(), ) tfwdbwd = Timer( stmt=stmtfwdbwd, label="ln", sub_label=f"{bs:5}, {fs:5}", description=f"fwdbwd, {dtype}", globals=globals(), ) for t in (tfwd, tfwdbwd): results.append(t.blocked_autorange()) print(fs, end="\r") c = Compare(results) c.print() ``` For numerical validation used the following script: ``` def run_model_on_device(fs, X, gO, device_string, numeric_type): ln = torch.nn.LayerNorm((fs,), device=device_string, dtype=numeric_type) ln.reset_parameters() X.grad = None ln.zero_grad(set_to_none=True) out = ln(X) out.backward(gO) return (ln.weight.grad, ln.bias.grad) def run_correctness_test(eps_weight, eps_bias): dtype = torch.float for fs in (256, 512, 1024, 1536, 2048, 3072, 4096): for bs in (256, 512, 1024, 2048, 4096, 6144, 8192): mean_adjustment = torch.randn(fs, device="cpu", dtype=torch.float) X = mean_adjustment * torch.randn( bs, fs, device="cpu", dtype=torch.float, requires_grad=True ) X = X.detach().requires_grad_() gO = torch.rand_like(X) X_gpu = X.to("cuda") X_gpu = X_gpu.detach().requires_grad_() gO_gpu = gO.to("cuda") gO_gpu = gO_gpu.detach().requires_grad_() grad_cpu_ref = run_model_on_device(fs, X, gO, "cpu", dtype) grad_gpu = run_model_on_device(fs, X_gpu, gO_gpu, "cuda", dtype) weight_grad_gpu_target = grad_gpu[0].detach().to("cpu") bias_grad_gpu_target = grad_gpu[1].detach().to("cpu") weight_delta = torch.abs(grad_cpu_ref[0] - weight_grad_gpu_target) weight_mismatches = (weight_delta >= eps_weight).nonzero() weight_mismatch_pct = len(weight_mismatches) / len(weight_delta) * 100 bias_delta = torch.abs(grad_cpu_ref[1] - bias_grad_gpu_target) bias_mismatches = (bias_delta >= eps_bias).nonzero() bias_mismatch_pct = len(bias_mismatches) / len(bias_delta) * 100 print( "Size ({} x {}) mismatch percentage: weight {:3.2f} bias {:3.2f}".format( fs, bs, weight_mismatch_pct, bias_mismatch_pct ) ) ``` Numerical validation results: ``` ... Size (256 x 4096) mismatch percentage: weight 0.00 bias 0.00 Size (256 x 6144) mismatch percentage: weight 0.00 bias 0.39 Size (256 x 8192) mismatch percentage: weight 0.00 bias 1.17 ... Size (512 x 4096) mismatch percentage: weight 0.00 bias 0.00 Size (512 x 6144) mismatch percentage: weight 0.00 bias 0.39 Size (512 x 8192) mismatch percentage: weight 0.39 bias 0.59 ... Size (1024 x 4096) mismatch percentage: weight 0.00 bias 0.00 Size (1024 x 6144) mismatch percentage: weight 0.00 bias 0.49 Size (1024 x 8192) mismatch percentage: weight 0.10 bias 1.07 ... Size (1536 x 6144) mismatch percentage: weight 0.00 bias 0.20 Size (1536 x 8192) mismatch percentage: weight 0.20 bias 0.78 ... Size (2048 x 6144) mismatch percentage: weight 0.05 bias 0.10 Size (2048 x 8192) mismatch percentage: weight 0.29 bias 0.88 ... Size (3072 x 6144) mismatch percentage: weight 0.00 bias 0.33 Size (3072 x 8192) mismatch percentage: weight 0.16 bias 0.72 ... Size (4096 x 6144) mismatch percentage: weight 0.12 bias 0.10 Size (4096 x 8192) mismatch percentage: weight 0.20 bias 1.03 Differential Revision: D40567574 fbshipit-source-id: 56c6bfb215b2818fd959aef7b768de6aec127ffb
This pull request was exported from Phabricator. Differential Revision: D40567574 |
…87445) Summary: Pull Request resolved: pytorch#87445 This implementation stores `mean` and `rstd` in registers and exchanges them across a warp using `__shfl_sync`. This avoids some `__syncthreads()` along the way and warp exchanges should be slightly faster than loading from shared memory. Test Plan: Wrote a simple CUDA app that calls the previous implementation of `GammaBetaBackwardCUDAKernel` and the current one, using FP32 values and compares the results. The epsilon value we used for FP comparison is 0.00001 for the weight and 0.0001 for the bias. Ran the benchmark for various sizes A100 GPU and got the results below. Almost all sizes show good speedup. ``` Size (512, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0098 (ms); optimized = 0.0098 (ms); speedup = 0.00% Size (1024, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0131 (ms); optimized = 0.0123 (ms); speedup = 6.40% Size (2048, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0198 (ms); optimized = 0.0170 (ms); speedup = 16.41% Size (4096, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0345 (ms); optimized = 0.0282 (ms); speedup = 22.32% Size (8192, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0820 (ms); optimized = 0.0804 (ms); speedup = 2.02% Size (512, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0102 (ms); optimized = 0.0099 (ms); speedup = 3.13% Size (1024, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0134 (ms); optimized = 0.0128 (ms); speedup = 4.66% Size (2048, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0216 (ms); optimized = 0.0189 (ms); speedup = 14.25% Size (4096, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0488 (ms); optimized = 0.0459 (ms); speedup = 6.34% Size (8192, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0973 (ms); optimized = 0.0870 (ms); speedup = 11.84% Size (512, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0131 (ms); optimized = 0.0106 (ms); speedup = 23.37% Size (1024, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0210 (ms); optimized = 0.0146 (ms); speedup = 43.56% Size (2048, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0420 (ms); optimized = 0.0316 (ms); speedup = 32.81% Size (4096, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0728 (ms); optimized = 0.0606 (ms); speedup = 20.19% Size (8192, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.1399 (ms); optimized = 0.1129 (ms); speedup = 23.93% Size (512, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.0181 (ms); optimized = 0.0155 (ms); speedup = 16.77% Size (1024, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.0379 (ms); optimized = 0.0317 (ms); speedup = 19.47% Size (2048, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.0680 (ms); optimized = 0.0609 (ms); speedup = 11.71% Size (4096, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.1286 (ms); optimized = 0.1179 (ms); speedup = 9.04% Size (8192, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.2567 (ms); optimized = 0.2296 (ms); speedup = 11.80% Size (512, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.0337 (ms); optimized = 0.0291 (ms); speedup = 15.90% Size (1024, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.0628 (ms); optimized = 0.0579 (ms); speedup = 8.48% Size (2048, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.1154 (ms); optimized = 0.1068 (ms); speedup = 8.04% Size (4096, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.2207 (ms); optimized = 0.2051 (ms); speedup = 7.61% Size (8192, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.4504 (ms); optimized = 0.4022 (ms); speedup = 11.98% Size (512, 16384); Mismatches: dg = 0 db = 0 out of 16384. reference = 0.0628 (ms); optimized = 0.0561 (ms); speedup = 11.94% Size (1024, 16384); Mismatches: dg = 0 db = 0 out of 16384. reference = 0.1125 (ms); optimized = 0.1027 (ms); speedup = 9.52% Size (2048, 16384); Mismatches: dg = 0 db = 0 out of 16384. reference = 0.2144 (ms); optimized = 0.1983 (ms); speedup = 8.13% Size (4096, 16384); Mismatches: dg = 0 db = 0 out of 16384. reference = 0.4543 (ms); optimized = 0.3920 (ms); speedup = 15.89% Size (8192, 16384); Mismatches: dg = 0 db = 0 out of 16384. reference = 0.9087 (ms); optimized = 0.7878 (ms); speedup = 15.35% Size (512, 13261); Mismatches: dg = 0 db = 0 out of 13261. reference = 0.0541 (ms); optimized = 0.0481 (ms); speedup = 12.49% Size (1024, 13261); Mismatches: dg = 0 db = 0 out of 13261. reference = 0.0953 (ms); optimized = 0.0895 (ms); speedup = 6.50% Size (2048, 13261); Mismatches: dg = 0 db = 0 out of 13261. reference = 0.1766 (ms); optimized = 0.1849 (ms); speedup = -4.49% Size (4096, 13261); Mismatches: dg = 0 db = 0 out of 13261. reference = 0.3431 (ms); optimized = 0.3821 (ms); speedup = -10.21% Size (8192, 13261); Mismatches: dg = 0 db = 0 out of 13261. reference = 0.6696 (ms); optimized = 0.7788 (ms); speedup = -14.02% Size (512, 999); Mismatches: dg = 7 db = 0 out of 999. reference = 0.0086 (ms); optimized = 0.0086 (ms); speedup = 0.00% Size (1024, 999); Mismatches: dg = 7 db = 0 out of 999. reference = 0.0122 (ms); optimized = 0.0112 (ms); speedup = 8.94% Size (2048, 999); Mismatches: dg = 7 db = 0 out of 999. reference = 0.0193 (ms); optimized = 0.0162 (ms); speedup = 19.29% Size (4096, 999); Mismatches: dg = 7 db = 0 out of 999. reference = 0.0493 (ms); optimized = 0.0441 (ms); speedup = 11.78% Size (8192, 999); Mismatches: dg = 7 db = 0 out of 999. reference = 0.0947 (ms); optimized = 0.0817 (ms); speedup = 15.90% Size (512, 667); Mismatches: dg = 0 db = 0 out of 667. reference = 0.0084 (ms); optimized = 0.0084 (ms); speedup = -0.28% Size (1024, 667); Mismatches: dg = 0 db = 0 out of 667. reference = 0.0114 (ms); optimized = 0.0106 (ms); speedup = 7.66% Size (2048, 667); Mismatches: dg = 0 db = 0 out of 667. reference = 0.0179 (ms); optimized = 0.0158 (ms); speedup = 13.27% Size (4096, 667); Mismatches: dg = 0 db = 0 out of 667. reference = 0.0378 (ms); optimized = 0.0343 (ms); speedup = 10.15% Size (8192, 667); Mismatches: dg = 0 db = 0 out of 667. reference = 0.0828 (ms); optimized = 0.0766 (ms); speedup = 8.09% Size (512, 312); Mismatches: dg = 0 db = 0 out of 312. reference = 0.0083 (ms); optimized = 0.0082 (ms); speedup = 1.16% Size (1024, 312); Mismatches: dg = 0 db = 0 out of 312. reference = 0.0111 (ms); optimized = 0.0103 (ms); speedup = 7.87% Size (2048, 312); Mismatches: dg = 0 db = 0 out of 312. reference = 0.0166 (ms); optimized = 0.0141 (ms); speedup = 17.94% Size (4096, 312); Mismatches: dg = 0 db = 0 out of 312. reference = 0.0283 (ms); optimized = 0.0230 (ms); speedup = 23.01% Size (8192, 312); Mismatches: dg = 0 db = 0 out of 312. reference = 0.0543 (ms); optimized = 0.0517 (ms); speedup = 4.98% Average speedup = 11.06% ``` For additional numerical validation used the following script: ``` def run_model_on_device(fs, X, gO, device_string, numeric_type): ln = torch.nn.LayerNorm((fs,), device=device_string, dtype=numeric_type) ln.reset_parameters() X.grad = None ln.zero_grad(set_to_none=True) out = ln(X) out.backward(gO) return (ln.weight.grad, ln.bias.grad) def run_correctness_test(eps_weight, eps_bias): dtype = torch.float for fs in (256, 512, 1024, 1536, 2048, 3072, 4096): for bs in (256, 512, 1024, 2048, 4096, 6144, 8192): mean_adjustment = torch.randn(fs, device="cpu", dtype=torch.float) X = mean_adjustment * torch.randn( bs, fs, device="cpu", dtype=torch.float, requires_grad=True ) X = X.detach().requires_grad_() gO = torch.rand_like(X) X_gpu = X.to("cuda") X_gpu = X_gpu.detach().requires_grad_() gO_gpu = gO.to("cuda") gO_gpu = gO_gpu.detach().requires_grad_() grad_cpu_ref = run_model_on_device(fs, X, gO, "cpu", dtype) grad_gpu = run_model_on_device(fs, X_gpu, gO_gpu, "cuda", dtype) weight_grad_gpu_target = grad_gpu[0].detach().to("cpu") bias_grad_gpu_target = grad_gpu[1].detach().to("cpu") weight_delta = torch.abs(grad_cpu_ref[0] - weight_grad_gpu_target) weight_mismatches = (weight_delta >= eps_weight).nonzero() weight_mismatch_pct = len(weight_mismatches) / len(weight_delta) * 100 bias_delta = torch.abs(grad_cpu_ref[1] - bias_grad_gpu_target) bias_mismatches = (bias_delta >= eps_bias).nonzero() bias_mismatch_pct = len(bias_mismatches) / len(bias_delta) * 100 print( "Size ({} x {}) mismatch percentage: weight {:3.2f} bias {:3.2f}".format( fs, bs, weight_mismatch_pct, bias_mismatch_pct ) ) ``` Differential Revision: D40567574 fbshipit-source-id: 756a0340678e0c8700f53cd15fb64a3243cd39ba
4fb81c0
to
671f11e
Compare
This pull request was exported from Phabricator. Differential Revision: D40567574 |
671f11e
to
e8c7506
Compare
…87445) Summary: Pull Request resolved: pytorch#87445 This implementation stores `mean` and `rstd` in registers and exchanges them across a warp using `__shfl_sync`. This avoids some `__syncthreads()` along the way and warp exchanges should be slightly faster than loading from shared memory. Test Plan: Wrote a simple CUDA app that calls the previous implementation of `GammaBetaBackwardCUDAKernel` and the current one, using FP32 values and compares the results. The epsilon value we used for FP comparison is 0.00001 for the weight and 0.0001 for the bias. Ran the benchmark for various sizes A100 GPU and got the results below. Almost all sizes show good speedup; average is 15.6%. ``` Size (512, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0099 (ms); optimized = 0.0093 (ms); bw_opt = 210.87 GB/s; speedup = 6.41% Size (1024, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0132 (ms); optimized = 0.0118 (ms); bw_opt = 331.96 GB/s; speedup = 11.72% Size (2048, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0198 (ms); optimized = 0.0169 (ms); bw_opt = 463.30 GB/s; speedup = 17.07% Size (4096, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0344 (ms); optimized = 0.0285 (ms); bw_opt = 549.16 GB/s; speedup = 20.57% Size (525, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0114 (ms); optimized = 0.0117 (ms); bw_opt = 171.74 GB/s; speedup = -2.65% Size (1033, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0148 (ms); optimized = 0.0144 (ms); bw_opt = 274.44 GB/s; speedup = 2.65% Size (2064, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0215 (ms); optimized = 0.0196 (ms); bw_opt = 402.73 GB/s; speedup = 9.73% Size (3000, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0276 (ms); optimized = 0.0245 (ms); bw_opt = 467.99 GB/s; speedup = 12.65% Size (512, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0099 (ms); optimized = 0.0094 (ms); bw_opt = 416.00 GB/s; speedup = 5.06% Size (1024, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0134 (ms); optimized = 0.0120 (ms); bw_opt = 651.43 GB/s; speedup = 11.51% Size (2048, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0214 (ms); optimized = 0.0186 (ms); bw_opt = 841.44 GB/s; speedup = 15.00% Size (4096, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0486 (ms); optimized = 0.0465 (ms); bw_opt = 672.98 GB/s; speedup = 4.56% Size (525, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0117 (ms); optimized = 0.0117 (ms); bw_opt = 343.85 GB/s; speedup = 0.20% Size (1033, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0153 (ms); optimized = 0.0146 (ms); bw_opt = 540.30 GB/s; speedup = 4.57% Size (2064, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0235 (ms); optimized = 0.0214 (ms); bw_opt = 736.58 GB/s; speedup = 9.69% Size (3000, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0354 (ms); optimized = 0.0287 (ms); bw_opt = 798.39 GB/s; speedup = 23.34% Size (512, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0132 (ms); optimized = 0.0102 (ms); bw_opt = 769.27 GB/s; speedup = 29.51% Size (1024, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0201 (ms); optimized = 0.0142 (ms); bw_opt = 1103.06 GB/s; speedup = 41.68% Size (2048, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0417 (ms); optimized = 0.0311 (ms); bw_opt = 1005.36 GB/s; speedup = 34.02% Size (4096, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0725 (ms); optimized = 0.0615 (ms); bw_opt = 1017.20 GB/s; speedup = 17.91% Size (525, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0148 (ms); optimized = 0.0125 (ms); bw_opt = 641.53 GB/s; speedup = 18.10% Size (1033, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0224 (ms); optimized = 0.0168 (ms); bw_opt = 939.12 GB/s; speedup = 33.19% Size (2064, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0457 (ms); optimized = 0.0369 (ms); bw_opt = 854.16 GB/s; speedup = 23.84% Size (3000, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0600 (ms); optimized = 0.0509 (ms); bw_opt = 900.04 GB/s; speedup = 17.85% Size (512, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.0178 (ms); optimized = 0.0152 (ms); bw_opt = 1029.47 GB/s; speedup = 16.93% Size (1024, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.0378 (ms); optimized = 0.0304 (ms); bw_opt = 1029.27 GB/s; speedup = 24.31% Size (2048, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.0680 (ms); optimized = 0.0654 (ms); bw_opt = 956.38 GB/s; speedup = 4.01% Size (4096, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.1344 (ms); optimized = 0.1161 (ms); bw_opt = 1077.09 GB/s; speedup = 15.75% Size (525, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.0228 (ms); optimized = 0.0173 (ms); bw_opt = 927.61 GB/s; speedup = 31.68% Size (1033, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.0498 (ms); optimized = 0.0361 (ms); bw_opt = 873.82 GB/s; speedup = 37.82% Size (2064, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.0826 (ms); optimized = 0.0659 (ms); bw_opt = 956.53 GB/s; speedup = 25.36% Size (3000, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.1114 (ms); optimized = 0.0891 (ms); bw_opt = 1028.16 GB/s; speedup = 25.05% Size (512, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.0342 (ms); optimized = 0.0270 (ms); bw_opt = 1159.26 GB/s; speedup = 26.57% Size (1024, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.0625 (ms); optimized = 0.0579 (ms); bw_opt = 1080.86 GB/s; speedup = 7.99% Size (2048, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.1160 (ms); optimized = 0.1064 (ms); bw_opt = 1175.72 GB/s; speedup = 9.05% Size (4096, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.2204 (ms); optimized = 0.2044 (ms); bw_opt = 1223.56 GB/s; speedup = 7.84% Size (525, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.0494 (ms); optimized = 0.0389 (ms); bw_opt = 825.70 GB/s; speedup = 27.04% Size (1033, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.0762 (ms); optimized = 0.0674 (ms); bw_opt = 936.46 GB/s; speedup = 13.05% Size (2064, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.1312 (ms); optimized = 0.1142 (ms); bw_opt = 1103.77 GB/s; speedup = 14.89% Size (3000, 8192); Mismatches: dg = 0 db = 0 out of 8192. reference = 0.1753 (ms); optimized = 0.1571 (ms); bw_opt = 1165.93 GB/s; speedup = 11.56% Size (512, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.0425 (ms); optimized = 0.0393 (ms); bw_opt = 972.28 GB/s; speedup = 8.07% Size (1024, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.0739 (ms); optimized = 0.0698 (ms); bw_opt = 1094.07 GB/s; speedup = 5.84% Size (2048, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.1417 (ms); optimized = 0.1294 (ms); bw_opt = 1179.98 GB/s; speedup = 9.51% Size (4096, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.2871 (ms); optimized = 0.2532 (ms); bw_opt = 1205.69 GB/s; speedup = 13.39% Size (525, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.0575 (ms); optimized = 0.0496 (ms); bw_opt = 790.34 GB/s; speedup = 15.96% Size (1033, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.0888 (ms); optimized = 0.0794 (ms); bw_opt = 970.15 GB/s; speedup = 11.80% Size (2064, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.1541 (ms); optimized = 0.1362 (ms); bw_opt = 1129.66 GB/s; speedup = 13.13% Size (3000, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.2098 (ms); optimized = 0.1882 (ms); bw_opt = 1188.13 GB/s; speedup = 11.46% Size (512, 500); Mismatches: dg = 0 db = 0 out of 500. reference = 0.0097 (ms); optimized = 0.0099 (ms); bw_opt = 193.53 GB/s; speedup = -1.93% Size (1024, 500); Mismatches: dg = 0 db = 0 out of 500. reference = 0.0133 (ms); optimized = 0.0125 (ms); bw_opt = 305.67 GB/s; speedup = 6.10% Size (2048, 500); Mismatches: dg = 0 db = 0 out of 500. reference = 0.0202 (ms); optimized = 0.0175 (ms); bw_opt = 437.05 GB/s; speedup = 15.40% Size (4096, 500); Mismatches: dg = 0 db = 0 out of 500. reference = 0.0360 (ms); optimized = 0.0291 (ms); bw_opt = 525.34 GB/s; speedup = 23.67% Size (525, 500); Mismatches: dg = 0 db = 0 out of 500. reference = 0.0116 (ms); optimized = 0.0115 (ms); bw_opt = 170.85 GB/s; speedup = 1.04% Size (1033, 500); Mismatches: dg = 0 db = 0 out of 500. reference = 0.0151 (ms); optimized = 0.0143 (ms); bw_opt = 269.81 GB/s; speedup = 5.50% Size (2064, 500); Mismatches: dg = 0 db = 0 out of 500. reference = 0.0221 (ms); optimized = 0.0193 (ms); bw_opt = 399.14 GB/s; speedup = 14.44% Size (3000, 500); Mismatches: dg = 0 db = 0 out of 500. reference = 0.0286 (ms); optimized = 0.0248 (ms); bw_opt = 451.77 GB/s; speedup = 15.38% Size (512, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0098 (ms); optimized = 0.0100 (ms); bw_opt = 382.08 GB/s; speedup = -2.14% Size (1024, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0132 (ms); optimized = 0.0126 (ms); bw_opt = 606.11 GB/s; speedup = 4.54% Size (2048, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0214 (ms); optimized = 0.0189 (ms); bw_opt = 808.26 GB/s; speedup = 13.24% Size (4096, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0523 (ms); optimized = 0.0491 (ms); bw_opt = 622.43 GB/s; speedup = 6.56% Size (525, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0116 (ms); optimized = 0.0116 (ms); bw_opt = 338.56 GB/s; speedup = 0.21% Size (1033, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0153 (ms); optimized = 0.0144 (ms); bw_opt = 535.51 GB/s; speedup = 6.29% Size (2064, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0234 (ms); optimized = 0.0208 (ms); bw_opt = 740.78 GB/s; speedup = 12.50% Size (3000, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0378 (ms); optimized = 0.0365 (ms); bw_opt = 613.16 GB/s; speedup = 3.59% Size (512, 2001); Mismatches: dg = 0 db = 0 out of 2001. reference = 0.0133 (ms); optimized = 0.0109 (ms); bw_opt = 700.75 GB/s; speedup = 21.83% Size (1024, 2001); Mismatches: dg = 0 db = 0 out of 2001. reference = 0.0206 (ms); optimized = 0.0151 (ms); bw_opt = 1011.46 GB/s; speedup = 36.28% Size (2048, 2001); Mismatches: dg = 0 db = 0 out of 2001. reference = 0.0436 (ms); optimized = 0.0354 (ms); bw_opt = 863.24 GB/s; speedup = 23.16% Size (4096, 2001); Mismatches: dg = 0 db = 0 out of 2001. reference = 0.0761 (ms); optimized = 0.0625 (ms); bw_opt = 977.57 GB/s; speedup = 21.70% Size (525, 2001); Mismatches: dg = 0 db = 0 out of 2001. reference = 0.0149 (ms); optimized = 0.0124 (ms); bw_opt = 632.84 GB/s; speedup = 20.19% Size (1033, 2001); Mismatches: dg = 0 db = 0 out of 2001. reference = 0.0224 (ms); optimized = 0.0169 (ms); bw_opt = 912.41 GB/s; speedup = 32.58% Size (2064, 2001); Mismatches: dg = 0 db = 0 out of 2001. reference = 0.0480 (ms); optimized = 0.0396 (ms); bw_opt = 777.79 GB/s; speedup = 21.19% Size (3000, 2001); Mismatches: dg = 0 db = 0 out of 2001. reference = 0.0623 (ms); optimized = 0.0514 (ms); bw_opt = 870.83 GB/s; speedup = 21.20% Size (512, 4005); Mismatches: dg = 5 db = 0 out of 4005. reference = 0.0191 (ms); optimized = 0.0163 (ms); bw_opt = 938.91 GB/s; speedup = 17.11% Size (1024, 4005); Mismatches: dg = 5 db = 0 out of 4005. reference = 0.0424 (ms); optimized = 0.0364 (ms); bw_opt = 840.87 GB/s; speedup = 16.58% Size (2048, 4005); Mismatches: dg = 5 db = 0 out of 4005. reference = 0.0741 (ms); optimized = 0.0619 (ms); bw_opt = 987.71 GB/s; speedup = 19.68% Size (4096, 4005); Mismatches: dg = 5 db = 0 out of 4005. reference = 0.1414 (ms); optimized = 0.1163 (ms); bw_opt = 1051.44 GB/s; speedup = 21.59% Size (525, 4005); Mismatches: dg = 5 db = 0 out of 4005. reference = 0.0242 (ms); optimized = 0.0180 (ms); bw_opt = 872.17 GB/s; speedup = 34.44% Size (1033, 4005); Mismatches: dg = 5 db = 0 out of 4005. reference = 0.0521 (ms); optimized = 0.0392 (ms); bw_opt = 787.37 GB/s; speedup = 32.91% Size (2064, 4005); Mismatches: dg = 5 db = 0 out of 4005. reference = 0.0867 (ms); optimized = 0.0667 (ms); bw_opt = 924.25 GB/s; speedup = 30.03% Size (3000, 4005); Mismatches: dg = 5 db = 0 out of 4005. reference = 0.1146 (ms); optimized = 0.0899 (ms); bw_opt = 996.25 GB/s; speedup = 27.47% Size (512, 8117); Mismatches: dg = 0 db = 0 out of 8117. reference = 0.0383 (ms); optimized = 0.0337 (ms); bw_opt = 921.03 GB/s; speedup = 13.73% Size (1024, 8117); Mismatches: dg = 0 db = 0 out of 8117. reference = 0.0656 (ms); optimized = 0.0591 (ms); bw_opt = 1048.93 GB/s; speedup = 11.01% Size (2048, 8117); Mismatches: dg = 0 db = 0 out of 8117. reference = 0.1204 (ms); optimized = 0.1118 (ms); bw_opt = 1108.33 GB/s; speedup = 7.68% Size (4096, 8117); Mismatches: dg = 0 db = 0 out of 8117. reference = 0.2281 (ms); optimized = 0.2276 (ms); bw_opt = 1088.79 GB/s; speedup = 0.22% Size (525, 8117); Mismatches: dg = 0 db = 0 out of 8117. reference = 0.0524 (ms); optimized = 0.0436 (ms); bw_opt = 729.98 GB/s; speedup = 20.24% Size (1033, 8117); Mismatches: dg = 0 db = 0 out of 8117. reference = 0.0810 (ms); optimized = 0.0679 (ms); bw_opt = 921.04 GB/s; speedup = 19.31% Size (2064, 8117); Mismatches: dg = 0 db = 0 out of 8117. reference = 0.1343 (ms); optimized = 0.1192 (ms); bw_opt = 1047.94 GB/s; speedup = 12.66% Size (3000, 8117); Mismatches: dg = 0 db = 0 out of 8117. reference = 0.1797 (ms); optimized = 0.1699 (ms); bw_opt = 1068.36 GB/s; speedup = 5.77% Average speedup = 15.56% ``` For additional numerical validation used the following script: ``` def run_model_on_device(fs, X, gO, device_string, numeric_type): ln = torch.nn.LayerNorm((fs,), device=device_string, dtype=numeric_type) ln.reset_parameters() X.grad = None ln.zero_grad(set_to_none=True) out = ln(X) out.backward(gO) return (ln.weight.grad, ln.bias.grad) def run_correctness_test(eps_weight, eps_bias): dtype = torch.float for fs in (512, 1024, 2048, 4096, 8192, 10000, 500, 1000, 2001, 4005, 8117): for bs in (512, 1024, 2048, 4096, 525, 1033, 2064, 3000): mean_adjustment = torch.randn(fs, device="cpu", dtype=torch.float) X = mean_adjustment * torch.randn( bs, fs, device="cpu", dtype=torch.float, requires_grad=True ) X = X.detach().requires_grad_() gO = torch.rand_like(X) X_gpu = X.to("cuda") X_gpu = X_gpu.detach().requires_grad_() gO_gpu = gO.to("cuda") gO_gpu = gO_gpu.detach().requires_grad_() grad_cpu_ref = run_model_on_device(fs, X, gO, "cpu", dtype) grad_gpu = run_model_on_device(fs, X_gpu, gO_gpu, "cuda", dtype) weight_grad_gpu_target = grad_gpu[0].detach().to("cpu") bias_grad_gpu_target = grad_gpu[1].detach().to("cpu") weight_delta = torch.abs(grad_cpu_ref[0] - weight_grad_gpu_target) weight_mismatches = (weight_delta >= eps_weight).nonzero() weight_mismatch_pct = len(weight_mismatches) / len(weight_delta) * 100 bias_delta = torch.abs(grad_cpu_ref[1] - bias_grad_gpu_target) bias_mismatches = (bias_delta >= eps_bias).nonzero() bias_mismatch_pct = len(bias_mismatches) / len(bias_delta) * 100 print( "Size ({} x {}) mismatch percentage: weight {:3.2f} bias {:3.2f}".format( fs, bs, weight_mismatch_pct, bias_mismatch_pct ) ) ``` Differential Revision: D40567574 fbshipit-source-id: b23e59f4cbd7d8a8104096af2ae5d9b6bb7e200c
…87445) Summary: Pull Request resolved: pytorch#87445 Improved native layer norm backward performance. Rewrote `GammaBetaBackwardCUDAKernel` to use shared memory only for the reduction step, but not for loading `mean` and `rstd`. The previous implementation used only `threadIdx.x = 0` to load `mean` and `rstd` into shared memory, and then all threads would access the values in order to do loop unrolling. This approached increased register usage and decreased occupancy, without much benefit from using shared memory (this is because the values were already cached in L1). The new implementation is simpler and register usage is smaller, thus occupancy is better. Added another implementation called `GammaBetaBackwardCUDAKernel_32x32` which is only for shapes dividing exactly to a (32 x 32) block. This permits using warp shuffles for speeding up loading `mean` and `rstd` as well as for the final reduction stage. The effective bandwidth of this implementation is equal to STREAM Triad. Observed that we can get additional benefit if we lower the threshold for calling `GammaBetaBackwardSimpleCUDAKernel` (simple col-wise reduction implementation) from `512` to `128`. Test Plan: Wrote a simple CUDA app that calls the previous implementation of `GammaBetaBackwardCUDAKernel` and the current one, using FP32 values and compares the results. The epsilon value we used for FP comparison is 0.00001 for the weight and 0.0001 for the bias. Ran the benchmark for various sizes A100 GPU and got the results below. Almost all sizes show good speedup. ``` Size (32, 32); Mismatches: dg = 0 db = 0 out of 32. reference = 0.0073 (ms); optimized = 0.0071 (ms); bw_opt = 1.14 GB/s; speedup = 2.68% Size (64, 32); Mismatches: dg = 0 db = 0 out of 32. reference = 0.0107 (ms); optimized = 0.0107 (ms); bw_opt = 1.50 GB/s; speedup = 0.22% Size (256, 128); Mismatches: dg = 0 db = 0 out of 128. reference = 0.0323 (ms); optimized = 0.0075 (ms); bw_opt = 32.89 GB/s; speedup = 330.16% Size (512, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0103 (ms); optimized = 0.0089 (ms); bw_opt = 440.54 GB/s; speedup = 15.82% Size (1024, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0197 (ms); optimized = 0.0136 (ms); bw_opt = 1151.44 GB/s; speedup = 44.91% Size (2048, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0416 (ms); optimized = 0.0283 (ms); bw_opt = 1105.31 GB/s; speedup = 47.01% Size (4096, 16384); Mismatches: dg = 0 db = 0 out of 16384. reference = 0.4420 (ms); optimized = 0.3915 (ms); bw_opt = 1277.58 GB/s; speedup = 12.90% Size (70000, 64); Mismatches: dg = 0 db = 0 out of 64. reference = 0.5908 (ms); optimized = 0.6850 (ms); bw_opt = 49.49 GB/s; speedup = -13.75% Size (131072, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 1.1961 (ms); optimized = 0.9234 (ms); bw_opt = 542.54 GB/s; speedup = 29.53% Size (1000, 520); Mismatches: dg = 0 db = 0 out of 520. reference = 0.0132 (ms); optimized = 0.0113 (ms); bw_opt = 343.83 GB/s; speedup = 16.88% Size (4005, 4005); Mismatches: dg = 0 db = 0 out of 4005. reference = 0.1441 (ms); optimized = 0.1054 (ms); bw_opt = 1134.36 GB/s; speedup = 36.71% Size (10000, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.1293 (ms); optimized = 0.1248 (ms); bw_opt = 597.71 GB/s; speedup = 3.63% Size (1024, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.0738 (ms); optimized = 0.0735 (ms); bw_opt = 1039.40 GB/s; speedup = 0.45% Size (8192, 4096); Mismatches: dg = 0 db = 0 out of 4096. reference = 0.2673 (ms); optimized = 0.2223 (ms); bw_opt = 1125.01 GB/s; speedup = 20.25% Size (10000, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.7331 (ms); optimized = 0.8940 (ms); bw_opt = 833.54 GB/s; speedup = -18.00% Size (3072, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.2087 (ms); optimized = 0.2364 (ms); bw_opt = 968.64 GB/s; speedup = -11.71% Size (6144, 10000); Mismatches: dg = 0 db = 0 out of 10000. reference = 0.4197 (ms); optimized = 0.5118 (ms); bw_opt = 894.63 GB/s; speedup = -18.00% Size (1024, 20000); Mismatches: dg = 0 db = 0 out of 20000. reference = 0.1480 (ms); optimized = 0.1297 (ms); bw_opt = 1177.68 GB/s; speedup = 14.12% Size (1024, 20000); Mismatches: dg = 0 db = 0 out of 20000. reference = 0.1483 (ms); optimized = 0.1278 (ms); bw_opt = 1195.26 GB/s; speedup = 16.04% Size (512, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0104 (ms); optimized = 0.0091 (ms); bw_opt = 646.72 GB/s; speedup = 14.44% Size (512, 6144); Mismatches: dg = 0 db = 0 out of 6144. reference = 0.0219 (ms); optimized = 0.0156 (ms); bw_opt = 1506.30 GB/s; speedup = 40.52% Size (512, 10240); Mismatches: dg = 0 db = 0 out of 10240. reference = 0.0424 (ms); optimized = 0.0370 (ms); bw_opt = 1057.84 GB/s; speedup = 14.63% Size (1000, 1000); Mismatches: dg = 0 db = 0 out of 1000. reference = 0.0139 (ms); optimized = 0.0119 (ms); bw_opt = 627.51 GB/s; speedup = 16.83% Size (2000, 2000); Mismatches: dg = 0 db = 0 out of 2000. reference = 0.0421 (ms); optimized = 0.0412 (ms); bw_opt = 724.10 GB/s; speedup = 2.20% Size (10240, 10240); Mismatches: dg = 0 db = 0 out of 10240. reference = 0.7210 (ms); optimized = 0.6098 (ms); bw_opt = 1281.40 GB/s; speedup = 18.24% Size (384, 128); Mismatches: dg = 0 db = 0 out of 128. reference = 0.0449 (ms); optimized = 0.0089 (ms); bw_opt = 41.50 GB/s; speedup = 403.48% Size (2048, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0208 (ms); optimized = 0.0169 (ms); bw_opt = 925.70 GB/s; speedup = 23.13% Size (267, 513); Mismatches: dg = 0 db = 0 out of 513. reference = 0.0342 (ms); optimized = 0.0090 (ms); bw_opt = 114.18 GB/s; speedup = 280.64% Size (67, 123479); Mismatches: dg = 0 db = 0 out of 123479. reference = 0.0562 (ms); optimized = 0.0552 (ms); bw_opt = 1133.46 GB/s; speedup = 1.81% Size (1024, 123479); Mismatches: dg = 0 db = 0 out of 123479. reference = 0.8573 (ms); optimized = 0.9245 (ms); bw_opt = 1020.02 GB/s; speedup = -7.27% Size (2048, 66679); Mismatches: dg = 0 db = 0 out of 66679. reference = 0.8778 (ms); optimized = 0.8590 (ms); bw_opt = 1185.05 GB/s; speedup = 2.19% Size (200, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0215 (ms); optimized = 0.0066 (ms); bw_opt = 58.49 GB/s; speedup = 226.81% Size (1000, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0109 (ms); optimized = 0.0092 (ms); bw_opt = 208.27 GB/s; speedup = 18.65% Size (6000, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0394 (ms); optimized = 0.0301 (ms); bw_opt = 381.90 GB/s; speedup = 30.98% Size (6272, 256); Mismatches: dg = 0 db = 0 out of 256. reference = 0.0403 (ms); optimized = 0.0300 (ms); bw_opt = 400.48 GB/s; speedup = 34.34% Size (200, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0218 (ms); optimized = 0.0066 (ms); bw_opt = 116.33 GB/s; speedup = 229.96% Size (1000, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0110 (ms); optimized = 0.0094 (ms); bw_opt = 407.29 GB/s; speedup = 17.26% Size (6000, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0535 (ms); optimized = 0.0594 (ms); bw_opt = 386.05 GB/s; speedup = -9.95% Size (6272, 512); Mismatches: dg = 0 db = 0 out of 512. reference = 0.0573 (ms); optimized = 0.0387 (ms); bw_opt = 619.62 GB/s; speedup = 48.06% Size (200, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0221 (ms); optimized = 0.0069 (ms); bw_opt = 222.78 GB/s; speedup = 220.76% Size (1000, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0113 (ms); optimized = 0.0097 (ms); bw_opt = 787.79 GB/s; speedup = 16.46% Size (6000, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0723 (ms); optimized = 0.0715 (ms); bw_opt = 640.95 GB/s; speedup = 1.10% Size (6272, 1024); Mismatches: dg = 0 db = 0 out of 1024. reference = 0.0751 (ms); optimized = 0.0572 (ms); bw_opt = 837.57 GB/s; speedup = 31.30% Size (200, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0232 (ms); optimized = 0.0071 (ms); bw_opt = 323.97 GB/s; speedup = 226.51% Size (1000, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0125 (ms); optimized = 0.0114 (ms); bw_opt = 1005.84 GB/s; speedup = 9.62% Size (6000, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0807 (ms); optimized = 0.0830 (ms); bw_opt = 828.02 GB/s; speedup = -2.76% Size (6272, 1536); Mismatches: dg = 0 db = 0 out of 1536. reference = 0.0836 (ms); optimized = 0.0695 (ms); bw_opt = 1033.62 GB/s; speedup = 20.27% Size (200, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0224 (ms); optimized = 0.0075 (ms); bw_opt = 408.58 GB/s; speedup = 198.10% Size (1000, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0165 (ms); optimized = 0.0135 (ms); bw_opt = 1132.42 GB/s; speedup = 22.26% Size (6000, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.0993 (ms); optimized = 0.0989 (ms); bw_opt = 926.35 GB/s; speedup = 0.41% Size (6272, 2048); Mismatches: dg = 0 db = 0 out of 2048. reference = 0.1033 (ms); optimized = 0.0826 (ms); bw_opt = 1159.55 GB/s; speedup = 25.09% Size (200, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.0230 (ms); optimized = 0.0076 (ms); bw_opt = 605.09 GB/s; speedup = 202.51% Size (1000, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.0207 (ms); optimized = 0.0213 (ms); bw_opt = 1076.45 GB/s; speedup = -2.69% Size (6000, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.1198 (ms); optimized = 0.1274 (ms); bw_opt = 1078.58 GB/s; speedup = -5.95% Size (6272, 3072); Mismatches: dg = 0 db = 0 out of 3072. reference = 0.1293 (ms); optimized = 0.1189 (ms); bw_opt = 1207.95 GB/s; speedup = 8.76% Average speedup = 52.88% ``` For additional numerical validation used the following script: ``` def run_model_on_device(fs, X, gO, device_string, numeric_type): ln = torch.nn.LayerNorm((fs,), device=device_string, dtype=numeric_type) ln.reset_parameters() X.grad = None ln.zero_grad(set_to_none=True) out = ln(X) out.backward(gO) return (ln.weight.grad, ln.bias.grad) def run_correctness_test(eps_weight, eps_bias): dtype = torch.float for fs in (512, 1024, 2048, 4096, 8192, 10000, 500, 1000, 2001, 4005, 8117): for bs in (512, 1024, 2048, 4096, 525, 1033, 2064, 3000): mean_adjustment = torch.randn(fs, device="cpu", dtype=torch.float) X = mean_adjustment * torch.randn( bs, fs, device="cpu", dtype=torch.float, requires_grad=True ) X = X.detach().requires_grad_() gO = torch.rand_like(X) X_gpu = X.to("cuda") X_gpu = X_gpu.detach().requires_grad_() gO_gpu = gO.to("cuda") gO_gpu = gO_gpu.detach().requires_grad_() grad_cpu_ref = run_model_on_device(fs, X, gO, "cpu", dtype) grad_gpu = run_model_on_device(fs, X_gpu, gO_gpu, "cuda", dtype) weight_grad_gpu_target = grad_gpu[0].detach().to("cpu") bias_grad_gpu_target = grad_gpu[1].detach().to("cpu") weight_delta = torch.abs(grad_cpu_ref[0] - weight_grad_gpu_target) weight_mismatches = (weight_delta >= eps_weight).nonzero() weight_mismatch_pct = len(weight_mismatches) / len(weight_delta) * 100 bias_delta = torch.abs(grad_cpu_ref[1] - bias_grad_gpu_target) bias_mismatches = (bias_delta >= eps_bias).nonzero() bias_mismatch_pct = len(bias_mismatches) / len(bias_delta) * 100 print( "Size ({} x {}) mismatch percentage: weight {:3.2f} bias {:3.2f}".format( fs, bs, weight_mismatch_pct, bias_mismatch_pct ) ) ``` `NVFuserTest.FusionMagicSchedulerLayerNormBackward_CUDA` test also does additional numerical validation and it passes. Reviewed By: ngimel Differential Revision: D40567574 fbshipit-source-id: c7ab0dff4c375cc0b60bbc2c0e452939d823b91b
This pull request was exported from Phabricator. Differential Revision: D40567574 |
e8c7506
to
e702b0c
Compare
@pytorchbot merge -g |
Merge startedYour change will be merged once all checks on your PR pass since you used the green (-g) flag (ETA: 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Hey @valentinandrei. |
Didn't find following labels among repository labels: topic:,performance,topic:,not,user,facing |
Didn't find following labels among repository labels: topic:performance |
Didn't find following labels among repository labels: performance |
@pytorchbot label "topic: performance" |
@pytorchbot revert -m "breaking internal builds due to MS compiler " -c "ghfirst"”. |
❌ 🤖 pytorchbot command failed:
Try |
@pytorchbot revert -m "breaking internal builds due to MS compiler " -c ghfirst |
@pytorchbot successfully started a revert job. Check the current status here. |
@valentinandrei your PR has been successfully reverted. |
…87445)" This reverts commit b6f2833. Reverted #87445 on behalf of https://github.com/weiwangmeta due to breaking internal builds due to MS compiler
…87445) Test Plan: ``` Times below are Forward + Backward on A100 Size FP32. Gain. FP16. Gain 256, 256 101.30 9% 103.9 6% 512, 256 110.10 -4% 102.9 10% 1024, 256 104.30 7% 102.4 6% 2048, 256 107.60 4% 109.7 0% 4096, 256 116.70 8% 109.1 0% 6144, 256 106.10 7% 112.8 2% 8192, 256 106.10 1% 109.7 2% 256, 512 102.10 3% 108.5 1% 512, 512 101.50 40% 105.9 4% 1024, 512 109.70 20% 109.2 -1% 2048, 512 107.40 24% 107.2 1% 4096, 512 108.00 6% 110.6 -3% 6144, 512 103.90 13% 105.8 7% 8192, 512 138.70 14% 105.6 7% 256, 1024 106.20 1% 102.9 6% 512, 1024 104.50 4% 104.2 3% 1024, 1024 126.90 -15% 103.9 10% 2048, 1024 127.40 -15% 102.2 6% 4096, 1024 117.70 6% 102.8 21% 6144, 1024 165.30 11% 112.2 12% 8192, 1024 211.90 11% 144.8 13% 256, 1536 102.80 11% 103.1 6% 512, 1536 103.30 9% 102.9 18% 1024, 1536 111.00 -2% 117.2 7% 2048, 1536 102.30 12% 132.1 -4% 4096, 1536 165.50 5% 112.9 18% 6144, 1536 236.60 5% 145.7 12% 8192, 1536 307.80 5% 186.1 11% 256, 2048 110.60 -1% 103.8 7% 512, 2048 105.20 3% 105.6 1% 1024, 2048 106.70 3% 114.8 3% 2048, 2048 124.90 5% 109.7 0% 4096, 2048 231.40 4% 129.9 10% 6144, 2048 332.80 4% 182.5 11% 8192, 2048 434.60 4% 235.2 11% 256, 3072 111.60 8% 110.8 1% 512, 3072 106.80 1% 104.6 10% 1024, 3072 104.90 3% 109.9 4% 2048, 3072 193.80 0% 106.2 10% 4096, 3072 364.50 0% 187.8 5% 6144, 3072 538.30 0% 267 5% 8192, 3072 718.00 -1% 346.7 6% 256, 4096 103.60 4% 110.2 -1% 512, 4096 131.40 -11% 117 -7% 1024, 4096 135.80 1% 104.8 7% 2048, 4096 268.20 1% 149.4 10% 4096, 4096 520.70 1% 268.5 9% 6144, 4096 786.30 0% 389.8 9% 8192, 4096 1043.50 0% 509 10% ``` Used the following script from ngimel: ``` import torch from torch.utils.benchmark import Compare, Timer results = [] for dtype in (torch.float, torch.half): for fs in (256, 512, 1024, 1536, 2048, 3072, 4096): for bs in (256, 512, 1024, 2048, 4096, 6144, 8192): ln = torch.nn.LayerNorm((fs,), device="cuda", dtype=dtype) X = torch.randn(bs, fs, device="cuda", dtype=dtype, requires_grad=True) gO = torch.rand_like(X) stmtfwd = "ln(X)" stmtfwdbwd = "X.grad=None; ln.zero_grad(set_to_none=True); out = ln(X); out.backward(gO)" tfwd = Timer( stmt=stmtfwd, label="ln", sub_label=f"{bs:5}, {fs:5}", description=f"fwd, {dtype}", globals=globals(), ) tfwdbwd = Timer( stmt=stmtfwdbwd, label="ln", sub_label=f"{bs:5}, {fs:5}", description=f"fwdbwd, {dtype}", globals=globals(), ) for t in (tfwd, tfwdbwd): results.append(t.blocked_autorange()) print(fs, end="\r") c = Compare(results) c.print() ``` Differential Revision: D40567574 Pull Request resolved: pytorch#87445 Approved by: https://github.com/ngimel
…ytorch#87445)" This reverts commit b6f2833. Reverted pytorch#87445 on behalf of https://github.com/weiwangmeta due to breaking internal builds due to MS compiler
…87445) Test Plan: ``` Times below are Forward + Backward on A100 Size FP32. Gain. FP16. Gain 256, 256 101.30 9% 103.9 6% 512, 256 110.10 -4% 102.9 10% 1024, 256 104.30 7% 102.4 6% 2048, 256 107.60 4% 109.7 0% 4096, 256 116.70 8% 109.1 0% 6144, 256 106.10 7% 112.8 2% 8192, 256 106.10 1% 109.7 2% 256, 512 102.10 3% 108.5 1% 512, 512 101.50 40% 105.9 4% 1024, 512 109.70 20% 109.2 -1% 2048, 512 107.40 24% 107.2 1% 4096, 512 108.00 6% 110.6 -3% 6144, 512 103.90 13% 105.8 7% 8192, 512 138.70 14% 105.6 7% 256, 1024 106.20 1% 102.9 6% 512, 1024 104.50 4% 104.2 3% 1024, 1024 126.90 -15% 103.9 10% 2048, 1024 127.40 -15% 102.2 6% 4096, 1024 117.70 6% 102.8 21% 6144, 1024 165.30 11% 112.2 12% 8192, 1024 211.90 11% 144.8 13% 256, 1536 102.80 11% 103.1 6% 512, 1536 103.30 9% 102.9 18% 1024, 1536 111.00 -2% 117.2 7% 2048, 1536 102.30 12% 132.1 -4% 4096, 1536 165.50 5% 112.9 18% 6144, 1536 236.60 5% 145.7 12% 8192, 1536 307.80 5% 186.1 11% 256, 2048 110.60 -1% 103.8 7% 512, 2048 105.20 3% 105.6 1% 1024, 2048 106.70 3% 114.8 3% 2048, 2048 124.90 5% 109.7 0% 4096, 2048 231.40 4% 129.9 10% 6144, 2048 332.80 4% 182.5 11% 8192, 2048 434.60 4% 235.2 11% 256, 3072 111.60 8% 110.8 1% 512, 3072 106.80 1% 104.6 10% 1024, 3072 104.90 3% 109.9 4% 2048, 3072 193.80 0% 106.2 10% 4096, 3072 364.50 0% 187.8 5% 6144, 3072 538.30 0% 267 5% 8192, 3072 718.00 -1% 346.7 6% 256, 4096 103.60 4% 110.2 -1% 512, 4096 131.40 -11% 117 -7% 1024, 4096 135.80 1% 104.8 7% 2048, 4096 268.20 1% 149.4 10% 4096, 4096 520.70 1% 268.5 9% 6144, 4096 786.30 0% 389.8 9% 8192, 4096 1043.50 0% 509 10% ``` Used the following script from ngimel: ``` import torch from torch.utils.benchmark import Compare, Timer results = [] for dtype in (torch.float, torch.half): for fs in (256, 512, 1024, 1536, 2048, 3072, 4096): for bs in (256, 512, 1024, 2048, 4096, 6144, 8192): ln = torch.nn.LayerNorm((fs,), device="cuda", dtype=dtype) X = torch.randn(bs, fs, device="cuda", dtype=dtype, requires_grad=True) gO = torch.rand_like(X) stmtfwd = "ln(X)" stmtfwdbwd = "X.grad=None; ln.zero_grad(set_to_none=True); out = ln(X); out.backward(gO)" tfwd = Timer( stmt=stmtfwd, label="ln", sub_label=f"{bs:5}, {fs:5}", description=f"fwd, {dtype}", globals=globals(), ) tfwdbwd = Timer( stmt=stmtfwdbwd, label="ln", sub_label=f"{bs:5}, {fs:5}", description=f"fwdbwd, {dtype}", globals=globals(), ) for t in (tfwd, tfwdbwd): results.append(t.blocked_autorange()) print(fs, end="\r") c = Compare(results) c.print() ``` Differential Revision: D40567574 Pull Request resolved: pytorch#87445 Approved by: https://github.com/ngimel
…ytorch#87445)" This reverts commit b6f2833. Reverted pytorch#87445 on behalf of https://github.com/weiwangmeta due to breaking internal builds due to MS compiler
…87445) Test Plan: ``` Times below are Forward + Backward on A100 Size FP32. Gain. FP16. Gain 256, 256 101.30 9% 103.9 6% 512, 256 110.10 -4% 102.9 10% 1024, 256 104.30 7% 102.4 6% 2048, 256 107.60 4% 109.7 0% 4096, 256 116.70 8% 109.1 0% 6144, 256 106.10 7% 112.8 2% 8192, 256 106.10 1% 109.7 2% 256, 512 102.10 3% 108.5 1% 512, 512 101.50 40% 105.9 4% 1024, 512 109.70 20% 109.2 -1% 2048, 512 107.40 24% 107.2 1% 4096, 512 108.00 6% 110.6 -3% 6144, 512 103.90 13% 105.8 7% 8192, 512 138.70 14% 105.6 7% 256, 1024 106.20 1% 102.9 6% 512, 1024 104.50 4% 104.2 3% 1024, 1024 126.90 -15% 103.9 10% 2048, 1024 127.40 -15% 102.2 6% 4096, 1024 117.70 6% 102.8 21% 6144, 1024 165.30 11% 112.2 12% 8192, 1024 211.90 11% 144.8 13% 256, 1536 102.80 11% 103.1 6% 512, 1536 103.30 9% 102.9 18% 1024, 1536 111.00 -2% 117.2 7% 2048, 1536 102.30 12% 132.1 -4% 4096, 1536 165.50 5% 112.9 18% 6144, 1536 236.60 5% 145.7 12% 8192, 1536 307.80 5% 186.1 11% 256, 2048 110.60 -1% 103.8 7% 512, 2048 105.20 3% 105.6 1% 1024, 2048 106.70 3% 114.8 3% 2048, 2048 124.90 5% 109.7 0% 4096, 2048 231.40 4% 129.9 10% 6144, 2048 332.80 4% 182.5 11% 8192, 2048 434.60 4% 235.2 11% 256, 3072 111.60 8% 110.8 1% 512, 3072 106.80 1% 104.6 10% 1024, 3072 104.90 3% 109.9 4% 2048, 3072 193.80 0% 106.2 10% 4096, 3072 364.50 0% 187.8 5% 6144, 3072 538.30 0% 267 5% 8192, 3072 718.00 -1% 346.7 6% 256, 4096 103.60 4% 110.2 -1% 512, 4096 131.40 -11% 117 -7% 1024, 4096 135.80 1% 104.8 7% 2048, 4096 268.20 1% 149.4 10% 4096, 4096 520.70 1% 268.5 9% 6144, 4096 786.30 0% 389.8 9% 8192, 4096 1043.50 0% 509 10% ``` Used the following script from ngimel: ``` import torch from torch.utils.benchmark import Compare, Timer results = [] for dtype in (torch.float, torch.half): for fs in (256, 512, 1024, 1536, 2048, 3072, 4096): for bs in (256, 512, 1024, 2048, 4096, 6144, 8192): ln = torch.nn.LayerNorm((fs,), device="cuda", dtype=dtype) X = torch.randn(bs, fs, device="cuda", dtype=dtype, requires_grad=True) gO = torch.rand_like(X) stmtfwd = "ln(X)" stmtfwdbwd = "X.grad=None; ln.zero_grad(set_to_none=True); out = ln(X); out.backward(gO)" tfwd = Timer( stmt=stmtfwd, label="ln", sub_label=f"{bs:5}, {fs:5}", description=f"fwd, {dtype}", globals=globals(), ) tfwdbwd = Timer( stmt=stmtfwdbwd, label="ln", sub_label=f"{bs:5}, {fs:5}", description=f"fwdbwd, {dtype}", globals=globals(), ) for t in (tfwd, tfwdbwd): results.append(t.blocked_autorange()) print(fs, end="\r") c = Compare(results) c.print() ``` Differential Revision: D40567574 Pull Request resolved: pytorch#87445 Approved by: https://github.com/ngimel
…ytorch#87445)" This reverts commit b6f2833. Reverted pytorch#87445 on behalf of https://github.com/weiwangmeta due to breaking internal builds due to MS compiler
Test Plan:
Used the following script from ngimel:
Differential Revision: D40567574