Skip to content

[RFC] Flash decoding optimization based on FlexAttention on X86 CPU #161757

@Valentine233

Description

@Valentine233

Release highlight for proposed Feature

Flash decoding optimization based on FlexAttention on X86 CPU

Point(s) of contact

xuan.liao@intel.com guobing.chen@intel.com

Release Mode (pytorch/pytorch features only)

In-tree

Out-Of-Tree Repo

No response

Description and value to the user

Flash decoding is a common technique adopted in LLM inference, in order to speed up attention and bring faster generation for long sequences. In PyTorch, the technique already exists for the CUDA path, but not for the CPU one. The proposed feature supports the flash decoding optimization based on FlexAttention on X86 CPU inductor backend, which realizes the parallelism on KV sequence by partition and reduction. The optimization can greatly improve the CPU utilization when the original parallelism is not sufficient, e.g. small batch size/head number/Q sequence length and long KV sequence length. This is expected to help PyTorch users improve the performance for LLM decoding phase, especially with long context length.

Link to design doc, GitHub issues, past submissions, etc

#159835

What feedback adopters have provided

Adopters found it could bring good performances for LLM inference with long context length. Furthermore, the feature largely improves the tensor parallelism cases where the heads will be further split with sharding.

Plan for documentations / tutorials

Tutorial is not needed

Additional context for tutorials

No response

Marketing/Blog Coverage

Yes

Are you requesting other marketing assistance with this feature?

No

Release Version

2.9

OS / Platform / Compute Coverage

Linux only
X86 CPU only

Testing Support (CI, test cases, etc..)

Unit testing is covered by CI.
For E2E test, one needs to run the LLM model by calling FlexAttention API with torch.compile.

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiionrelease-feature-requestThis tag is to mark Feature Tracked for PyTorch OSS ReleasestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions