-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Faster int8 quantized #125704
Faster int8 quantized #125704
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125704
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ffcd8c5 with merge base e9c5f1c (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@cccclai FYI |
7983a22
to
0ec5940
Compare
0ec5940
to
6a0aabb
Compare
@pytorchbot merge -f "Lint and MPS tests are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Or my journey to learn how to write fast Metal kernels (more details would be posted [here](https://github.com/malfet/llm_experiments/tree/main/metal-perf) ) Using gpt-fast as a benchmark (by running `python generate.py --checkpoint_path checkpoints/stories110M/model_int8.pth --device mps`) Before the change, on M2 Pro I get 50 tokens per sec After adding a very naive ```metal template<typename T> kernel void int8pack_mm( constant T * A [[buffer(0)]], constant char * B [[buffer(1)]], constant T * scales [[buffer(2)]], device T * outputData [[buffer(3)]], constant uint3 & sizes [[buffer(4)]], uint thread_index [[thread_position_in_grid]]) { const uint lda = sizes.y; const uint ldc = sizes.z; const uint m = thread_index / sizes.z; // 0..sizes.x-1 const uint n = thread_index % sizes.z; // 0..sizes.z-1 constant T *A_ptr = A + m * lda; constant char *B_ptr = B + n * lda; float rc = 0.0; for(uint k = 0; k < sizes.y; k++) { const auto a_val = float(A_ptr[k]); const auto b_val = float(B_ptr[k]); rc += a_val * b_val; } outputData[thread_index] = T(rc * float(scales[n])); } ``` Perf dropped down to sad 15 tokens per seconds. Replacing inner loop with vectorized operations ```metal float rc = 0.0; for(uint k = 0; k < sizes.y/4; k++) { const auto a_val = float4(A_ptr[k]); const auto b_val = float4(B_ptr[k]); rc += dot(a_val, b_val); } ``` Perf jumps back up to 53 tokens per second, but it's a bit of a lie when it comes to llama2-7B perf. Next step in unlocking the performance were to replace a 1D grid with a 2D one, but limit the thread group size to a single row, which results in a much better data locality which unfortunately is not observable with `stories110M` anymore as it small model size and Python runtime overhead hide the perf gain) There were several unsuccessful attempts at caching inputs in thread local memory or using `float4x4` to speed up computation. But the key to unlocking the perf were a comment in https://github.com/ml-explore/mlx/blob/631dfbe67309fb630795cd612739cbe54c75e222/mlx/backend/metal/kernels/gemv.metal#L184 which hinted at exploiting both SIMD groups and thread local caches, which resulted in 5x jump in performance compared to initial vectorization approach and 3x perf jump in end-to-end llama7b test Pull Request resolved: pytorch#125704 Approved by: https://github.com/mikekgfb
Or my journey to learn how to write fast Metal kernels (more details would be posted here )
Using gpt-fast as a benchmark (by running
python generate.py --checkpoint_path checkpoints/stories110M/model_int8.pth --device mps
)Before the change, on M2 Pro I get 50 tokens per sec
After adding a very naive
Perf dropped down to sad 15 tokens per seconds.
Replacing inner loop with vectorized operations
Perf jumps back up to 53 tokens per second, but it's a bit of a lie when it comes to llama2-7B perf.
Next step in unlocking the performance were to replace a 1D grid with a 2D one, but limit the thread group size to a single row, which results in a much better data locality which unfortunately is not observable with
stories110M
anymore as it small model size and Python runtime overhead hide the perf gain)There were several unsuccessful attempts at caching inputs in thread local memory or using
float4x4
to speed up computation. But the key to unlocking the perf were a comment in https://github.com/ml-explore/mlx/blob/631dfbe67309fb630795cd612739cbe54c75e222/mlx/backend/metal/kernels/gemv.metal#L184which hinted at exploiting both SIMD groups and thread local caches, which resulted in 5x jump in performance compared to initial vectorization approach and 3x perf jump in end-to-end llama7b test