Skip to content

Conversation

swolchok
Copy link
Contributor

@swolchok swolchok commented May 29, 2024

Stack from ghstack (oldest at bottom):

Summary: This gets us a baseline level of reasonable performance for
bfloat16 matrix-vector and matrix-matrix multiplication on my Apple
M1. I've intentionally left using intrinsics for future work.

Test Plan: Used
https://github.com/malfet/llm_experiments/blob/main/benchmarks/benchmark_torch_mm.py
(modified to run larger sizes) to benchmark a range of LLM-interesting
matrix-vector and matrix-matrix sizes on my Apple M1 Pro. bfloat16 performance is
improved across the board (except possibly for very small cases) and
now exceeds float32 performance (as it should) for the matrix-vector
cases.

Before:

Matrix-vector:
m=8, n=128, k=1
====================
trans_b  torch.float32    0.75 usec
trans_b  torch.float16    0.71 usec
trans_b torch.bfloat16    0.81 usec
m=128, n=8, k=1
====================
trans_b  torch.float32    0.75 usec
trans_b  torch.float16    0.93 usec
trans_b torch.bfloat16    0.98 usec
m=4096, n=4096, k=1
====================
trans_b  torch.float32 2194.31 usec
trans_b  torch.float16  661.27 usec
trans_b torch.bfloat16 3758.42 usec
m=11008, n=4096, k=1
====================
trans_b  torch.float32 5792.04 usec
trans_b  torch.float16 1789.98 usec
trans_b torch.bfloat16 10120.67 usec
m=4096, n=11008, k=1
====================
trans_b  torch.float32 6101.22 usec
trans_b  torch.float16 1927.34 usec
trans_b torch.bfloat16 10469.47 usec
m=32000, n=4096, k=1
====================
trans_b  torch.float32 18353.20 usec
trans_b  torch.float16 5161.06 usec
trans_b torch.bfloat16 29601.69 usec

Matrix-matrix (prompt len 4:
m=8, n=128, k=4
====================
trans_b  torch.float32    2.14 usec
trans_b  torch.float16    0.85 usec
trans_b torch.bfloat16    1.19 usec
m=128, n=8, k=4
====================
trans_b  torch.float32    1.47 usec
trans_b  torch.float16    1.85 usec
trans_b torch.bfloat16    1.75 usec
m=4096, n=4096, k=4
====================
trans_b  torch.float32 4416.40 usec
trans_b  torch.float16 2688.36 usec
trans_b torch.bfloat16 14987.33 usec
m=11008, n=4096, k=4
====================
trans_b  torch.float32 6140.24 usec
trans_b  torch.float16 7467.26 usec
trans_b torch.bfloat16 40295.52 usec
m=4096, n=11008, k=4
====================
trans_b  torch.float32 6143.10 usec
trans_b  torch.float16 7298.04 usec
trans_b torch.bfloat16 41393.43 usec
m=32000, n=4096, k=4
====================
trans_b  torch.float32 17650.72 usec
trans_b  torch.float16 21346.63 usec
trans_b torch.bfloat16 116849.98 usec

Matrix-matrix (prompt len 8:
m=8, n=128, k=8
====================
trans_b  torch.float32    1.05 usec
trans_b  torch.float16    1.03 usec
trans_b torch.bfloat16    1.69 usec
m=128, n=8, k=8
====================
trans_b  torch.float32    2.05 usec
trans_b  torch.float16    3.08 usec
trans_b torch.bfloat16    2.95 usec
m=4096, n=4096, k=8
====================
trans_b  torch.float32 2323.99 usec
trans_b  torch.float16 5265.45 usec
trans_b torch.bfloat16 29942.40 usec
m=11008, n=4096, k=8
====================
trans_b  torch.float32 6202.01 usec
trans_b  torch.float16 14677.90 usec
trans_b torch.bfloat16 80625.18 usec
m=4096, n=11008, k=8
====================
trans_b  torch.float32 6112.05 usec
trans_b  torch.float16 14340.52 usec
trans_b torch.bfloat16 82799.99 usec
m=32000, n=4096, k=8
====================
trans_b  torch.float32 17650.65 usec
trans_b  torch.float16 42551.43 usec
trans_b torch.bfloat16 236081.08 usec

Matrix-matrix (prompt len 16:
m=8, n=128, k=16
====================
trans_b  torch.float32    1.26 usec
trans_b  torch.float16    1.34 usec
trans_b torch.bfloat16    2.69 usec
m=128, n=8, k=16
====================
trans_b  torch.float32    1.60 usec
trans_b  torch.float16    5.81 usec
trans_b torch.bfloat16    5.34 usec
m=4096, n=4096, k=16
====================
trans_b  torch.float32 2328.05 usec
trans_b  torch.float16 10526.58 usec
trans_b torch.bfloat16 60028.28 usec
m=11008, n=4096, k=16
====================
trans_b  torch.float32 6243.35 usec
trans_b  torch.float16 28505.08 usec
trans_b torch.bfloat16 163670.15 usec
m=4096, n=11008, k=16
====================
trans_b  torch.float32 5870.11 usec
trans_b  torch.float16 28597.89 usec
trans_b torch.bfloat16 165404.88 usec
m=32000, n=4096, k=16
====================
trans_b  torch.float32 17746.27 usec
trans_b  torch.float16 83393.87 usec
trans_b torch.bfloat16 472313.13 usec

Matrix-matrix (prompt len 32:
m=8, n=128, k=32
====================
trans_b  torch.float32    1.35 usec
trans_b  torch.float16    2.01 usec
trans_b torch.bfloat16    4.68 usec
m=128, n=8, k=32
====================
trans_b  torch.float32    1.19 usec
trans_b  torch.float16   10.98 usec
trans_b torch.bfloat16   10.13 usec
m=4096, n=4096, k=32
====================
trans_b  torch.float32 2525.29 usec
trans_b  torch.float16 23106.71 usec
trans_b torch.bfloat16 122987.04 usec
m=11008, n=4096, k=32
====================
trans_b  torch.float32 6131.34 usec
trans_b  torch.float16 57537.41 usec
trans_b torch.bfloat16 327825.00 usec
m=4096, n=11008, k=32
====================
trans_b  torch.float32 6395.01 usec
trans_b  torch.float16 57456.33 usec
trans_b torch.bfloat16 331325.58 usec
m=32000, n=4096, k=32
====================
trans_b  torch.float32 19078.68 usec
trans_b  torch.float16 167735.08 usec
trans_b torch.bfloat16 975736.88 usec

Matrix-matrix (prompt len 128:
m=8, n=128, k=128
====================
trans_b  torch.float32    2.40 usec
trans_b  torch.float16    6.07 usec
trans_b torch.bfloat16   16.83 usec
m=128, n=8, k=128
====================
trans_b  torch.float32    1.78 usec
trans_b  torch.float16   40.35 usec
trans_b torch.bfloat16   37.21 usec
m=4096, n=4096, k=128
====================
trans_b  torch.float32 4827.60 usec
trans_b  torch.float16 84341.24 usec
trans_b torch.bfloat16 478917.75 usec
m=11008, n=4096, k=128
====================
trans_b  torch.float32 11879.96 usec
trans_b  torch.float16 226484.33 usec
trans_b torch.bfloat16 1289465.50 usec
m=4096, n=11008, k=128
====================
trans_b  torch.float32 10707.75 usec
trans_b  torch.float16 229200.58 usec
trans_b torch.bfloat16 1327416.67 usec
m=32000, n=4096, k=128
====================
trans_b  torch.float32 33306.32 usec
trans_b  torch.float16 662898.21 usec
trans_b torch.bfloat16 3815866.63 usec

After:

Matrix-vector:
m=8, n=128, k=1
====================
trans_b  torch.float32    0.77 usec
trans_b  torch.float16    0.72 usec
trans_b torch.bfloat16    0.77 usec
m=128, n=8, k=1
====================
trans_b  torch.float32    0.73 usec
trans_b  torch.float16    0.93 usec
trans_b torch.bfloat16    1.56 usec
m=4096, n=4096, k=1
====================
trans_b  torch.float32 2195.22 usec
trans_b  torch.float16  675.40 usec
trans_b torch.bfloat16 1038.29 usec
m=11008, n=4096, k=1
====================
trans_b  torch.float32 5980.27 usec
trans_b  torch.float16 1806.08 usec
trans_b torch.bfloat16 2756.46 usec
m=4096, n=11008, k=1
====================
trans_b  torch.float32 6339.95 usec
trans_b  torch.float16 1844.71 usec
trans_b torch.bfloat16 2726.52 usec
m=32000, n=4096, k=1
====================
trans_b  torch.float32 18137.17 usec
trans_b  torch.float16 6020.75 usec
trans_b torch.bfloat16 8612.89 usec

Matrix-matrix (prompt len 4:
m=8, n=128, k=4
====================
trans_b  torch.float32    2.24 usec
trans_b  torch.float16    0.91 usec
trans_b torch.bfloat16    1.07 usec
m=128, n=8, k=4
====================
trans_b  torch.float32    1.58 usec
trans_b  torch.float16    1.96 usec
trans_b torch.bfloat16    2.11 usec
m=4096, n=4096, k=4
====================
trans_b  torch.float32 4583.43 usec
trans_b  torch.float16 3014.04 usec
trans_b torch.bfloat16 4434.04 usec
m=11008, n=4096, k=4
====================
trans_b  torch.float32 6245.55 usec
trans_b  torch.float16 7513.82 usec
trans_b torch.bfloat16 11207.80 usec
m=4096, n=11008, k=4
====================
trans_b  torch.float32 6096.22 usec
trans_b  torch.float16 7688.82 usec
trans_b torch.bfloat16 11143.72 usec
m=32000, n=4096, k=4
====================
trans_b  torch.float32 17982.88 usec
trans_b  torch.float16 22001.28 usec
trans_b torch.bfloat16 32470.62 usec

Matrix-matrix (prompt len 8:
m=8, n=128, k=8
====================
trans_b  torch.float32    1.05 usec
trans_b  torch.float16    1.02 usec
trans_b torch.bfloat16    1.44 usec
m=128, n=8, k=8
====================
trans_b  torch.float32    2.07 usec
trans_b  torch.float16    3.10 usec
trans_b torch.bfloat16    3.38 usec
m=4096, n=4096, k=8
====================
trans_b  torch.float32 2245.43 usec
trans_b  torch.float16 5597.87 usec
trans_b torch.bfloat16 8775.08 usec
m=11008, n=4096, k=8
====================
trans_b  torch.float32 6227.68 usec
trans_b  torch.float16 15102.41 usec
trans_b torch.bfloat16 22457.37 usec
m=4096, n=11008, k=8
====================
trans_b  torch.float32 6082.16 usec
trans_b  torch.float16 15131.57 usec
trans_b torch.bfloat16 21860.15 usec
m=32000, n=4096, k=8
====================
trans_b  torch.float32 19659.00 usec
trans_b  torch.float16 45075.64 usec
trans_b torch.bfloat16 67746.75 usec

Matrix-matrix (prompt len 16:
m=8, n=128, k=16
====================
trans_b  torch.float32    1.31 usec
trans_b  torch.float16    1.41 usec
trans_b torch.bfloat16    2.04 usec
m=128, n=8, k=16
====================
trans_b  torch.float32    1.66 usec
trans_b  torch.float16    5.76 usec
trans_b torch.bfloat16    6.37 usec
m=4096, n=4096, k=16
====================
trans_b  torch.float32 2271.34 usec
trans_b  torch.float16 11198.46 usec
trans_b torch.bfloat16 16893.54 usec
m=11008, n=4096, k=16
====================
trans_b  torch.float32 6266.85 usec
trans_b  torch.float16 29342.49 usec
trans_b torch.bfloat16 45159.22 usec
m=4096, n=11008, k=16
====================
trans_b  torch.float32 5999.16 usec
trans_b  torch.float16 29157.43 usec
trans_b torch.bfloat16 43295.81 usec
m=32000, n=4096, k=16
====================
trans_b  torch.float32 18028.83 usec
trans_b  torch.float16 89626.88 usec
trans_b torch.bfloat16 128164.62 usec

Matrix-matrix (prompt len 32:
m=8, n=128, k=32
====================
trans_b  torch.float32    1.38 usec
trans_b  torch.float16    2.03 usec
trans_b torch.bfloat16    3.29 usec
m=128, n=8, k=32
====================
trans_b  torch.float32    1.24 usec
trans_b  torch.float16   10.58 usec
trans_b torch.bfloat16   11.97 usec
m=4096, n=4096, k=32
====================
trans_b  torch.float32 2591.56 usec
trans_b  torch.float16 21683.62 usec
trans_b torch.bfloat16 32657.68 usec
m=11008, n=4096, k=32
====================
trans_b  torch.float32 6468.43 usec
trans_b  torch.float16 57811.33 usec
trans_b torch.bfloat16 89263.21 usec
m=4096, n=11008, k=32
====================
trans_b  torch.float32 6034.74 usec
trans_b  torch.float16 59372.56 usec
trans_b torch.bfloat16 88107.85 usec
m=32000, n=4096, k=32
====================
trans_b  torch.float32 18609.27 usec
trans_b  torch.float16 167298.00 usec
trans_b torch.bfloat16 255116.37 usec

Matrix-matrix (prompt len 128:
m=8, n=128, k=128
====================
trans_b  torch.float32    2.44 usec
trans_b  torch.float16    6.11 usec
trans_b torch.bfloat16   10.92 usec
m=128, n=8, k=128
====================
trans_b  torch.float32    1.80 usec
trans_b  torch.float16   40.26 usec
trans_b torch.bfloat16   44.82 usec
m=4096, n=4096, k=128
====================
trans_b  torch.float32 4773.29 usec
trans_b  torch.float16 84458.54 usec
trans_b torch.bfloat16 131248.58 usec
m=11008, n=4096, k=128
====================
trans_b  torch.float32 12249.16 usec
trans_b  torch.float16 234411.87 usec
trans_b torch.bfloat16 351970.71 usec
m=4096, n=11008, k=128
====================
trans_b  torch.float32 11439.24 usec
trans_b  torch.float16 233347.04 usec
trans_b torch.bfloat16 354475.96 usec
m=32000, n=4096, k=128
====================
trans_b  torch.float32 33803.03 usec
trans_b  torch.float16 688157.54 usec
trans_b torch.bfloat16 1048221.42 usec

Also ran the stock configuration; it was unchanged, indicating that we need to integrate this path with torch.mv separately, which will come in a follow-up PR.l

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label May 29, 2024
Copy link

pytorch-bot bot commented May 29, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/127477

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit aa86fe0 with merge base 21144ce (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Jun 3, 2024
…127478)

Summary: The existing code didn't gate the fast path, so the fast path had to duplicate the stock kernel. Now we gate it and delete the duplicate kernel.

Test Plan: Existing tests. Flipped the TORCH_INTERNAL_ASSERT_DEBUG_ONLY to non-debug and forced to fail (locally) to make sure we had test coverage.
Pull Request resolved: #127478
Approved by: https://github.com/malfet
ghstack dependencies: #127477
pytorchmergebot pushed a commit that referenced this pull request Jun 3, 2024
Summary: Used bfloat16 dot support from #127477 to write a bfloat16 transposed fast path and integrated it.

Test Plan: Ran https://github.com/malfet/llm_experiments/blob/main/benchmarks/benchmark_torch_mm.py before and after on my Apple M1 Pro.
Before:
```
mv_nt    torch.float32    6.77 usec
mv_nt    torch.float16    8.24 usec
mv_nt   torch.bfloat16  184.74 usec
mv_ta    torch.float32    5.71 usec
mv_ta    torch.float16   27.95 usec
mv_ta   torch.bfloat16   98.06 usec
notrans  torch.float32    5.55 usec
notrans  torch.float16   25.11 usec
notrans torch.bfloat16   63.55 usec
trans_a  torch.float32    5.62 usec
trans_a  torch.float16   74.48 usec
trans_a torch.bfloat16  313.19 usec
trans_b  torch.float32    5.68 usec
trans_b  torch.float16    8.18 usec
trans_b torch.bfloat16   14.96 usec
```

After:
```
mv_nt    torch.float32    5.40 usec
mv_nt    torch.float16    8.25 usec
mv_nt   torch.bfloat16   12.81 usec
mv_ta    torch.float32    5.69 usec
mv_ta    torch.float16   27.94 usec
mv_ta   torch.bfloat16   98.18 usec
notrans  torch.float32    5.60 usec
notrans  torch.float16   25.17 usec
notrans torch.bfloat16   63.22 usec
trans_a  torch.float32    5.61 usec
trans_a  torch.float16   69.32 usec
trans_a torch.bfloat16  316.62 usec
trans_b  torch.float32    5.60 usec
trans_b  torch.float16    8.09 usec
trans_b torch.bfloat16   14.61 usec
```

Note large improvement in mv_nt torch.bfloat16 case.
Pull Request resolved: #127484
Approved by: https://github.com/malfet
ghstack dependencies: #127477, #127478
petrex pushed a commit to petrex/pytorch that referenced this pull request Jun 5, 2024
…ucture (pytorch#127477)

Summary: This gets us a baseline level of reasonable performance for
bfloat16 matrix-vector and matrix-matrix multiplication on my Apple
M1. I've intentionally left using intrinsics for future work.

Test Plan: Used
https://github.com/malfet/llm_experiments/blob/main/benchmarks/benchmark_torch_mm.py
(modified to run larger sizes) to benchmark a range of LLM-interesting
matrix-vector and matrix-matrix sizes on my Apple M1 Pro. bfloat16 performance is
improved across the board (except possibly for very small cases) and
now exceeds float32 performance (as it should) for the matrix-vector
cases.

Before:
```
Matrix-vector:
m=8, n=128, k=1
====================
trans_b  torch.float32    0.75 usec
trans_b  torch.float16    0.71 usec
trans_b torch.bfloat16    0.81 usec
m=128, n=8, k=1
====================
trans_b  torch.float32    0.75 usec
trans_b  torch.float16    0.93 usec
trans_b torch.bfloat16    0.98 usec
m=4096, n=4096, k=1
====================
trans_b  torch.float32 2194.31 usec
trans_b  torch.float16  661.27 usec
trans_b torch.bfloat16 3758.42 usec
m=11008, n=4096, k=1
====================
trans_b  torch.float32 5792.04 usec
trans_b  torch.float16 1789.98 usec
trans_b torch.bfloat16 10120.67 usec
m=4096, n=11008, k=1
====================
trans_b  torch.float32 6101.22 usec
trans_b  torch.float16 1927.34 usec
trans_b torch.bfloat16 10469.47 usec
m=32000, n=4096, k=1
====================
trans_b  torch.float32 18353.20 usec
trans_b  torch.float16 5161.06 usec
trans_b torch.bfloat16 29601.69 usec

Matrix-matrix (prompt len 4:
m=8, n=128, k=4
====================
trans_b  torch.float32    2.14 usec
trans_b  torch.float16    0.85 usec
trans_b torch.bfloat16    1.19 usec
m=128, n=8, k=4
====================
trans_b  torch.float32    1.47 usec
trans_b  torch.float16    1.85 usec
trans_b torch.bfloat16    1.75 usec
m=4096, n=4096, k=4
====================
trans_b  torch.float32 4416.40 usec
trans_b  torch.float16 2688.36 usec
trans_b torch.bfloat16 14987.33 usec
m=11008, n=4096, k=4
====================
trans_b  torch.float32 6140.24 usec
trans_b  torch.float16 7467.26 usec
trans_b torch.bfloat16 40295.52 usec
m=4096, n=11008, k=4
====================
trans_b  torch.float32 6143.10 usec
trans_b  torch.float16 7298.04 usec
trans_b torch.bfloat16 41393.43 usec
m=32000, n=4096, k=4
====================
trans_b  torch.float32 17650.72 usec
trans_b  torch.float16 21346.63 usec
trans_b torch.bfloat16 116849.98 usec

Matrix-matrix (prompt len 8:
m=8, n=128, k=8
====================
trans_b  torch.float32    1.05 usec
trans_b  torch.float16    1.03 usec
trans_b torch.bfloat16    1.69 usec
m=128, n=8, k=8
====================
trans_b  torch.float32    2.05 usec
trans_b  torch.float16    3.08 usec
trans_b torch.bfloat16    2.95 usec
m=4096, n=4096, k=8
====================
trans_b  torch.float32 2323.99 usec
trans_b  torch.float16 5265.45 usec
trans_b torch.bfloat16 29942.40 usec
m=11008, n=4096, k=8
====================
trans_b  torch.float32 6202.01 usec
trans_b  torch.float16 14677.90 usec
trans_b torch.bfloat16 80625.18 usec
m=4096, n=11008, k=8
====================
trans_b  torch.float32 6112.05 usec
trans_b  torch.float16 14340.52 usec
trans_b torch.bfloat16 82799.99 usec
m=32000, n=4096, k=8
====================
trans_b  torch.float32 17650.65 usec
trans_b  torch.float16 42551.43 usec
trans_b torch.bfloat16 236081.08 usec

Matrix-matrix (prompt len 16:
m=8, n=128, k=16
====================
trans_b  torch.float32    1.26 usec
trans_b  torch.float16    1.34 usec
trans_b torch.bfloat16    2.69 usec
m=128, n=8, k=16
====================
trans_b  torch.float32    1.60 usec
trans_b  torch.float16    5.81 usec
trans_b torch.bfloat16    5.34 usec
m=4096, n=4096, k=16
====================
trans_b  torch.float32 2328.05 usec
trans_b  torch.float16 10526.58 usec
trans_b torch.bfloat16 60028.28 usec
m=11008, n=4096, k=16
====================
trans_b  torch.float32 6243.35 usec
trans_b  torch.float16 28505.08 usec
trans_b torch.bfloat16 163670.15 usec
m=4096, n=11008, k=16
====================
trans_b  torch.float32 5870.11 usec
trans_b  torch.float16 28597.89 usec
trans_b torch.bfloat16 165404.88 usec
m=32000, n=4096, k=16
====================
trans_b  torch.float32 17746.27 usec
trans_b  torch.float16 83393.87 usec
trans_b torch.bfloat16 472313.13 usec

Matrix-matrix (prompt len 32:
m=8, n=128, k=32
====================
trans_b  torch.float32    1.35 usec
trans_b  torch.float16    2.01 usec
trans_b torch.bfloat16    4.68 usec
m=128, n=8, k=32
====================
trans_b  torch.float32    1.19 usec
trans_b  torch.float16   10.98 usec
trans_b torch.bfloat16   10.13 usec
m=4096, n=4096, k=32
====================
trans_b  torch.float32 2525.29 usec
trans_b  torch.float16 23106.71 usec
trans_b torch.bfloat16 122987.04 usec
m=11008, n=4096, k=32
====================
trans_b  torch.float32 6131.34 usec
trans_b  torch.float16 57537.41 usec
trans_b torch.bfloat16 327825.00 usec
m=4096, n=11008, k=32
====================
trans_b  torch.float32 6395.01 usec
trans_b  torch.float16 57456.33 usec
trans_b torch.bfloat16 331325.58 usec
m=32000, n=4096, k=32
====================
trans_b  torch.float32 19078.68 usec
trans_b  torch.float16 167735.08 usec
trans_b torch.bfloat16 975736.88 usec

Matrix-matrix (prompt len 128:
m=8, n=128, k=128
====================
trans_b  torch.float32    2.40 usec
trans_b  torch.float16    6.07 usec
trans_b torch.bfloat16   16.83 usec
m=128, n=8, k=128
====================
trans_b  torch.float32    1.78 usec
trans_b  torch.float16   40.35 usec
trans_b torch.bfloat16   37.21 usec
m=4096, n=4096, k=128
====================
trans_b  torch.float32 4827.60 usec
trans_b  torch.float16 84341.24 usec
trans_b torch.bfloat16 478917.75 usec
m=11008, n=4096, k=128
====================
trans_b  torch.float32 11879.96 usec
trans_b  torch.float16 226484.33 usec
trans_b torch.bfloat16 1289465.50 usec
m=4096, n=11008, k=128
====================
trans_b  torch.float32 10707.75 usec
trans_b  torch.float16 229200.58 usec
trans_b torch.bfloat16 1327416.67 usec
m=32000, n=4096, k=128
====================
trans_b  torch.float32 33306.32 usec
trans_b  torch.float16 662898.21 usec
trans_b torch.bfloat16 3815866.63 usec
```

After:
```
Matrix-vector:
m=8, n=128, k=1
====================
trans_b  torch.float32    0.77 usec
trans_b  torch.float16    0.72 usec
trans_b torch.bfloat16    0.77 usec
m=128, n=8, k=1
====================
trans_b  torch.float32    0.73 usec
trans_b  torch.float16    0.93 usec
trans_b torch.bfloat16    1.56 usec
m=4096, n=4096, k=1
====================
trans_b  torch.float32 2195.22 usec
trans_b  torch.float16  675.40 usec
trans_b torch.bfloat16 1038.29 usec
m=11008, n=4096, k=1
====================
trans_b  torch.float32 5980.27 usec
trans_b  torch.float16 1806.08 usec
trans_b torch.bfloat16 2756.46 usec
m=4096, n=11008, k=1
====================
trans_b  torch.float32 6339.95 usec
trans_b  torch.float16 1844.71 usec
trans_b torch.bfloat16 2726.52 usec
m=32000, n=4096, k=1
====================
trans_b  torch.float32 18137.17 usec
trans_b  torch.float16 6020.75 usec
trans_b torch.bfloat16 8612.89 usec

Matrix-matrix (prompt len 4:
m=8, n=128, k=4
====================
trans_b  torch.float32    2.24 usec
trans_b  torch.float16    0.91 usec
trans_b torch.bfloat16    1.07 usec
m=128, n=8, k=4
====================
trans_b  torch.float32    1.58 usec
trans_b  torch.float16    1.96 usec
trans_b torch.bfloat16    2.11 usec
m=4096, n=4096, k=4
====================
trans_b  torch.float32 4583.43 usec
trans_b  torch.float16 3014.04 usec
trans_b torch.bfloat16 4434.04 usec
m=11008, n=4096, k=4
====================
trans_b  torch.float32 6245.55 usec
trans_b  torch.float16 7513.82 usec
trans_b torch.bfloat16 11207.80 usec
m=4096, n=11008, k=4
====================
trans_b  torch.float32 6096.22 usec
trans_b  torch.float16 7688.82 usec
trans_b torch.bfloat16 11143.72 usec
m=32000, n=4096, k=4
====================
trans_b  torch.float32 17982.88 usec
trans_b  torch.float16 22001.28 usec
trans_b torch.bfloat16 32470.62 usec

Matrix-matrix (prompt len 8:
m=8, n=128, k=8
====================
trans_b  torch.float32    1.05 usec
trans_b  torch.float16    1.02 usec
trans_b torch.bfloat16    1.44 usec
m=128, n=8, k=8
====================
trans_b  torch.float32    2.07 usec
trans_b  torch.float16    3.10 usec
trans_b torch.bfloat16    3.38 usec
m=4096, n=4096, k=8
====================
trans_b  torch.float32 2245.43 usec
trans_b  torch.float16 5597.87 usec
trans_b torch.bfloat16 8775.08 usec
m=11008, n=4096, k=8
====================
trans_b  torch.float32 6227.68 usec
trans_b  torch.float16 15102.41 usec
trans_b torch.bfloat16 22457.37 usec
m=4096, n=11008, k=8
====================
trans_b  torch.float32 6082.16 usec
trans_b  torch.float16 15131.57 usec
trans_b torch.bfloat16 21860.15 usec
m=32000, n=4096, k=8
====================
trans_b  torch.float32 19659.00 usec
trans_b  torch.float16 45075.64 usec
trans_b torch.bfloat16 67746.75 usec

Matrix-matrix (prompt len 16:
m=8, n=128, k=16
====================
trans_b  torch.float32    1.31 usec
trans_b  torch.float16    1.41 usec
trans_b torch.bfloat16    2.04 usec
m=128, n=8, k=16
====================
trans_b  torch.float32    1.66 usec
trans_b  torch.float16    5.76 usec
trans_b torch.bfloat16    6.37 usec
m=4096, n=4096, k=16
====================
trans_b  torch.float32 2271.34 usec
trans_b  torch.float16 11198.46 usec
trans_b torch.bfloat16 16893.54 usec
m=11008, n=4096, k=16
====================
trans_b  torch.float32 6266.85 usec
trans_b  torch.float16 29342.49 usec
trans_b torch.bfloat16 45159.22 usec
m=4096, n=11008, k=16
====================
trans_b  torch.float32 5999.16 usec
trans_b  torch.float16 29157.43 usec
trans_b torch.bfloat16 43295.81 usec
m=32000, n=4096, k=16
====================
trans_b  torch.float32 18028.83 usec
trans_b  torch.float16 89626.88 usec
trans_b torch.bfloat16 128164.62 usec

Matrix-matrix (prompt len 32:
m=8, n=128, k=32
====================
trans_b  torch.float32    1.38 usec
trans_b  torch.float16    2.03 usec
trans_b torch.bfloat16    3.29 usec
m=128, n=8, k=32
====================
trans_b  torch.float32    1.24 usec
trans_b  torch.float16   10.58 usec
trans_b torch.bfloat16   11.97 usec
m=4096, n=4096, k=32
====================
trans_b  torch.float32 2591.56 usec
trans_b  torch.float16 21683.62 usec
trans_b torch.bfloat16 32657.68 usec
m=11008, n=4096, k=32
====================
trans_b  torch.float32 6468.43 usec
trans_b  torch.float16 57811.33 usec
trans_b torch.bfloat16 89263.21 usec
m=4096, n=11008, k=32
====================
trans_b  torch.float32 6034.74 usec
trans_b  torch.float16 59372.56 usec
trans_b torch.bfloat16 88107.85 usec
m=32000, n=4096, k=32
====================
trans_b  torch.float32 18609.27 usec
trans_b  torch.float16 167298.00 usec
trans_b torch.bfloat16 255116.37 usec

Matrix-matrix (prompt len 128:
m=8, n=128, k=128
====================
trans_b  torch.float32    2.44 usec
trans_b  torch.float16    6.11 usec
trans_b torch.bfloat16   10.92 usec
m=128, n=8, k=128
====================
trans_b  torch.float32    1.80 usec
trans_b  torch.float16   40.26 usec
trans_b torch.bfloat16   44.82 usec
m=4096, n=4096, k=128
====================
trans_b  torch.float32 4773.29 usec
trans_b  torch.float16 84458.54 usec
trans_b torch.bfloat16 131248.58 usec
m=11008, n=4096, k=128
====================
trans_b  torch.float32 12249.16 usec
trans_b  torch.float16 234411.87 usec
trans_b torch.bfloat16 351970.71 usec
m=4096, n=11008, k=128
====================
trans_b  torch.float32 11439.24 usec
trans_b  torch.float16 233347.04 usec
trans_b torch.bfloat16 354475.96 usec
m=32000, n=4096, k=128
====================
trans_b  torch.float32 33803.03 usec
trans_b  torch.float16 688157.54 usec
trans_b torch.bfloat16 1048221.42 usec
```

Also ran the stock configuration; it was unchanged, indicating that we need to integrate this path with torch.mv separately, which will come in a follow-up PR.l

Pull Request resolved: pytorch#127477
Approved by: https://github.com/malfet
petrex pushed a commit to petrex/pytorch that referenced this pull request Jun 5, 2024
…ytorch#127478)

Summary: The existing code didn't gate the fast path, so the fast path had to duplicate the stock kernel. Now we gate it and delete the duplicate kernel.

Test Plan: Existing tests. Flipped the TORCH_INTERNAL_ASSERT_DEBUG_ONLY to non-debug and forced to fail (locally) to make sure we had test coverage.
Pull Request resolved: pytorch#127478
Approved by: https://github.com/malfet
ghstack dependencies: pytorch#127477
petrex pushed a commit to petrex/pytorch that referenced this pull request Jun 5, 2024
Summary: Used bfloat16 dot support from pytorch#127477 to write a bfloat16 transposed fast path and integrated it.

Test Plan: Ran https://github.com/malfet/llm_experiments/blob/main/benchmarks/benchmark_torch_mm.py before and after on my Apple M1 Pro.
Before:
```
mv_nt    torch.float32    6.77 usec
mv_nt    torch.float16    8.24 usec
mv_nt   torch.bfloat16  184.74 usec
mv_ta    torch.float32    5.71 usec
mv_ta    torch.float16   27.95 usec
mv_ta   torch.bfloat16   98.06 usec
notrans  torch.float32    5.55 usec
notrans  torch.float16   25.11 usec
notrans torch.bfloat16   63.55 usec
trans_a  torch.float32    5.62 usec
trans_a  torch.float16   74.48 usec
trans_a torch.bfloat16  313.19 usec
trans_b  torch.float32    5.68 usec
trans_b  torch.float16    8.18 usec
trans_b torch.bfloat16   14.96 usec
```

After:
```
mv_nt    torch.float32    5.40 usec
mv_nt    torch.float16    8.25 usec
mv_nt   torch.bfloat16   12.81 usec
mv_ta    torch.float32    5.69 usec
mv_ta    torch.float16   27.94 usec
mv_ta   torch.bfloat16   98.18 usec
notrans  torch.float32    5.60 usec
notrans  torch.float16   25.17 usec
notrans torch.bfloat16   63.22 usec
trans_a  torch.float32    5.61 usec
trans_a  torch.float16   69.32 usec
trans_a torch.bfloat16  316.62 usec
trans_b  torch.float32    5.60 usec
trans_b  torch.float16    8.09 usec
trans_b torch.bfloat16   14.61 usec
```

Note large improvement in mv_nt torch.bfloat16 case.
Pull Request resolved: pytorch#127484
Approved by: https://github.com/malfet
ghstack dependencies: pytorch#127477, pytorch#127478
@github-actions github-actions bot deleted the gh/swolchok/635/head branch July 4, 2024 01:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: cpu CPU specific problem (e.g., perf, algorithm) topic: performance topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants