Skip to content

Conversation

[ghstack-poisoned]
Copy link

pytorch-bot bot commented May 29, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/127435

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit afa8024 with merge base 9a8e810 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@swolchok swolchok requested a review from malfet May 29, 2024 19:56
Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, though I was always bother by using signed types as array indices... Consider changing those, if possible

@swolchok
Copy link
Contributor Author

LGTM, though I was always bother by using signed types as array indices... Consider changing those, if possible

signed types are generally better because overflow is UB, so the compiler can assume that overflow will not happen and optimize accordingly. unsigned is for when you really need the extra one (1) bit. (see e.g. https://youtu.be/yG1OZ69H_-o?t=2357&si=HLuIdAnutC4ZhfWb)

pytorchmergebot pushed a commit that referenced this pull request May 30, 2024
…7451)

Summary: This doesn't change the overall gemm algorithm away from repeated dot products, just uses our efficient fp16 dot product developed for the gemv case. It seems to improve performance for every prompt length I tested.

Test Plan: Use https://github.com/malfet/llm_experiments/blob/main/benchmarks/benchmark_torch_mm.py , edited to test only the trans_b (really gemm_transa_) case for the sizes outlined in the output.

Before:
```
Matrix-vector:
m=8, n=128, k=1
====================
trans_b  torch.float32    1.05 usec
trans_b  torch.float16    0.97 usec
trans_b torch.bfloat16    1.06 usec
m=128, n=8, k=1
====================
trans_b  torch.float32    0.80 usec
trans_b  torch.float16    0.97 usec
trans_b torch.bfloat16    1.00 usec
m=4096, n=4096, k=1
====================
trans_b  torch.float32 2160.75 usec
trans_b  torch.float16  659.77 usec
trans_b torch.bfloat16 3800.13 usec
m=11008, n=4096, k=1
====================
trans_b  torch.float32 6343.68 usec
trans_b  torch.float16 1789.42 usec
trans_b torch.bfloat16 10098.34 usec
m=4096, n=11008, k=1
====================
trans_b  torch.float32 6217.20 usec
trans_b  torch.float16 1874.47 usec
trans_b torch.bfloat16 10490.30 usec
m=32000, n=4096, k=1
====================
trans_b  torch.float32 17934.45 usec
trans_b  torch.float16 5323.81 usec
trans_b torch.bfloat16 29320.80 usec

Matrix-matrix (prompt len 4:
m=8, n=128, k=4
====================
trans_b  torch.float32    2.40 usec
trans_b  torch.float16    1.22 usec
trans_b torch.bfloat16    1.22 usec
m=128, n=8, k=4
====================
trans_b  torch.float32    1.52 usec
trans_b  torch.float16    1.33 usec
trans_b torch.bfloat16    1.77 usec
m=4096, n=4096, k=4
====================
trans_b  torch.float32 4317.09 usec
trans_b  torch.float16 15541.04 usec
trans_b torch.bfloat16 15032.29 usec
m=11008, n=4096, k=4
====================
trans_b  torch.float32 6191.19 usec
trans_b  torch.float16 40436.29 usec
trans_b torch.bfloat16 40626.93 usec
m=4096, n=11008, k=4
====================
trans_b  torch.float32 6049.22 usec
trans_b  torch.float16 42367.16 usec
trans_b torch.bfloat16 42482.43 usec
m=32000, n=4096, k=4
====================
trans_b  torch.float32 17611.36 usec
trans_b  torch.float16 117368.54 usec
trans_b torch.bfloat16 116958.85 usec

Matrix-matrix (prompt len 8:
m=8, n=128, k=8
====================
trans_b  torch.float32    1.04 usec
trans_b  torch.float16    1.71 usec
trans_b torch.bfloat16    1.74 usec
m=128, n=8, k=8
====================
trans_b  torch.float32    2.10 usec
trans_b  torch.float16    2.01 usec
trans_b torch.bfloat16    2.91 usec
m=4096, n=4096, k=8
====================
trans_b  torch.float32 2456.23 usec
trans_b  torch.float16 30112.76 usec
trans_b torch.bfloat16 29941.58 usec
m=11008, n=4096, k=8
====================
trans_b  torch.float32 6236.12 usec
trans_b  torch.float16 80361.22 usec
trans_b torch.bfloat16 80466.64 usec
m=4096, n=11008, k=8
====================
trans_b  torch.float32 6236.10 usec
trans_b  torch.float16 82990.74 usec
trans_b torch.bfloat16 83899.80 usec
m=32000, n=4096, k=8
====================
trans_b  torch.float32 17606.43 usec
trans_b  torch.float16 234397.38 usec
trans_b torch.bfloat16 237057.29 usec

Matrix-matrix (prompt len 16:
m=8, n=128, k=16
====================
trans_b  torch.float32    1.31 usec
trans_b  torch.float16    2.67 usec
trans_b torch.bfloat16    2.72 usec
m=128, n=8, k=16
====================
trans_b  torch.float32    1.66 usec
trans_b  torch.float16    3.36 usec
trans_b torch.bfloat16    5.18 usec
m=4096, n=4096, k=16
====================
trans_b  torch.float32 2504.24 usec
trans_b  torch.float16 60896.53 usec
trans_b torch.bfloat16 59852.49 usec
m=11008, n=4096, k=16
====================
trans_b  torch.float32 6407.11 usec
trans_b  torch.float16 163294.92 usec
trans_b torch.bfloat16 161199.10 usec
m=4096, n=11008, k=16
====================
trans_b  torch.float32 6132.30 usec
trans_b  torch.float16 167244.77 usec
trans_b torch.bfloat16 170064.35 usec
m=32000, n=4096, k=16
====================
trans_b  torch.float32 17635.56 usec
trans_b  torch.float16 475020.00 usec
trans_b torch.bfloat16 476332.29 usec

Matrix-matrix (prompt len 32:
m=8, n=128, k=32
====================
trans_b  torch.float32    1.40 usec
trans_b  torch.float16    4.67 usec
trans_b torch.bfloat16    4.80 usec
m=128, n=8, k=32
====================
trans_b  torch.float32    1.24 usec
trans_b  torch.float16    6.10 usec
trans_b torch.bfloat16   10.03 usec
m=4096, n=4096, k=32
====================
trans_b  torch.float32 2660.63 usec
trans_b  torch.float16 122436.04 usec
trans_b torch.bfloat16 121687.96 usec
m=11008, n=4096, k=32
====================
trans_b  torch.float32 6405.60 usec
trans_b  torch.float16 324708.42 usec
trans_b torch.bfloat16 324866.67 usec
m=4096, n=11008, k=32
====================
trans_b  torch.float32 6566.74 usec
trans_b  torch.float16 330801.04 usec
trans_b torch.bfloat16 332561.79 usec
m=32000, n=4096, k=32
====================
trans_b  torch.float32 18610.84 usec
trans_b  torch.float16 944578.75 usec
trans_b torch.bfloat16 940674.33 usec

Matrix-matrix (prompt len 128:
m=8, n=128, k=128
====================
trans_b  torch.float32    2.48 usec
trans_b  torch.float16   16.43 usec
trans_b torch.bfloat16   17.11 usec
m=128, n=8, k=128
====================
trans_b  torch.float32    1.83 usec
trans_b  torch.float16   22.31 usec
trans_b torch.bfloat16   37.00 usec
m=4096, n=4096, k=128
====================
trans_b  torch.float32 4806.59 usec
trans_b  torch.float16 485338.83 usec
trans_b torch.bfloat16 478835.08 usec
m=11008, n=4096, k=128
====================
trans_b  torch.float32 12109.51 usec
trans_b  torch.float16 1300928.58 usec
trans_b torch.bfloat16 1293181.63 usec
m=4096, n=11008, k=128
====================
trans_b  torch.float32 11223.70 usec
trans_b  torch.float16 1326119.92 usec
trans_b torch.bfloat16 1330395.12 usec
m=32000, n=4096, k=128
====================
trans_b  torch.float32 33485.34 usec
trans_b  torch.float16 3869227.17 usec
trans_b torch.bfloat16 3792905.00 usec
```

After:
```
Matrix-vector:
m=8, n=128, k=1
====================
trans_b  torch.float32    0.75 usec
trans_b  torch.float16    0.71 usec
trans_b torch.bfloat16    0.81 usec
m=128, n=8, k=1
====================
trans_b  torch.float32    0.75 usec
trans_b  torch.float16    0.93 usec
trans_b torch.bfloat16    0.98 usec
m=4096, n=4096, k=1
====================
trans_b  torch.float32 2194.31 usec
trans_b  torch.float16  661.27 usec
trans_b torch.bfloat16 3758.42 usec
m=11008, n=4096, k=1
====================
trans_b  torch.float32 5792.04 usec
trans_b  torch.float16 1789.98 usec
trans_b torch.bfloat16 10120.67 usec
m=4096, n=11008, k=1
====================
trans_b  torch.float32 6101.22 usec
trans_b  torch.float16 1927.34 usec
trans_b torch.bfloat16 10469.47 usec
m=32000, n=4096, k=1
====================
trans_b  torch.float32 18353.20 usec
trans_b  torch.float16 5161.06 usec
trans_b torch.bfloat16 29601.69 usec

Matrix-matrix (prompt len 4:
m=8, n=128, k=4
====================
trans_b  torch.float32    2.14 usec
trans_b  torch.float16    0.85 usec
trans_b torch.bfloat16    1.19 usec
m=128, n=8, k=4
====================
trans_b  torch.float32    1.47 usec
trans_b  torch.float16    1.85 usec
trans_b torch.bfloat16    1.75 usec
m=4096, n=4096, k=4
====================
trans_b  torch.float32 4416.40 usec
trans_b  torch.float16 2688.36 usec
trans_b torch.bfloat16 14987.33 usec
m=11008, n=4096, k=4
====================
trans_b  torch.float32 6140.24 usec
trans_b  torch.float16 7467.26 usec
trans_b torch.bfloat16 40295.52 usec
m=4096, n=11008, k=4
====================
trans_b  torch.float32 6143.10 usec
trans_b  torch.float16 7298.04 usec
trans_b torch.bfloat16 41393.43 usec
m=32000, n=4096, k=4
====================
trans_b  torch.float32 17650.72 usec
trans_b  torch.float16 21346.63 usec
trans_b torch.bfloat16 116849.98 usec

Matrix-matrix (prompt len 8:
m=8, n=128, k=8
====================
trans_b  torch.float32    1.05 usec
trans_b  torch.float16    1.03 usec
trans_b torch.bfloat16    1.69 usec
m=128, n=8, k=8
====================
trans_b  torch.float32    2.05 usec
trans_b  torch.float16    3.08 usec
trans_b torch.bfloat16    2.95 usec
m=4096, n=4096, k=8
====================
trans_b  torch.float32 2323.99 usec
trans_b  torch.float16 5265.45 usec
trans_b torch.bfloat16 29942.40 usec
m=11008, n=4096, k=8
====================
trans_b  torch.float32 6202.01 usec
trans_b  torch.float16 14677.90 usec
trans_b torch.bfloat16 80625.18 usec
m=4096, n=11008, k=8
====================
trans_b  torch.float32 6112.05 usec
trans_b  torch.float16 14340.52 usec
trans_b torch.bfloat16 82799.99 usec
m=32000, n=4096, k=8
====================
trans_b  torch.float32 17650.65 usec
trans_b  torch.float16 42551.43 usec
trans_b torch.bfloat16 236081.08 usec

Matrix-matrix (prompt len 16:
m=8, n=128, k=16
====================
trans_b  torch.float32    1.26 usec
trans_b  torch.float16    1.34 usec
trans_b torch.bfloat16    2.69 usec
m=128, n=8, k=16
====================
trans_b  torch.float32    1.60 usec
trans_b  torch.float16    5.81 usec
trans_b torch.bfloat16    5.34 usec
m=4096, n=4096, k=16
====================
trans_b  torch.float32 2328.05 usec
trans_b  torch.float16 10526.58 usec
trans_b torch.bfloat16 60028.28 usec
m=11008, n=4096, k=16
====================
trans_b  torch.float32 6243.35 usec
trans_b  torch.float16 28505.08 usec
trans_b torch.bfloat16 163670.15 usec
m=4096, n=11008, k=16
====================
trans_b  torch.float32 5870.11 usec
trans_b  torch.float16 28597.89 usec
trans_b torch.bfloat16 165404.88 usec
m=32000, n=4096, k=16
====================
trans_b  torch.float32 17746.27 usec
trans_b  torch.float16 83393.87 usec
trans_b torch.bfloat16 472313.13 usec

Matrix-matrix (prompt len 32:
m=8, n=128, k=32
====================
trans_b  torch.float32    1.35 usec
trans_b  torch.float16    2.01 usec
trans_b torch.bfloat16    4.68 usec
m=128, n=8, k=32
====================
trans_b  torch.float32    1.19 usec
trans_b  torch.float16   10.98 usec
trans_b torch.bfloat16   10.13 usec
m=4096, n=4096, k=32
====================
trans_b  torch.float32 2525.29 usec
trans_b  torch.float16 23106.71 usec
trans_b torch.bfloat16 122987.04 usec
m=11008, n=4096, k=32
====================
trans_b  torch.float32 6131.34 usec
trans_b  torch.float16 57537.41 usec
trans_b torch.bfloat16 327825.00 usec
m=4096, n=11008, k=32
====================
trans_b  torch.float32 6395.01 usec
trans_b  torch.float16 57456.33 usec
trans_b torch.bfloat16 331325.58 usec
m=32000, n=4096, k=32
====================
trans_b  torch.float32 19078.68 usec
trans_b  torch.float16 167735.08 usec
trans_b torch.bfloat16 975736.88 usec

Matrix-matrix (prompt len 128:
m=8, n=128, k=128
====================
trans_b  torch.float32    2.40 usec
trans_b  torch.float16    6.07 usec
trans_b torch.bfloat16   16.83 usec
m=128, n=8, k=128
====================
trans_b  torch.float32    1.78 usec
trans_b  torch.float16   40.35 usec
trans_b torch.bfloat16   37.21 usec
m=4096, n=4096, k=128
====================
trans_b  torch.float32 4827.60 usec
trans_b  torch.float16 84341.24 usec
trans_b torch.bfloat16 478917.75 usec
m=11008, n=4096, k=128
====================
trans_b  torch.float32 11879.96 usec
trans_b  torch.float16 226484.33 usec
trans_b torch.bfloat16 1289465.50 usec
m=4096, n=11008, k=128
====================
trans_b  torch.float32 10707.75 usec
trans_b  torch.float16 229200.58 usec
trans_b torch.bfloat16 1327416.67 usec
m=32000, n=4096, k=128
====================
trans_b  torch.float32 33306.32 usec
trans_b  torch.float16 662898.21 usec
trans_b torch.bfloat16 3815866.63 usec
```

torch.float16 performance seems to be improved for all except the
m=128, n=8, k=128 case, where it is roughly neutral. This case
motivated the addition of the "first-tier tail fixup" in the dot
kernel.

Pull Request resolved: #127451
Approved by: https://github.com/malfet
ghstack dependencies: #127435
pytorchmergebot pushed a commit that referenced this pull request May 30, 2024
…32_arith (#127476)

Summary: Preparing to generalize to bf16. (This should not be committed unless the following bf16 PR is committed!)

Test Plan: Spot-checked llm_experiments benchmark result to make sure it didn't regress.
Pull Request resolved: #127476
Approved by: https://github.com/malfet
ghstack dependencies: #127435, #127451
petrex pushed a commit to petrex/pytorch that referenced this pull request Jun 5, 2024
…rch#127435)

Summary: Refactoring step before we attempt to use these to implement a less bad fp16 GEMM.

Test Plan: Existing tests.
Pull Request resolved: pytorch#127435
Approved by: https://github.com/malfet
petrex pushed a commit to petrex/pytorch that referenced this pull request Jun 5, 2024
…orch#127451)

Summary: This doesn't change the overall gemm algorithm away from repeated dot products, just uses our efficient fp16 dot product developed for the gemv case. It seems to improve performance for every prompt length I tested.

Test Plan: Use https://github.com/malfet/llm_experiments/blob/main/benchmarks/benchmark_torch_mm.py , edited to test only the trans_b (really gemm_transa_) case for the sizes outlined in the output.

Before:
```
Matrix-vector:
m=8, n=128, k=1
====================
trans_b  torch.float32    1.05 usec
trans_b  torch.float16    0.97 usec
trans_b torch.bfloat16    1.06 usec
m=128, n=8, k=1
====================
trans_b  torch.float32    0.80 usec
trans_b  torch.float16    0.97 usec
trans_b torch.bfloat16    1.00 usec
m=4096, n=4096, k=1
====================
trans_b  torch.float32 2160.75 usec
trans_b  torch.float16  659.77 usec
trans_b torch.bfloat16 3800.13 usec
m=11008, n=4096, k=1
====================
trans_b  torch.float32 6343.68 usec
trans_b  torch.float16 1789.42 usec
trans_b torch.bfloat16 10098.34 usec
m=4096, n=11008, k=1
====================
trans_b  torch.float32 6217.20 usec
trans_b  torch.float16 1874.47 usec
trans_b torch.bfloat16 10490.30 usec
m=32000, n=4096, k=1
====================
trans_b  torch.float32 17934.45 usec
trans_b  torch.float16 5323.81 usec
trans_b torch.bfloat16 29320.80 usec

Matrix-matrix (prompt len 4:
m=8, n=128, k=4
====================
trans_b  torch.float32    2.40 usec
trans_b  torch.float16    1.22 usec
trans_b torch.bfloat16    1.22 usec
m=128, n=8, k=4
====================
trans_b  torch.float32    1.52 usec
trans_b  torch.float16    1.33 usec
trans_b torch.bfloat16    1.77 usec
m=4096, n=4096, k=4
====================
trans_b  torch.float32 4317.09 usec
trans_b  torch.float16 15541.04 usec
trans_b torch.bfloat16 15032.29 usec
m=11008, n=4096, k=4
====================
trans_b  torch.float32 6191.19 usec
trans_b  torch.float16 40436.29 usec
trans_b torch.bfloat16 40626.93 usec
m=4096, n=11008, k=4
====================
trans_b  torch.float32 6049.22 usec
trans_b  torch.float16 42367.16 usec
trans_b torch.bfloat16 42482.43 usec
m=32000, n=4096, k=4
====================
trans_b  torch.float32 17611.36 usec
trans_b  torch.float16 117368.54 usec
trans_b torch.bfloat16 116958.85 usec

Matrix-matrix (prompt len 8:
m=8, n=128, k=8
====================
trans_b  torch.float32    1.04 usec
trans_b  torch.float16    1.71 usec
trans_b torch.bfloat16    1.74 usec
m=128, n=8, k=8
====================
trans_b  torch.float32    2.10 usec
trans_b  torch.float16    2.01 usec
trans_b torch.bfloat16    2.91 usec
m=4096, n=4096, k=8
====================
trans_b  torch.float32 2456.23 usec
trans_b  torch.float16 30112.76 usec
trans_b torch.bfloat16 29941.58 usec
m=11008, n=4096, k=8
====================
trans_b  torch.float32 6236.12 usec
trans_b  torch.float16 80361.22 usec
trans_b torch.bfloat16 80466.64 usec
m=4096, n=11008, k=8
====================
trans_b  torch.float32 6236.10 usec
trans_b  torch.float16 82990.74 usec
trans_b torch.bfloat16 83899.80 usec
m=32000, n=4096, k=8
====================
trans_b  torch.float32 17606.43 usec
trans_b  torch.float16 234397.38 usec
trans_b torch.bfloat16 237057.29 usec

Matrix-matrix (prompt len 16:
m=8, n=128, k=16
====================
trans_b  torch.float32    1.31 usec
trans_b  torch.float16    2.67 usec
trans_b torch.bfloat16    2.72 usec
m=128, n=8, k=16
====================
trans_b  torch.float32    1.66 usec
trans_b  torch.float16    3.36 usec
trans_b torch.bfloat16    5.18 usec
m=4096, n=4096, k=16
====================
trans_b  torch.float32 2504.24 usec
trans_b  torch.float16 60896.53 usec
trans_b torch.bfloat16 59852.49 usec
m=11008, n=4096, k=16
====================
trans_b  torch.float32 6407.11 usec
trans_b  torch.float16 163294.92 usec
trans_b torch.bfloat16 161199.10 usec
m=4096, n=11008, k=16
====================
trans_b  torch.float32 6132.30 usec
trans_b  torch.float16 167244.77 usec
trans_b torch.bfloat16 170064.35 usec
m=32000, n=4096, k=16
====================
trans_b  torch.float32 17635.56 usec
trans_b  torch.float16 475020.00 usec
trans_b torch.bfloat16 476332.29 usec

Matrix-matrix (prompt len 32:
m=8, n=128, k=32
====================
trans_b  torch.float32    1.40 usec
trans_b  torch.float16    4.67 usec
trans_b torch.bfloat16    4.80 usec
m=128, n=8, k=32
====================
trans_b  torch.float32    1.24 usec
trans_b  torch.float16    6.10 usec
trans_b torch.bfloat16   10.03 usec
m=4096, n=4096, k=32
====================
trans_b  torch.float32 2660.63 usec
trans_b  torch.float16 122436.04 usec
trans_b torch.bfloat16 121687.96 usec
m=11008, n=4096, k=32
====================
trans_b  torch.float32 6405.60 usec
trans_b  torch.float16 324708.42 usec
trans_b torch.bfloat16 324866.67 usec
m=4096, n=11008, k=32
====================
trans_b  torch.float32 6566.74 usec
trans_b  torch.float16 330801.04 usec
trans_b torch.bfloat16 332561.79 usec
m=32000, n=4096, k=32
====================
trans_b  torch.float32 18610.84 usec
trans_b  torch.float16 944578.75 usec
trans_b torch.bfloat16 940674.33 usec

Matrix-matrix (prompt len 128:
m=8, n=128, k=128
====================
trans_b  torch.float32    2.48 usec
trans_b  torch.float16   16.43 usec
trans_b torch.bfloat16   17.11 usec
m=128, n=8, k=128
====================
trans_b  torch.float32    1.83 usec
trans_b  torch.float16   22.31 usec
trans_b torch.bfloat16   37.00 usec
m=4096, n=4096, k=128
====================
trans_b  torch.float32 4806.59 usec
trans_b  torch.float16 485338.83 usec
trans_b torch.bfloat16 478835.08 usec
m=11008, n=4096, k=128
====================
trans_b  torch.float32 12109.51 usec
trans_b  torch.float16 1300928.58 usec
trans_b torch.bfloat16 1293181.63 usec
m=4096, n=11008, k=128
====================
trans_b  torch.float32 11223.70 usec
trans_b  torch.float16 1326119.92 usec
trans_b torch.bfloat16 1330395.12 usec
m=32000, n=4096, k=128
====================
trans_b  torch.float32 33485.34 usec
trans_b  torch.float16 3869227.17 usec
trans_b torch.bfloat16 3792905.00 usec
```

After:
```
Matrix-vector:
m=8, n=128, k=1
====================
trans_b  torch.float32    0.75 usec
trans_b  torch.float16    0.71 usec
trans_b torch.bfloat16    0.81 usec
m=128, n=8, k=1
====================
trans_b  torch.float32    0.75 usec
trans_b  torch.float16    0.93 usec
trans_b torch.bfloat16    0.98 usec
m=4096, n=4096, k=1
====================
trans_b  torch.float32 2194.31 usec
trans_b  torch.float16  661.27 usec
trans_b torch.bfloat16 3758.42 usec
m=11008, n=4096, k=1
====================
trans_b  torch.float32 5792.04 usec
trans_b  torch.float16 1789.98 usec
trans_b torch.bfloat16 10120.67 usec
m=4096, n=11008, k=1
====================
trans_b  torch.float32 6101.22 usec
trans_b  torch.float16 1927.34 usec
trans_b torch.bfloat16 10469.47 usec
m=32000, n=4096, k=1
====================
trans_b  torch.float32 18353.20 usec
trans_b  torch.float16 5161.06 usec
trans_b torch.bfloat16 29601.69 usec

Matrix-matrix (prompt len 4:
m=8, n=128, k=4
====================
trans_b  torch.float32    2.14 usec
trans_b  torch.float16    0.85 usec
trans_b torch.bfloat16    1.19 usec
m=128, n=8, k=4
====================
trans_b  torch.float32    1.47 usec
trans_b  torch.float16    1.85 usec
trans_b torch.bfloat16    1.75 usec
m=4096, n=4096, k=4
====================
trans_b  torch.float32 4416.40 usec
trans_b  torch.float16 2688.36 usec
trans_b torch.bfloat16 14987.33 usec
m=11008, n=4096, k=4
====================
trans_b  torch.float32 6140.24 usec
trans_b  torch.float16 7467.26 usec
trans_b torch.bfloat16 40295.52 usec
m=4096, n=11008, k=4
====================
trans_b  torch.float32 6143.10 usec
trans_b  torch.float16 7298.04 usec
trans_b torch.bfloat16 41393.43 usec
m=32000, n=4096, k=4
====================
trans_b  torch.float32 17650.72 usec
trans_b  torch.float16 21346.63 usec
trans_b torch.bfloat16 116849.98 usec

Matrix-matrix (prompt len 8:
m=8, n=128, k=8
====================
trans_b  torch.float32    1.05 usec
trans_b  torch.float16    1.03 usec
trans_b torch.bfloat16    1.69 usec
m=128, n=8, k=8
====================
trans_b  torch.float32    2.05 usec
trans_b  torch.float16    3.08 usec
trans_b torch.bfloat16    2.95 usec
m=4096, n=4096, k=8
====================
trans_b  torch.float32 2323.99 usec
trans_b  torch.float16 5265.45 usec
trans_b torch.bfloat16 29942.40 usec
m=11008, n=4096, k=8
====================
trans_b  torch.float32 6202.01 usec
trans_b  torch.float16 14677.90 usec
trans_b torch.bfloat16 80625.18 usec
m=4096, n=11008, k=8
====================
trans_b  torch.float32 6112.05 usec
trans_b  torch.float16 14340.52 usec
trans_b torch.bfloat16 82799.99 usec
m=32000, n=4096, k=8
====================
trans_b  torch.float32 17650.65 usec
trans_b  torch.float16 42551.43 usec
trans_b torch.bfloat16 236081.08 usec

Matrix-matrix (prompt len 16:
m=8, n=128, k=16
====================
trans_b  torch.float32    1.26 usec
trans_b  torch.float16    1.34 usec
trans_b torch.bfloat16    2.69 usec
m=128, n=8, k=16
====================
trans_b  torch.float32    1.60 usec
trans_b  torch.float16    5.81 usec
trans_b torch.bfloat16    5.34 usec
m=4096, n=4096, k=16
====================
trans_b  torch.float32 2328.05 usec
trans_b  torch.float16 10526.58 usec
trans_b torch.bfloat16 60028.28 usec
m=11008, n=4096, k=16
====================
trans_b  torch.float32 6243.35 usec
trans_b  torch.float16 28505.08 usec
trans_b torch.bfloat16 163670.15 usec
m=4096, n=11008, k=16
====================
trans_b  torch.float32 5870.11 usec
trans_b  torch.float16 28597.89 usec
trans_b torch.bfloat16 165404.88 usec
m=32000, n=4096, k=16
====================
trans_b  torch.float32 17746.27 usec
trans_b  torch.float16 83393.87 usec
trans_b torch.bfloat16 472313.13 usec

Matrix-matrix (prompt len 32:
m=8, n=128, k=32
====================
trans_b  torch.float32    1.35 usec
trans_b  torch.float16    2.01 usec
trans_b torch.bfloat16    4.68 usec
m=128, n=8, k=32
====================
trans_b  torch.float32    1.19 usec
trans_b  torch.float16   10.98 usec
trans_b torch.bfloat16   10.13 usec
m=4096, n=4096, k=32
====================
trans_b  torch.float32 2525.29 usec
trans_b  torch.float16 23106.71 usec
trans_b torch.bfloat16 122987.04 usec
m=11008, n=4096, k=32
====================
trans_b  torch.float32 6131.34 usec
trans_b  torch.float16 57537.41 usec
trans_b torch.bfloat16 327825.00 usec
m=4096, n=11008, k=32
====================
trans_b  torch.float32 6395.01 usec
trans_b  torch.float16 57456.33 usec
trans_b torch.bfloat16 331325.58 usec
m=32000, n=4096, k=32
====================
trans_b  torch.float32 19078.68 usec
trans_b  torch.float16 167735.08 usec
trans_b torch.bfloat16 975736.88 usec

Matrix-matrix (prompt len 128:
m=8, n=128, k=128
====================
trans_b  torch.float32    2.40 usec
trans_b  torch.float16    6.07 usec
trans_b torch.bfloat16   16.83 usec
m=128, n=8, k=128
====================
trans_b  torch.float32    1.78 usec
trans_b  torch.float16   40.35 usec
trans_b torch.bfloat16   37.21 usec
m=4096, n=4096, k=128
====================
trans_b  torch.float32 4827.60 usec
trans_b  torch.float16 84341.24 usec
trans_b torch.bfloat16 478917.75 usec
m=11008, n=4096, k=128
====================
trans_b  torch.float32 11879.96 usec
trans_b  torch.float16 226484.33 usec
trans_b torch.bfloat16 1289465.50 usec
m=4096, n=11008, k=128
====================
trans_b  torch.float32 10707.75 usec
trans_b  torch.float16 229200.58 usec
trans_b torch.bfloat16 1327416.67 usec
m=32000, n=4096, k=128
====================
trans_b  torch.float32 33306.32 usec
trans_b  torch.float16 662898.21 usec
trans_b torch.bfloat16 3815866.63 usec
```

torch.float16 performance seems to be improved for all except the
m=128, n=8, k=128 case, where it is roughly neutral. This case
motivated the addition of the "first-tier tail fixup" in the dot
kernel.

Pull Request resolved: pytorch#127451
Approved by: https://github.com/malfet
ghstack dependencies: pytorch#127435
petrex pushed a commit to petrex/pytorch that referenced this pull request Jun 5, 2024
…32_arith (pytorch#127476)

Summary: Preparing to generalize to bf16. (This should not be committed unless the following bf16 PR is committed!)

Test Plan: Spot-checked llm_experiments benchmark result to make sure it didn't regress.
Pull Request resolved: pytorch#127476
Approved by: https://github.com/malfet
ghstack dependencies: pytorch#127435, pytorch#127451
@github-actions github-actions bot deleted the gh/swolchok/632/head branch June 30, 2024 02:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants