Skip to content

Add support for Flash Attention for AMD/ROCm #112997

@chauhang

Description

@chauhang

🚀 The feature, motivation and pitch

Enable support for Flash Attention Memory Efficient and SDPA kernels for AMD GPUs.

At present using these gives below warning with latest nightlies (torch==2.2.0.dev20231105+rocm5.6, pytorch-triton-rocm==2.1.0+34f8189eae):

model.py:187: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:253.)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
model.py:187: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:291.)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

ROCm already has an implementation of Tri's FA here: https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm2#amd-gpurocm-support

Alternatives

User's have to manually install the ROCm version of FA and use that in their code, vs using the native PyTorch APIs.

Additional context

The ROCM build currently has the FA related flags turned off by default: https://github.com/pytorch/pytorch/blob/main/CMakeLists.txt#L741-L750

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang

Metadata

Metadata

Assignees

No one assigned

    Labels

    ciflow/rocmTrigger "default" config CI on ROCmmodule: rocmAMD GPU support for PytorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions