-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Onboard ARM bfloat16 to gemm-by-dot-product-for-gemm_transa_ infrastructure #127477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/127477
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit aa86fe0 with merge base 21144ce ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This was referenced May 29, 2024
malfet
approved these changes
Jun 3, 2024
pytorchmergebot
pushed a commit
that referenced
this pull request
Jun 3, 2024
…127478) Summary: The existing code didn't gate the fast path, so the fast path had to duplicate the stock kernel. Now we gate it and delete the duplicate kernel. Test Plan: Existing tests. Flipped the TORCH_INTERNAL_ASSERT_DEBUG_ONLY to non-debug and forced to fail (locally) to make sure we had test coverage. Pull Request resolved: #127478 Approved by: https://github.com/malfet ghstack dependencies: #127477
pytorchmergebot
pushed a commit
that referenced
this pull request
Jun 3, 2024
Summary: Used bfloat16 dot support from #127477 to write a bfloat16 transposed fast path and integrated it. Test Plan: Ran https://github.com/malfet/llm_experiments/blob/main/benchmarks/benchmark_torch_mm.py before and after on my Apple M1 Pro. Before: ``` mv_nt torch.float32 6.77 usec mv_nt torch.float16 8.24 usec mv_nt torch.bfloat16 184.74 usec mv_ta torch.float32 5.71 usec mv_ta torch.float16 27.95 usec mv_ta torch.bfloat16 98.06 usec notrans torch.float32 5.55 usec notrans torch.float16 25.11 usec notrans torch.bfloat16 63.55 usec trans_a torch.float32 5.62 usec trans_a torch.float16 74.48 usec trans_a torch.bfloat16 313.19 usec trans_b torch.float32 5.68 usec trans_b torch.float16 8.18 usec trans_b torch.bfloat16 14.96 usec ``` After: ``` mv_nt torch.float32 5.40 usec mv_nt torch.float16 8.25 usec mv_nt torch.bfloat16 12.81 usec mv_ta torch.float32 5.69 usec mv_ta torch.float16 27.94 usec mv_ta torch.bfloat16 98.18 usec notrans torch.float32 5.60 usec notrans torch.float16 25.17 usec notrans torch.bfloat16 63.22 usec trans_a torch.float32 5.61 usec trans_a torch.float16 69.32 usec trans_a torch.bfloat16 316.62 usec trans_b torch.float32 5.60 usec trans_b torch.float16 8.09 usec trans_b torch.bfloat16 14.61 usec ``` Note large improvement in mv_nt torch.bfloat16 case. Pull Request resolved: #127484 Approved by: https://github.com/malfet ghstack dependencies: #127477, #127478
petrex
pushed a commit
to petrex/pytorch
that referenced
this pull request
Jun 5, 2024
…ucture (pytorch#127477) Summary: This gets us a baseline level of reasonable performance for bfloat16 matrix-vector and matrix-matrix multiplication on my Apple M1. I've intentionally left using intrinsics for future work. Test Plan: Used https://github.com/malfet/llm_experiments/blob/main/benchmarks/benchmark_torch_mm.py (modified to run larger sizes) to benchmark a range of LLM-interesting matrix-vector and matrix-matrix sizes on my Apple M1 Pro. bfloat16 performance is improved across the board (except possibly for very small cases) and now exceeds float32 performance (as it should) for the matrix-vector cases. Before: ``` Matrix-vector: m=8, n=128, k=1 ==================== trans_b torch.float32 0.75 usec trans_b torch.float16 0.71 usec trans_b torch.bfloat16 0.81 usec m=128, n=8, k=1 ==================== trans_b torch.float32 0.75 usec trans_b torch.float16 0.93 usec trans_b torch.bfloat16 0.98 usec m=4096, n=4096, k=1 ==================== trans_b torch.float32 2194.31 usec trans_b torch.float16 661.27 usec trans_b torch.bfloat16 3758.42 usec m=11008, n=4096, k=1 ==================== trans_b torch.float32 5792.04 usec trans_b torch.float16 1789.98 usec trans_b torch.bfloat16 10120.67 usec m=4096, n=11008, k=1 ==================== trans_b torch.float32 6101.22 usec trans_b torch.float16 1927.34 usec trans_b torch.bfloat16 10469.47 usec m=32000, n=4096, k=1 ==================== trans_b torch.float32 18353.20 usec trans_b torch.float16 5161.06 usec trans_b torch.bfloat16 29601.69 usec Matrix-matrix (prompt len 4: m=8, n=128, k=4 ==================== trans_b torch.float32 2.14 usec trans_b torch.float16 0.85 usec trans_b torch.bfloat16 1.19 usec m=128, n=8, k=4 ==================== trans_b torch.float32 1.47 usec trans_b torch.float16 1.85 usec trans_b torch.bfloat16 1.75 usec m=4096, n=4096, k=4 ==================== trans_b torch.float32 4416.40 usec trans_b torch.float16 2688.36 usec trans_b torch.bfloat16 14987.33 usec m=11008, n=4096, k=4 ==================== trans_b torch.float32 6140.24 usec trans_b torch.float16 7467.26 usec trans_b torch.bfloat16 40295.52 usec m=4096, n=11008, k=4 ==================== trans_b torch.float32 6143.10 usec trans_b torch.float16 7298.04 usec trans_b torch.bfloat16 41393.43 usec m=32000, n=4096, k=4 ==================== trans_b torch.float32 17650.72 usec trans_b torch.float16 21346.63 usec trans_b torch.bfloat16 116849.98 usec Matrix-matrix (prompt len 8: m=8, n=128, k=8 ==================== trans_b torch.float32 1.05 usec trans_b torch.float16 1.03 usec trans_b torch.bfloat16 1.69 usec m=128, n=8, k=8 ==================== trans_b torch.float32 2.05 usec trans_b torch.float16 3.08 usec trans_b torch.bfloat16 2.95 usec m=4096, n=4096, k=8 ==================== trans_b torch.float32 2323.99 usec trans_b torch.float16 5265.45 usec trans_b torch.bfloat16 29942.40 usec m=11008, n=4096, k=8 ==================== trans_b torch.float32 6202.01 usec trans_b torch.float16 14677.90 usec trans_b torch.bfloat16 80625.18 usec m=4096, n=11008, k=8 ==================== trans_b torch.float32 6112.05 usec trans_b torch.float16 14340.52 usec trans_b torch.bfloat16 82799.99 usec m=32000, n=4096, k=8 ==================== trans_b torch.float32 17650.65 usec trans_b torch.float16 42551.43 usec trans_b torch.bfloat16 236081.08 usec Matrix-matrix (prompt len 16: m=8, n=128, k=16 ==================== trans_b torch.float32 1.26 usec trans_b torch.float16 1.34 usec trans_b torch.bfloat16 2.69 usec m=128, n=8, k=16 ==================== trans_b torch.float32 1.60 usec trans_b torch.float16 5.81 usec trans_b torch.bfloat16 5.34 usec m=4096, n=4096, k=16 ==================== trans_b torch.float32 2328.05 usec trans_b torch.float16 10526.58 usec trans_b torch.bfloat16 60028.28 usec m=11008, n=4096, k=16 ==================== trans_b torch.float32 6243.35 usec trans_b torch.float16 28505.08 usec trans_b torch.bfloat16 163670.15 usec m=4096, n=11008, k=16 ==================== trans_b torch.float32 5870.11 usec trans_b torch.float16 28597.89 usec trans_b torch.bfloat16 165404.88 usec m=32000, n=4096, k=16 ==================== trans_b torch.float32 17746.27 usec trans_b torch.float16 83393.87 usec trans_b torch.bfloat16 472313.13 usec Matrix-matrix (prompt len 32: m=8, n=128, k=32 ==================== trans_b torch.float32 1.35 usec trans_b torch.float16 2.01 usec trans_b torch.bfloat16 4.68 usec m=128, n=8, k=32 ==================== trans_b torch.float32 1.19 usec trans_b torch.float16 10.98 usec trans_b torch.bfloat16 10.13 usec m=4096, n=4096, k=32 ==================== trans_b torch.float32 2525.29 usec trans_b torch.float16 23106.71 usec trans_b torch.bfloat16 122987.04 usec m=11008, n=4096, k=32 ==================== trans_b torch.float32 6131.34 usec trans_b torch.float16 57537.41 usec trans_b torch.bfloat16 327825.00 usec m=4096, n=11008, k=32 ==================== trans_b torch.float32 6395.01 usec trans_b torch.float16 57456.33 usec trans_b torch.bfloat16 331325.58 usec m=32000, n=4096, k=32 ==================== trans_b torch.float32 19078.68 usec trans_b torch.float16 167735.08 usec trans_b torch.bfloat16 975736.88 usec Matrix-matrix (prompt len 128: m=8, n=128, k=128 ==================== trans_b torch.float32 2.40 usec trans_b torch.float16 6.07 usec trans_b torch.bfloat16 16.83 usec m=128, n=8, k=128 ==================== trans_b torch.float32 1.78 usec trans_b torch.float16 40.35 usec trans_b torch.bfloat16 37.21 usec m=4096, n=4096, k=128 ==================== trans_b torch.float32 4827.60 usec trans_b torch.float16 84341.24 usec trans_b torch.bfloat16 478917.75 usec m=11008, n=4096, k=128 ==================== trans_b torch.float32 11879.96 usec trans_b torch.float16 226484.33 usec trans_b torch.bfloat16 1289465.50 usec m=4096, n=11008, k=128 ==================== trans_b torch.float32 10707.75 usec trans_b torch.float16 229200.58 usec trans_b torch.bfloat16 1327416.67 usec m=32000, n=4096, k=128 ==================== trans_b torch.float32 33306.32 usec trans_b torch.float16 662898.21 usec trans_b torch.bfloat16 3815866.63 usec ``` After: ``` Matrix-vector: m=8, n=128, k=1 ==================== trans_b torch.float32 0.77 usec trans_b torch.float16 0.72 usec trans_b torch.bfloat16 0.77 usec m=128, n=8, k=1 ==================== trans_b torch.float32 0.73 usec trans_b torch.float16 0.93 usec trans_b torch.bfloat16 1.56 usec m=4096, n=4096, k=1 ==================== trans_b torch.float32 2195.22 usec trans_b torch.float16 675.40 usec trans_b torch.bfloat16 1038.29 usec m=11008, n=4096, k=1 ==================== trans_b torch.float32 5980.27 usec trans_b torch.float16 1806.08 usec trans_b torch.bfloat16 2756.46 usec m=4096, n=11008, k=1 ==================== trans_b torch.float32 6339.95 usec trans_b torch.float16 1844.71 usec trans_b torch.bfloat16 2726.52 usec m=32000, n=4096, k=1 ==================== trans_b torch.float32 18137.17 usec trans_b torch.float16 6020.75 usec trans_b torch.bfloat16 8612.89 usec Matrix-matrix (prompt len 4: m=8, n=128, k=4 ==================== trans_b torch.float32 2.24 usec trans_b torch.float16 0.91 usec trans_b torch.bfloat16 1.07 usec m=128, n=8, k=4 ==================== trans_b torch.float32 1.58 usec trans_b torch.float16 1.96 usec trans_b torch.bfloat16 2.11 usec m=4096, n=4096, k=4 ==================== trans_b torch.float32 4583.43 usec trans_b torch.float16 3014.04 usec trans_b torch.bfloat16 4434.04 usec m=11008, n=4096, k=4 ==================== trans_b torch.float32 6245.55 usec trans_b torch.float16 7513.82 usec trans_b torch.bfloat16 11207.80 usec m=4096, n=11008, k=4 ==================== trans_b torch.float32 6096.22 usec trans_b torch.float16 7688.82 usec trans_b torch.bfloat16 11143.72 usec m=32000, n=4096, k=4 ==================== trans_b torch.float32 17982.88 usec trans_b torch.float16 22001.28 usec trans_b torch.bfloat16 32470.62 usec Matrix-matrix (prompt len 8: m=8, n=128, k=8 ==================== trans_b torch.float32 1.05 usec trans_b torch.float16 1.02 usec trans_b torch.bfloat16 1.44 usec m=128, n=8, k=8 ==================== trans_b torch.float32 2.07 usec trans_b torch.float16 3.10 usec trans_b torch.bfloat16 3.38 usec m=4096, n=4096, k=8 ==================== trans_b torch.float32 2245.43 usec trans_b torch.float16 5597.87 usec trans_b torch.bfloat16 8775.08 usec m=11008, n=4096, k=8 ==================== trans_b torch.float32 6227.68 usec trans_b torch.float16 15102.41 usec trans_b torch.bfloat16 22457.37 usec m=4096, n=11008, k=8 ==================== trans_b torch.float32 6082.16 usec trans_b torch.float16 15131.57 usec trans_b torch.bfloat16 21860.15 usec m=32000, n=4096, k=8 ==================== trans_b torch.float32 19659.00 usec trans_b torch.float16 45075.64 usec trans_b torch.bfloat16 67746.75 usec Matrix-matrix (prompt len 16: m=8, n=128, k=16 ==================== trans_b torch.float32 1.31 usec trans_b torch.float16 1.41 usec trans_b torch.bfloat16 2.04 usec m=128, n=8, k=16 ==================== trans_b torch.float32 1.66 usec trans_b torch.float16 5.76 usec trans_b torch.bfloat16 6.37 usec m=4096, n=4096, k=16 ==================== trans_b torch.float32 2271.34 usec trans_b torch.float16 11198.46 usec trans_b torch.bfloat16 16893.54 usec m=11008, n=4096, k=16 ==================== trans_b torch.float32 6266.85 usec trans_b torch.float16 29342.49 usec trans_b torch.bfloat16 45159.22 usec m=4096, n=11008, k=16 ==================== trans_b torch.float32 5999.16 usec trans_b torch.float16 29157.43 usec trans_b torch.bfloat16 43295.81 usec m=32000, n=4096, k=16 ==================== trans_b torch.float32 18028.83 usec trans_b torch.float16 89626.88 usec trans_b torch.bfloat16 128164.62 usec Matrix-matrix (prompt len 32: m=8, n=128, k=32 ==================== trans_b torch.float32 1.38 usec trans_b torch.float16 2.03 usec trans_b torch.bfloat16 3.29 usec m=128, n=8, k=32 ==================== trans_b torch.float32 1.24 usec trans_b torch.float16 10.58 usec trans_b torch.bfloat16 11.97 usec m=4096, n=4096, k=32 ==================== trans_b torch.float32 2591.56 usec trans_b torch.float16 21683.62 usec trans_b torch.bfloat16 32657.68 usec m=11008, n=4096, k=32 ==================== trans_b torch.float32 6468.43 usec trans_b torch.float16 57811.33 usec trans_b torch.bfloat16 89263.21 usec m=4096, n=11008, k=32 ==================== trans_b torch.float32 6034.74 usec trans_b torch.float16 59372.56 usec trans_b torch.bfloat16 88107.85 usec m=32000, n=4096, k=32 ==================== trans_b torch.float32 18609.27 usec trans_b torch.float16 167298.00 usec trans_b torch.bfloat16 255116.37 usec Matrix-matrix (prompt len 128: m=8, n=128, k=128 ==================== trans_b torch.float32 2.44 usec trans_b torch.float16 6.11 usec trans_b torch.bfloat16 10.92 usec m=128, n=8, k=128 ==================== trans_b torch.float32 1.80 usec trans_b torch.float16 40.26 usec trans_b torch.bfloat16 44.82 usec m=4096, n=4096, k=128 ==================== trans_b torch.float32 4773.29 usec trans_b torch.float16 84458.54 usec trans_b torch.bfloat16 131248.58 usec m=11008, n=4096, k=128 ==================== trans_b torch.float32 12249.16 usec trans_b torch.float16 234411.87 usec trans_b torch.bfloat16 351970.71 usec m=4096, n=11008, k=128 ==================== trans_b torch.float32 11439.24 usec trans_b torch.float16 233347.04 usec trans_b torch.bfloat16 354475.96 usec m=32000, n=4096, k=128 ==================== trans_b torch.float32 33803.03 usec trans_b torch.float16 688157.54 usec trans_b torch.bfloat16 1048221.42 usec ``` Also ran the stock configuration; it was unchanged, indicating that we need to integrate this path with torch.mv separately, which will come in a follow-up PR.l Pull Request resolved: pytorch#127477 Approved by: https://github.com/malfet
petrex
pushed a commit
to petrex/pytorch
that referenced
this pull request
Jun 5, 2024
…ytorch#127478) Summary: The existing code didn't gate the fast path, so the fast path had to duplicate the stock kernel. Now we gate it and delete the duplicate kernel. Test Plan: Existing tests. Flipped the TORCH_INTERNAL_ASSERT_DEBUG_ONLY to non-debug and forced to fail (locally) to make sure we had test coverage. Pull Request resolved: pytorch#127478 Approved by: https://github.com/malfet ghstack dependencies: pytorch#127477
petrex
pushed a commit
to petrex/pytorch
that referenced
this pull request
Jun 5, 2024
Summary: Used bfloat16 dot support from pytorch#127477 to write a bfloat16 transposed fast path and integrated it. Test Plan: Ran https://github.com/malfet/llm_experiments/blob/main/benchmarks/benchmark_torch_mm.py before and after on my Apple M1 Pro. Before: ``` mv_nt torch.float32 6.77 usec mv_nt torch.float16 8.24 usec mv_nt torch.bfloat16 184.74 usec mv_ta torch.float32 5.71 usec mv_ta torch.float16 27.95 usec mv_ta torch.bfloat16 98.06 usec notrans torch.float32 5.55 usec notrans torch.float16 25.11 usec notrans torch.bfloat16 63.55 usec trans_a torch.float32 5.62 usec trans_a torch.float16 74.48 usec trans_a torch.bfloat16 313.19 usec trans_b torch.float32 5.68 usec trans_b torch.float16 8.18 usec trans_b torch.bfloat16 14.96 usec ``` After: ``` mv_nt torch.float32 5.40 usec mv_nt torch.float16 8.25 usec mv_nt torch.bfloat16 12.81 usec mv_ta torch.float32 5.69 usec mv_ta torch.float16 27.94 usec mv_ta torch.bfloat16 98.18 usec notrans torch.float32 5.60 usec notrans torch.float16 25.17 usec notrans torch.bfloat16 63.22 usec trans_a torch.float32 5.61 usec trans_a torch.float16 69.32 usec trans_a torch.bfloat16 316.62 usec trans_b torch.float32 5.60 usec trans_b torch.float16 8.09 usec trans_b torch.bfloat16 14.61 usec ``` Note large improvement in mv_nt torch.bfloat16 case. Pull Request resolved: pytorch#127484 Approved by: https://github.com/malfet ghstack dependencies: pytorch#127477, pytorch#127478
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
Summary: This gets us a baseline level of reasonable performance for
bfloat16 matrix-vector and matrix-matrix multiplication on my Apple
M1. I've intentionally left using intrinsics for future work.
Test Plan: Used
https://github.com/malfet/llm_experiments/blob/main/benchmarks/benchmark_torch_mm.py
(modified to run larger sizes) to benchmark a range of LLM-interesting
matrix-vector and matrix-matrix sizes on my Apple M1 Pro. bfloat16 performance is
improved across the board (except possibly for very small cases) and
now exceeds float32 performance (as it should) for the matrix-vector
cases.
Before:
After:
Also ran the stock configuration; it was unchanged, indicating that we need to integrate this path with torch.mv separately, which will come in a follow-up PR.l
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10