-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add NEON-accelerated int8mm for bfloat16 #125290
Conversation
Apparently `vshlq_u32` is faster than `vcvt_f32_f16` I.e. the same stories110M run at 60 tokens/sec with f16, but at 66 tokens/sec with bf16
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125290
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 954400e with merge base ea347fa ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
As apparently `vshlq_u32` is faster than `vcvt_f32_f16` Refactor NEON `tinygemm_kernel` to rely on `load_as_float32x4` and `load_as_float32x4x2` and implement them for float16 (using vcvt), bfloat16 (using left shift) and plain float32 (not using anything) As result stories110M run at 60 tokens/sec with f16, but at 66 tokens/sec with bf16 and 75 tokens/sec with f32, though more bandwith demand starts to favor reduced floating types as model size gets bigger. Pull Request resolved: pytorch#125290 Approved by: https://github.com/mikekgfb
As apparently `vshlq_u32` is faster than `vcvt_f32_f16` Refactor NEON `tinygemm_kernel` to rely on `load_as_float32x4` and `load_as_float32x4x2` and implement them for float16 (using vcvt), bfloat16 (using left shift) and plain float32 (not using anything) As result stories110M run at 60 tokens/sec with f16, but at 66 tokens/sec with bf16 and 75 tokens/sec with f32, though more bandwith demand starts to favor reduced floating types as model size gets bigger. Pull Request resolved: pytorch#125290 Approved by: https://github.com/mikekgfb
It used to be vectorized only for f16, but no reason not to do the same for bf16 or f32 Spiritual followup of #125290
It used to be vectorized only for f16, but no reason not to do the same for bf16 or f32 Spiritual followup of #125290 Pull Request resolved: #126512 Approved by: https://github.com/Skylion007
It used to be vectorized only for f16, but no reason not to do the same for bf16 or f32 Spiritual followup of pytorch#125290 Pull Request resolved: pytorch#126512 Approved by: https://github.com/Skylion007
As apparently
vshlq_u32
is faster thanvcvt_f32_f16
Refactor NEON
tinygemm_kernel
to rely onload_as_float32x4
andload_as_float32x4x2
and implement them for float16 (using vcvt), bfloat16 (using left shift) and plain float32 (not using anything)As result stories110M run at 60 tokens/sec with f16, but at 66 tokens/sec with bf16 and 75 tokens/sec with f32, though more bandwith demand starts to favor reduced floating types as model size gets bigger.
cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10