-
Notifications
You must be signed in to change notification settings - Fork 25.1k
Description
🚀 The feature, motivation and pitch
Enable support for Flash Attention Memory Efficient and SDPA kernels for AMD GPUs.
At present using these gives below warning with latest nightlies (torch==2.2.0.dev20231105+rocm5.6, pytorch-triton-rocm==2.1.0+34f8189eae):
model.py:187: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:253.)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
model.py:187: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:291.)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
ROCm already has an implementation of Tri's FA here: https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm2#amd-gpurocm-support
Alternatives
User's have to manually install the ROCm version of FA and use that in their code, vs using the native PyTorch APIs.
Additional context
The ROCM build currently has the FA related flags turned off by default: https://github.com/pytorch/pytorch/blob/main/CMakeLists.txt#L741-L750
cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang
Metadata
Metadata
Assignees
Labels
Type
Projects
Status