-
Notifications
You must be signed in to change notification settings - Fork 25.4k
Description
Release highlight for proposed Feature
Flash decoding optimization based on FlexAttention on X86 CPU
Point(s) of contact
xuan.liao@intel.com guobing.chen@intel.com
Release Mode (pytorch/pytorch features only)
In-tree
Out-Of-Tree Repo
No response
Description and value to the user
Flash decoding is a common technique adopted in LLM inference, in order to speed up attention and bring faster generation for long sequences. In PyTorch, the technique already exists for the CUDA path, but not for the CPU one. The proposed feature supports the flash decoding optimization based on FlexAttention on X86 CPU inductor backend, which realizes the parallelism on KV sequence by partition and reduction. The optimization can greatly improve the CPU utilization when the original parallelism is not sufficient, e.g. small batch size/head number/Q sequence length and long KV sequence length. This is expected to help PyTorch users improve the performance for LLM decoding phase, especially with long context length.
Link to design doc, GitHub issues, past submissions, etc
What feedback adopters have provided
Adopters found it could bring good performances for LLM inference with long context length. Furthermore, the feature largely improves the tensor parallelism cases where the heads will be further split with sharding.
Plan for documentations / tutorials
Tutorial is not needed
Additional context for tutorials
No response
Marketing/Blog Coverage
Yes
Are you requesting other marketing assistance with this feature?
No
Release Version
2.9
OS / Platform / Compute Coverage
Linux only
X86 CPU only
Testing Support (CI, test cases, etc..)
Unit testing is covered by CI.
For E2E test, one needs to run the LLM model by calling FlexAttention API with torch.compile
.