-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Extract dot-product functions from fp16_gemv_trans gemv kernels #127435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/127435
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit afa8024 with merge base 9a8e810 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, though I was always bother by using signed types as array indices... Consider changing those, if possible
signed types are generally better because overflow is UB, so the compiler can assume that overflow will not happen and optimize accordingly. unsigned is for when you really need the extra one (1) bit. (see e.g. https://youtu.be/yG1OZ69H_-o?t=2357&si=HLuIdAnutC4ZhfWb) |
…7451) Summary: This doesn't change the overall gemm algorithm away from repeated dot products, just uses our efficient fp16 dot product developed for the gemv case. It seems to improve performance for every prompt length I tested. Test Plan: Use https://github.com/malfet/llm_experiments/blob/main/benchmarks/benchmark_torch_mm.py , edited to test only the trans_b (really gemm_transa_) case for the sizes outlined in the output. Before: ``` Matrix-vector: m=8, n=128, k=1 ==================== trans_b torch.float32 1.05 usec trans_b torch.float16 0.97 usec trans_b torch.bfloat16 1.06 usec m=128, n=8, k=1 ==================== trans_b torch.float32 0.80 usec trans_b torch.float16 0.97 usec trans_b torch.bfloat16 1.00 usec m=4096, n=4096, k=1 ==================== trans_b torch.float32 2160.75 usec trans_b torch.float16 659.77 usec trans_b torch.bfloat16 3800.13 usec m=11008, n=4096, k=1 ==================== trans_b torch.float32 6343.68 usec trans_b torch.float16 1789.42 usec trans_b torch.bfloat16 10098.34 usec m=4096, n=11008, k=1 ==================== trans_b torch.float32 6217.20 usec trans_b torch.float16 1874.47 usec trans_b torch.bfloat16 10490.30 usec m=32000, n=4096, k=1 ==================== trans_b torch.float32 17934.45 usec trans_b torch.float16 5323.81 usec trans_b torch.bfloat16 29320.80 usec Matrix-matrix (prompt len 4: m=8, n=128, k=4 ==================== trans_b torch.float32 2.40 usec trans_b torch.float16 1.22 usec trans_b torch.bfloat16 1.22 usec m=128, n=8, k=4 ==================== trans_b torch.float32 1.52 usec trans_b torch.float16 1.33 usec trans_b torch.bfloat16 1.77 usec m=4096, n=4096, k=4 ==================== trans_b torch.float32 4317.09 usec trans_b torch.float16 15541.04 usec trans_b torch.bfloat16 15032.29 usec m=11008, n=4096, k=4 ==================== trans_b torch.float32 6191.19 usec trans_b torch.float16 40436.29 usec trans_b torch.bfloat16 40626.93 usec m=4096, n=11008, k=4 ==================== trans_b torch.float32 6049.22 usec trans_b torch.float16 42367.16 usec trans_b torch.bfloat16 42482.43 usec m=32000, n=4096, k=4 ==================== trans_b torch.float32 17611.36 usec trans_b torch.float16 117368.54 usec trans_b torch.bfloat16 116958.85 usec Matrix-matrix (prompt len 8: m=8, n=128, k=8 ==================== trans_b torch.float32 1.04 usec trans_b torch.float16 1.71 usec trans_b torch.bfloat16 1.74 usec m=128, n=8, k=8 ==================== trans_b torch.float32 2.10 usec trans_b torch.float16 2.01 usec trans_b torch.bfloat16 2.91 usec m=4096, n=4096, k=8 ==================== trans_b torch.float32 2456.23 usec trans_b torch.float16 30112.76 usec trans_b torch.bfloat16 29941.58 usec m=11008, n=4096, k=8 ==================== trans_b torch.float32 6236.12 usec trans_b torch.float16 80361.22 usec trans_b torch.bfloat16 80466.64 usec m=4096, n=11008, k=8 ==================== trans_b torch.float32 6236.10 usec trans_b torch.float16 82990.74 usec trans_b torch.bfloat16 83899.80 usec m=32000, n=4096, k=8 ==================== trans_b torch.float32 17606.43 usec trans_b torch.float16 234397.38 usec trans_b torch.bfloat16 237057.29 usec Matrix-matrix (prompt len 16: m=8, n=128, k=16 ==================== trans_b torch.float32 1.31 usec trans_b torch.float16 2.67 usec trans_b torch.bfloat16 2.72 usec m=128, n=8, k=16 ==================== trans_b torch.float32 1.66 usec trans_b torch.float16 3.36 usec trans_b torch.bfloat16 5.18 usec m=4096, n=4096, k=16 ==================== trans_b torch.float32 2504.24 usec trans_b torch.float16 60896.53 usec trans_b torch.bfloat16 59852.49 usec m=11008, n=4096, k=16 ==================== trans_b torch.float32 6407.11 usec trans_b torch.float16 163294.92 usec trans_b torch.bfloat16 161199.10 usec m=4096, n=11008, k=16 ==================== trans_b torch.float32 6132.30 usec trans_b torch.float16 167244.77 usec trans_b torch.bfloat16 170064.35 usec m=32000, n=4096, k=16 ==================== trans_b torch.float32 17635.56 usec trans_b torch.float16 475020.00 usec trans_b torch.bfloat16 476332.29 usec Matrix-matrix (prompt len 32: m=8, n=128, k=32 ==================== trans_b torch.float32 1.40 usec trans_b torch.float16 4.67 usec trans_b torch.bfloat16 4.80 usec m=128, n=8, k=32 ==================== trans_b torch.float32 1.24 usec trans_b torch.float16 6.10 usec trans_b torch.bfloat16 10.03 usec m=4096, n=4096, k=32 ==================== trans_b torch.float32 2660.63 usec trans_b torch.float16 122436.04 usec trans_b torch.bfloat16 121687.96 usec m=11008, n=4096, k=32 ==================== trans_b torch.float32 6405.60 usec trans_b torch.float16 324708.42 usec trans_b torch.bfloat16 324866.67 usec m=4096, n=11008, k=32 ==================== trans_b torch.float32 6566.74 usec trans_b torch.float16 330801.04 usec trans_b torch.bfloat16 332561.79 usec m=32000, n=4096, k=32 ==================== trans_b torch.float32 18610.84 usec trans_b torch.float16 944578.75 usec trans_b torch.bfloat16 940674.33 usec Matrix-matrix (prompt len 128: m=8, n=128, k=128 ==================== trans_b torch.float32 2.48 usec trans_b torch.float16 16.43 usec trans_b torch.bfloat16 17.11 usec m=128, n=8, k=128 ==================== trans_b torch.float32 1.83 usec trans_b torch.float16 22.31 usec trans_b torch.bfloat16 37.00 usec m=4096, n=4096, k=128 ==================== trans_b torch.float32 4806.59 usec trans_b torch.float16 485338.83 usec trans_b torch.bfloat16 478835.08 usec m=11008, n=4096, k=128 ==================== trans_b torch.float32 12109.51 usec trans_b torch.float16 1300928.58 usec trans_b torch.bfloat16 1293181.63 usec m=4096, n=11008, k=128 ==================== trans_b torch.float32 11223.70 usec trans_b torch.float16 1326119.92 usec trans_b torch.bfloat16 1330395.12 usec m=32000, n=4096, k=128 ==================== trans_b torch.float32 33485.34 usec trans_b torch.float16 3869227.17 usec trans_b torch.bfloat16 3792905.00 usec ``` After: ``` Matrix-vector: m=8, n=128, k=1 ==================== trans_b torch.float32 0.75 usec trans_b torch.float16 0.71 usec trans_b torch.bfloat16 0.81 usec m=128, n=8, k=1 ==================== trans_b torch.float32 0.75 usec trans_b torch.float16 0.93 usec trans_b torch.bfloat16 0.98 usec m=4096, n=4096, k=1 ==================== trans_b torch.float32 2194.31 usec trans_b torch.float16 661.27 usec trans_b torch.bfloat16 3758.42 usec m=11008, n=4096, k=1 ==================== trans_b torch.float32 5792.04 usec trans_b torch.float16 1789.98 usec trans_b torch.bfloat16 10120.67 usec m=4096, n=11008, k=1 ==================== trans_b torch.float32 6101.22 usec trans_b torch.float16 1927.34 usec trans_b torch.bfloat16 10469.47 usec m=32000, n=4096, k=1 ==================== trans_b torch.float32 18353.20 usec trans_b torch.float16 5161.06 usec trans_b torch.bfloat16 29601.69 usec Matrix-matrix (prompt len 4: m=8, n=128, k=4 ==================== trans_b torch.float32 2.14 usec trans_b torch.float16 0.85 usec trans_b torch.bfloat16 1.19 usec m=128, n=8, k=4 ==================== trans_b torch.float32 1.47 usec trans_b torch.float16 1.85 usec trans_b torch.bfloat16 1.75 usec m=4096, n=4096, k=4 ==================== trans_b torch.float32 4416.40 usec trans_b torch.float16 2688.36 usec trans_b torch.bfloat16 14987.33 usec m=11008, n=4096, k=4 ==================== trans_b torch.float32 6140.24 usec trans_b torch.float16 7467.26 usec trans_b torch.bfloat16 40295.52 usec m=4096, n=11008, k=4 ==================== trans_b torch.float32 6143.10 usec trans_b torch.float16 7298.04 usec trans_b torch.bfloat16 41393.43 usec m=32000, n=4096, k=4 ==================== trans_b torch.float32 17650.72 usec trans_b torch.float16 21346.63 usec trans_b torch.bfloat16 116849.98 usec Matrix-matrix (prompt len 8: m=8, n=128, k=8 ==================== trans_b torch.float32 1.05 usec trans_b torch.float16 1.03 usec trans_b torch.bfloat16 1.69 usec m=128, n=8, k=8 ==================== trans_b torch.float32 2.05 usec trans_b torch.float16 3.08 usec trans_b torch.bfloat16 2.95 usec m=4096, n=4096, k=8 ==================== trans_b torch.float32 2323.99 usec trans_b torch.float16 5265.45 usec trans_b torch.bfloat16 29942.40 usec m=11008, n=4096, k=8 ==================== trans_b torch.float32 6202.01 usec trans_b torch.float16 14677.90 usec trans_b torch.bfloat16 80625.18 usec m=4096, n=11008, k=8 ==================== trans_b torch.float32 6112.05 usec trans_b torch.float16 14340.52 usec trans_b torch.bfloat16 82799.99 usec m=32000, n=4096, k=8 ==================== trans_b torch.float32 17650.65 usec trans_b torch.float16 42551.43 usec trans_b torch.bfloat16 236081.08 usec Matrix-matrix (prompt len 16: m=8, n=128, k=16 ==================== trans_b torch.float32 1.26 usec trans_b torch.float16 1.34 usec trans_b torch.bfloat16 2.69 usec m=128, n=8, k=16 ==================== trans_b torch.float32 1.60 usec trans_b torch.float16 5.81 usec trans_b torch.bfloat16 5.34 usec m=4096, n=4096, k=16 ==================== trans_b torch.float32 2328.05 usec trans_b torch.float16 10526.58 usec trans_b torch.bfloat16 60028.28 usec m=11008, n=4096, k=16 ==================== trans_b torch.float32 6243.35 usec trans_b torch.float16 28505.08 usec trans_b torch.bfloat16 163670.15 usec m=4096, n=11008, k=16 ==================== trans_b torch.float32 5870.11 usec trans_b torch.float16 28597.89 usec trans_b torch.bfloat16 165404.88 usec m=32000, n=4096, k=16 ==================== trans_b torch.float32 17746.27 usec trans_b torch.float16 83393.87 usec trans_b torch.bfloat16 472313.13 usec Matrix-matrix (prompt len 32: m=8, n=128, k=32 ==================== trans_b torch.float32 1.35 usec trans_b torch.float16 2.01 usec trans_b torch.bfloat16 4.68 usec m=128, n=8, k=32 ==================== trans_b torch.float32 1.19 usec trans_b torch.float16 10.98 usec trans_b torch.bfloat16 10.13 usec m=4096, n=4096, k=32 ==================== trans_b torch.float32 2525.29 usec trans_b torch.float16 23106.71 usec trans_b torch.bfloat16 122987.04 usec m=11008, n=4096, k=32 ==================== trans_b torch.float32 6131.34 usec trans_b torch.float16 57537.41 usec trans_b torch.bfloat16 327825.00 usec m=4096, n=11008, k=32 ==================== trans_b torch.float32 6395.01 usec trans_b torch.float16 57456.33 usec trans_b torch.bfloat16 331325.58 usec m=32000, n=4096, k=32 ==================== trans_b torch.float32 19078.68 usec trans_b torch.float16 167735.08 usec trans_b torch.bfloat16 975736.88 usec Matrix-matrix (prompt len 128: m=8, n=128, k=128 ==================== trans_b torch.float32 2.40 usec trans_b torch.float16 6.07 usec trans_b torch.bfloat16 16.83 usec m=128, n=8, k=128 ==================== trans_b torch.float32 1.78 usec trans_b torch.float16 40.35 usec trans_b torch.bfloat16 37.21 usec m=4096, n=4096, k=128 ==================== trans_b torch.float32 4827.60 usec trans_b torch.float16 84341.24 usec trans_b torch.bfloat16 478917.75 usec m=11008, n=4096, k=128 ==================== trans_b torch.float32 11879.96 usec trans_b torch.float16 226484.33 usec trans_b torch.bfloat16 1289465.50 usec m=4096, n=11008, k=128 ==================== trans_b torch.float32 10707.75 usec trans_b torch.float16 229200.58 usec trans_b torch.bfloat16 1327416.67 usec m=32000, n=4096, k=128 ==================== trans_b torch.float32 33306.32 usec trans_b torch.float16 662898.21 usec trans_b torch.bfloat16 3815866.63 usec ``` torch.float16 performance seems to be improved for all except the m=128, n=8, k=128 case, where it is roughly neutral. This case motivated the addition of the "first-tier tail fixup" in the dot kernel. Pull Request resolved: #127451 Approved by: https://github.com/malfet ghstack dependencies: #127435
…32_arith (#127476) Summary: Preparing to generalize to bf16. (This should not be committed unless the following bf16 PR is committed!) Test Plan: Spot-checked llm_experiments benchmark result to make sure it didn't regress. Pull Request resolved: #127476 Approved by: https://github.com/malfet ghstack dependencies: #127435, #127451
…rch#127435) Summary: Refactoring step before we attempt to use these to implement a less bad fp16 GEMM. Test Plan: Existing tests. Pull Request resolved: pytorch#127435 Approved by: https://github.com/malfet
…orch#127451) Summary: This doesn't change the overall gemm algorithm away from repeated dot products, just uses our efficient fp16 dot product developed for the gemv case. It seems to improve performance for every prompt length I tested. Test Plan: Use https://github.com/malfet/llm_experiments/blob/main/benchmarks/benchmark_torch_mm.py , edited to test only the trans_b (really gemm_transa_) case for the sizes outlined in the output. Before: ``` Matrix-vector: m=8, n=128, k=1 ==================== trans_b torch.float32 1.05 usec trans_b torch.float16 0.97 usec trans_b torch.bfloat16 1.06 usec m=128, n=8, k=1 ==================== trans_b torch.float32 0.80 usec trans_b torch.float16 0.97 usec trans_b torch.bfloat16 1.00 usec m=4096, n=4096, k=1 ==================== trans_b torch.float32 2160.75 usec trans_b torch.float16 659.77 usec trans_b torch.bfloat16 3800.13 usec m=11008, n=4096, k=1 ==================== trans_b torch.float32 6343.68 usec trans_b torch.float16 1789.42 usec trans_b torch.bfloat16 10098.34 usec m=4096, n=11008, k=1 ==================== trans_b torch.float32 6217.20 usec trans_b torch.float16 1874.47 usec trans_b torch.bfloat16 10490.30 usec m=32000, n=4096, k=1 ==================== trans_b torch.float32 17934.45 usec trans_b torch.float16 5323.81 usec trans_b torch.bfloat16 29320.80 usec Matrix-matrix (prompt len 4: m=8, n=128, k=4 ==================== trans_b torch.float32 2.40 usec trans_b torch.float16 1.22 usec trans_b torch.bfloat16 1.22 usec m=128, n=8, k=4 ==================== trans_b torch.float32 1.52 usec trans_b torch.float16 1.33 usec trans_b torch.bfloat16 1.77 usec m=4096, n=4096, k=4 ==================== trans_b torch.float32 4317.09 usec trans_b torch.float16 15541.04 usec trans_b torch.bfloat16 15032.29 usec m=11008, n=4096, k=4 ==================== trans_b torch.float32 6191.19 usec trans_b torch.float16 40436.29 usec trans_b torch.bfloat16 40626.93 usec m=4096, n=11008, k=4 ==================== trans_b torch.float32 6049.22 usec trans_b torch.float16 42367.16 usec trans_b torch.bfloat16 42482.43 usec m=32000, n=4096, k=4 ==================== trans_b torch.float32 17611.36 usec trans_b torch.float16 117368.54 usec trans_b torch.bfloat16 116958.85 usec Matrix-matrix (prompt len 8: m=8, n=128, k=8 ==================== trans_b torch.float32 1.04 usec trans_b torch.float16 1.71 usec trans_b torch.bfloat16 1.74 usec m=128, n=8, k=8 ==================== trans_b torch.float32 2.10 usec trans_b torch.float16 2.01 usec trans_b torch.bfloat16 2.91 usec m=4096, n=4096, k=8 ==================== trans_b torch.float32 2456.23 usec trans_b torch.float16 30112.76 usec trans_b torch.bfloat16 29941.58 usec m=11008, n=4096, k=8 ==================== trans_b torch.float32 6236.12 usec trans_b torch.float16 80361.22 usec trans_b torch.bfloat16 80466.64 usec m=4096, n=11008, k=8 ==================== trans_b torch.float32 6236.10 usec trans_b torch.float16 82990.74 usec trans_b torch.bfloat16 83899.80 usec m=32000, n=4096, k=8 ==================== trans_b torch.float32 17606.43 usec trans_b torch.float16 234397.38 usec trans_b torch.bfloat16 237057.29 usec Matrix-matrix (prompt len 16: m=8, n=128, k=16 ==================== trans_b torch.float32 1.31 usec trans_b torch.float16 2.67 usec trans_b torch.bfloat16 2.72 usec m=128, n=8, k=16 ==================== trans_b torch.float32 1.66 usec trans_b torch.float16 3.36 usec trans_b torch.bfloat16 5.18 usec m=4096, n=4096, k=16 ==================== trans_b torch.float32 2504.24 usec trans_b torch.float16 60896.53 usec trans_b torch.bfloat16 59852.49 usec m=11008, n=4096, k=16 ==================== trans_b torch.float32 6407.11 usec trans_b torch.float16 163294.92 usec trans_b torch.bfloat16 161199.10 usec m=4096, n=11008, k=16 ==================== trans_b torch.float32 6132.30 usec trans_b torch.float16 167244.77 usec trans_b torch.bfloat16 170064.35 usec m=32000, n=4096, k=16 ==================== trans_b torch.float32 17635.56 usec trans_b torch.float16 475020.00 usec trans_b torch.bfloat16 476332.29 usec Matrix-matrix (prompt len 32: m=8, n=128, k=32 ==================== trans_b torch.float32 1.40 usec trans_b torch.float16 4.67 usec trans_b torch.bfloat16 4.80 usec m=128, n=8, k=32 ==================== trans_b torch.float32 1.24 usec trans_b torch.float16 6.10 usec trans_b torch.bfloat16 10.03 usec m=4096, n=4096, k=32 ==================== trans_b torch.float32 2660.63 usec trans_b torch.float16 122436.04 usec trans_b torch.bfloat16 121687.96 usec m=11008, n=4096, k=32 ==================== trans_b torch.float32 6405.60 usec trans_b torch.float16 324708.42 usec trans_b torch.bfloat16 324866.67 usec m=4096, n=11008, k=32 ==================== trans_b torch.float32 6566.74 usec trans_b torch.float16 330801.04 usec trans_b torch.bfloat16 332561.79 usec m=32000, n=4096, k=32 ==================== trans_b torch.float32 18610.84 usec trans_b torch.float16 944578.75 usec trans_b torch.bfloat16 940674.33 usec Matrix-matrix (prompt len 128: m=8, n=128, k=128 ==================== trans_b torch.float32 2.48 usec trans_b torch.float16 16.43 usec trans_b torch.bfloat16 17.11 usec m=128, n=8, k=128 ==================== trans_b torch.float32 1.83 usec trans_b torch.float16 22.31 usec trans_b torch.bfloat16 37.00 usec m=4096, n=4096, k=128 ==================== trans_b torch.float32 4806.59 usec trans_b torch.float16 485338.83 usec trans_b torch.bfloat16 478835.08 usec m=11008, n=4096, k=128 ==================== trans_b torch.float32 12109.51 usec trans_b torch.float16 1300928.58 usec trans_b torch.bfloat16 1293181.63 usec m=4096, n=11008, k=128 ==================== trans_b torch.float32 11223.70 usec trans_b torch.float16 1326119.92 usec trans_b torch.bfloat16 1330395.12 usec m=32000, n=4096, k=128 ==================== trans_b torch.float32 33485.34 usec trans_b torch.float16 3869227.17 usec trans_b torch.bfloat16 3792905.00 usec ``` After: ``` Matrix-vector: m=8, n=128, k=1 ==================== trans_b torch.float32 0.75 usec trans_b torch.float16 0.71 usec trans_b torch.bfloat16 0.81 usec m=128, n=8, k=1 ==================== trans_b torch.float32 0.75 usec trans_b torch.float16 0.93 usec trans_b torch.bfloat16 0.98 usec m=4096, n=4096, k=1 ==================== trans_b torch.float32 2194.31 usec trans_b torch.float16 661.27 usec trans_b torch.bfloat16 3758.42 usec m=11008, n=4096, k=1 ==================== trans_b torch.float32 5792.04 usec trans_b torch.float16 1789.98 usec trans_b torch.bfloat16 10120.67 usec m=4096, n=11008, k=1 ==================== trans_b torch.float32 6101.22 usec trans_b torch.float16 1927.34 usec trans_b torch.bfloat16 10469.47 usec m=32000, n=4096, k=1 ==================== trans_b torch.float32 18353.20 usec trans_b torch.float16 5161.06 usec trans_b torch.bfloat16 29601.69 usec Matrix-matrix (prompt len 4: m=8, n=128, k=4 ==================== trans_b torch.float32 2.14 usec trans_b torch.float16 0.85 usec trans_b torch.bfloat16 1.19 usec m=128, n=8, k=4 ==================== trans_b torch.float32 1.47 usec trans_b torch.float16 1.85 usec trans_b torch.bfloat16 1.75 usec m=4096, n=4096, k=4 ==================== trans_b torch.float32 4416.40 usec trans_b torch.float16 2688.36 usec trans_b torch.bfloat16 14987.33 usec m=11008, n=4096, k=4 ==================== trans_b torch.float32 6140.24 usec trans_b torch.float16 7467.26 usec trans_b torch.bfloat16 40295.52 usec m=4096, n=11008, k=4 ==================== trans_b torch.float32 6143.10 usec trans_b torch.float16 7298.04 usec trans_b torch.bfloat16 41393.43 usec m=32000, n=4096, k=4 ==================== trans_b torch.float32 17650.72 usec trans_b torch.float16 21346.63 usec trans_b torch.bfloat16 116849.98 usec Matrix-matrix (prompt len 8: m=8, n=128, k=8 ==================== trans_b torch.float32 1.05 usec trans_b torch.float16 1.03 usec trans_b torch.bfloat16 1.69 usec m=128, n=8, k=8 ==================== trans_b torch.float32 2.05 usec trans_b torch.float16 3.08 usec trans_b torch.bfloat16 2.95 usec m=4096, n=4096, k=8 ==================== trans_b torch.float32 2323.99 usec trans_b torch.float16 5265.45 usec trans_b torch.bfloat16 29942.40 usec m=11008, n=4096, k=8 ==================== trans_b torch.float32 6202.01 usec trans_b torch.float16 14677.90 usec trans_b torch.bfloat16 80625.18 usec m=4096, n=11008, k=8 ==================== trans_b torch.float32 6112.05 usec trans_b torch.float16 14340.52 usec trans_b torch.bfloat16 82799.99 usec m=32000, n=4096, k=8 ==================== trans_b torch.float32 17650.65 usec trans_b torch.float16 42551.43 usec trans_b torch.bfloat16 236081.08 usec Matrix-matrix (prompt len 16: m=8, n=128, k=16 ==================== trans_b torch.float32 1.26 usec trans_b torch.float16 1.34 usec trans_b torch.bfloat16 2.69 usec m=128, n=8, k=16 ==================== trans_b torch.float32 1.60 usec trans_b torch.float16 5.81 usec trans_b torch.bfloat16 5.34 usec m=4096, n=4096, k=16 ==================== trans_b torch.float32 2328.05 usec trans_b torch.float16 10526.58 usec trans_b torch.bfloat16 60028.28 usec m=11008, n=4096, k=16 ==================== trans_b torch.float32 6243.35 usec trans_b torch.float16 28505.08 usec trans_b torch.bfloat16 163670.15 usec m=4096, n=11008, k=16 ==================== trans_b torch.float32 5870.11 usec trans_b torch.float16 28597.89 usec trans_b torch.bfloat16 165404.88 usec m=32000, n=4096, k=16 ==================== trans_b torch.float32 17746.27 usec trans_b torch.float16 83393.87 usec trans_b torch.bfloat16 472313.13 usec Matrix-matrix (prompt len 32: m=8, n=128, k=32 ==================== trans_b torch.float32 1.35 usec trans_b torch.float16 2.01 usec trans_b torch.bfloat16 4.68 usec m=128, n=8, k=32 ==================== trans_b torch.float32 1.19 usec trans_b torch.float16 10.98 usec trans_b torch.bfloat16 10.13 usec m=4096, n=4096, k=32 ==================== trans_b torch.float32 2525.29 usec trans_b torch.float16 23106.71 usec trans_b torch.bfloat16 122987.04 usec m=11008, n=4096, k=32 ==================== trans_b torch.float32 6131.34 usec trans_b torch.float16 57537.41 usec trans_b torch.bfloat16 327825.00 usec m=4096, n=11008, k=32 ==================== trans_b torch.float32 6395.01 usec trans_b torch.float16 57456.33 usec trans_b torch.bfloat16 331325.58 usec m=32000, n=4096, k=32 ==================== trans_b torch.float32 19078.68 usec trans_b torch.float16 167735.08 usec trans_b torch.bfloat16 975736.88 usec Matrix-matrix (prompt len 128: m=8, n=128, k=128 ==================== trans_b torch.float32 2.40 usec trans_b torch.float16 6.07 usec trans_b torch.bfloat16 16.83 usec m=128, n=8, k=128 ==================== trans_b torch.float32 1.78 usec trans_b torch.float16 40.35 usec trans_b torch.bfloat16 37.21 usec m=4096, n=4096, k=128 ==================== trans_b torch.float32 4827.60 usec trans_b torch.float16 84341.24 usec trans_b torch.bfloat16 478917.75 usec m=11008, n=4096, k=128 ==================== trans_b torch.float32 11879.96 usec trans_b torch.float16 226484.33 usec trans_b torch.bfloat16 1289465.50 usec m=4096, n=11008, k=128 ==================== trans_b torch.float32 10707.75 usec trans_b torch.float16 229200.58 usec trans_b torch.bfloat16 1327416.67 usec m=32000, n=4096, k=128 ==================== trans_b torch.float32 33306.32 usec trans_b torch.float16 662898.21 usec trans_b torch.bfloat16 3815866.63 usec ``` torch.float16 performance seems to be improved for all except the m=128, n=8, k=128 case, where it is roughly neutral. This case motivated the addition of the "first-tier tail fixup" in the dot kernel. Pull Request resolved: pytorch#127451 Approved by: https://github.com/malfet ghstack dependencies: pytorch#127435
…32_arith (pytorch#127476) Summary: Preparing to generalize to bf16. (This should not be committed unless the following bf16 PR is committed!) Test Plan: Spot-checked llm_experiments benchmark result to make sure it didn't regress. Pull Request resolved: pytorch#127476 Approved by: https://github.com/malfet ghstack dependencies: pytorch#127435, pytorch#127451
Stack from ghstack (oldest at bottom):
Summary: Refactoring step before we attempt to use these to implement a less bad fp16 GEMM.
Test Plan: Existing tests.