Skip to content

Conversation

zrphercule
Copy link
Contributor

Summary:
As described in https://fb.quip.com/oxpiA1uDBjgP

This implements the first parts of the RFC, and is a rough draft showing the approach. The idea is that for the first cut we can maintain very close (identical I believe in this diff) numerical equivalence to the existing nn.MHA implementation, which is what this diff attempts to do. In subsequent implementations, once we have a working and adopted native self-attention implementation, we could then explore alternative implementations, etc.

The current implementation is similar to existing dedicated implementations such as LightSeq/FasterTransformer/DeepSpeed, and for MHA on both CPUs and GPUs is between 1.2x and 2x faster depending on the setting. It makes some approximations/restrictions (doesn't handle masking in masked softmax, etc), but these shouldn't materially impact performance.

This does the first few items:

  • add native_multi_head_attention(...) , native_multi_head_attention_backward(..) to native_functions.yaml
  • Implement native_multi_head_attention(..) on GPU, extracting bits and pieces out of LS/DS/FT as appropriate
  • Implement native_multi_head_attention(..) on CPU

The backward implementation is still WIP, but the idea would be to:

  • Hook these up in derivatives.yaml
    Implement native_multi_head_attention_backward(..) on GPU, extracting out bits and pieces out of LS/DS (not FT since it’s inference only)
  • Implement native_multi_head_attention_backward(..) on CPU
  • In torch.nn.functional.multi_head_attention_forward
    def multi_head_attention_forward(
    , add some conditionals to check if we are being called in a BERT/ViT-style encoder fashion, and invoke the native function directly.

Test Plan: TODO

Differential Revision: D31829981

@pytorch-probot
Copy link

pytorch-probot bot commented Jan 4, 2022

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/zrphercule/pytorch/blob/45ed5f2773c8b116091ccfa358f38ef77cb857f6/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk ✅ triggered
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
docker-builds ciflow/all, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped

You can add a comment to the PR and tag @pytorchbot with the following commands:
# ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun

# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow

For more information, please take a look at the CI Flow Wiki.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jan 4, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 45ed5f2 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D31829981

7 similar comments
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D31829981

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D31829981

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D31829981

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D31829981

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D31829981

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D31829981

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D31829981

…ementation (#70649)

Summary:
Pull Request resolved: #70649

As described in https://fb.quip.com/oxpiA1uDBjgP

This implements the first parts of the RFC, and is a rough draft showing the approach. The idea is that for the first cut we can maintain very close (identical I believe in this diff) numerical equivalence to the existing nn.MHA implementation, which is what this diff attempts to do. In subsequent implementations, once we have a working and adopted native self-attention implementation, we could then explore alternative implementations, etc.

The current implementation is similar to existing dedicated implementations such as LightSeq/FasterTransformer/DeepSpeed, and for MHA on both CPUs and GPUs is between 1.2x and 2x faster depending on the setting. It makes some approximations/restrictions (doesn't handle masking in masked softmax, etc), but these shouldn't materially impact performance.

This does the first few items:

* add native_multi_head_attention(...) , native_multi_head_attention_backward(..) to native_functions.yaml
* Implement native_multi_head_attention(..) on GPU, extracting bits and pieces out of LS/DS/FT as appropriate
* Implement native_multi_head_attention(..) on CPU

The backward implementation is still WIP, but the idea would be to:

* Hook these up in derivatives.yaml
Implement native_multi_head_attention_backward(..) on GPU, extracting out bits and pieces out of LS/DS (not FT since it’s inference only)
* Implement native_multi_head_attention_backward(..) on CPU
* In torch.nn.functional.multi_head_attention_forward https://github.com/pytorch/pytorch/blob/23321ba7a3b634ee734455aab4a984689802cad0/torch/nn/functional.py#L4953, add some conditionals to check if we are being called in a BERT/ViT-style encoder fashion, and invoke the native function directly.

Test Plan: TODO

Reviewed By: mikekgfb

Differential Revision: D31829981

fbshipit-source-id: f5db2e758dde4d0b204899b8553110e14c1777ed
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D31829981

@bdhirsh
Copy link
Contributor

bdhirsh commented Feb 2, 2022

Hi @zrphercule. It looks like this PR adds a new public torch API (torch.native_multi_attention_self_attention), but this PR doesn't include any public documentation for it (e.g. in torch_docs.py`). For the 1.11 release, we'll want to either (1) add the public docs for this op, or (2) mark this api as private by adding a leading underscore to the name.

Can you do one of those and cherry-pick into the 1.11 release branch?

@zrphercule
Copy link
Contributor Author

Hi @zrphercule. It looks like this PR adds a new public torch API (torch.native_multi_attention_self_attention), but this PR doesn't include any public documentation for it (e.g. in torch_docs.py`). For the 1.11 release, we'll want to either (1) add the public docs for this op, or (2) mark this api as private by adding a leading underscore to the name.

Can you do one of those and cherry-pick into the 1.11 release branch?

Hi @bdhirsh, I think it is better to make this API private for now, since we are still working on some following diffs of it. Will make a pr today, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants