Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aotinductor] add versions for the sdpa shim api #113487

Closed
wants to merge 3 commits into from

Conversation

chenyang78
Copy link
Contributor

@chenyang78 chenyang78 commented Nov 10, 2023

Stack from ghstack (oldest at bottom):

In our first implemenation of the sdpa shim api, we didn't consider
the case where the optional scale argument could be None. It was
unnoticed because we always got a default argument for the cuda backend.
The issue was detected with the cpu backend.

This PR implements versioning for shim kernels. Currently, we only
have different versions for the sdpa api. We expect we would only
maintain a very small number of abi-compatible shim APIs that
had different versions.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @kadeng @muchulee8 @aakhundov @ColinPeppler

In our first implemenation of the sdpa shim api, we didn't consider
the case where the optional scale argument could be None. It was
unnoticed because we always got a default argument for the cuda backend.
The issue was detected with the cpu backend.

This PR implements versioning for shim kernels. Currently, we only
have different versions for the sdpa api. We expect we would only
maintain a very small number of abi-compatible shim APIs that
had different versions.

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Nov 10, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/113487

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a3b5d67 with merge base b910d9e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@@ -252,14 +252,14 @@ AOTITorchError aoti_torch_create_tensor_from_blob(
});
}

AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
static AOTITorchError _aoti_torch__scaled_dot_product_flash_attention_internal(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need a new _internal function -- we could just have aoti_torch__scaled_dot_product_flash_attention call aoti_torch__scaled_dot_product_flash_attention_v2 while passing the scale param as &scale. The v2 function would then call at::_scaled_dot_product_flash_attention directly.

Comment on lines +3786 to +3787
# For sdpa, we need the v2 version only if any optional
# kwarg is missing.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the idea was to make all new models call the v2 version, so that we could delete the old one entirely? Doesn't using the v2 version only if an optional arg is missing make migration harder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the idea was to make all new models call the v2 version, so that we could delete the old one entirely? Doesn't using the v2 version only if an optional arg is missing make migration harder?

It's not about new models v.s. old models. Given the same model, the optional scale could be either None or have a default value, for different backends (e.g. cpu and cuda in our case). Regarding migrating to the V2 version, my understanding is that once the new V2 version becomes available on the serving side, we could always generate the V2 version. At that point, we may safely remove the old one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh right. We want to keep generating calls to the old (v1) version where possible for now so the CUDA models keep working with old libtorch deployments. Got it

@@ -4278,6 +4302,7 @@ def __init__(
self.kernel = (
f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}"
)
self.abi_compatible_kernel = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why the typechecker didn't complain, but shouldn't this line be in the constructor of ExternKernelAlloc, since ExternKernelAlloc.codegen references it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Fixed. Thanks.

In our first implemenation of the sdpa shim api, we didn't consider
the case where the optional scale argument could be None. It was
unnoticed because we always got a default argument for the cuda backend.
The issue was detected with the cpu backend.

This PR implements versioning for shim kernels. Currently, we only
have different versions for the sdpa api. We expect we would only
maintain a very small number of abi-compatible shim APIs that
had different versions.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
chenyang78 added a commit that referenced this pull request Nov 12, 2023
In our first implemenation of the sdpa shim api, we didn't consider
the case where the optional scale argument could be None. It was
unnoticed because we always got a default argument for the cuda backend.
The issue was detected with the cpu backend.

This PR implements versioning for shim kernels. Currently, we only
have different versions for the sdpa api. We expect we would only
maintain a very small number of abi-compatible shim APIs that
had different versions.

ghstack-source-id: 8283d5506c21091ddebd3dc18b56b99691a9c933
Pull Request resolved: #113487
Copy link
Contributor

@desertfire desertfire left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a comment about the function signature. The rest LGTM.

AtenTensorHandle key,
AtenTensorHandle value,
double dropout_p,
bool is_causal,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't paid close enough attention to the v1 function signature, but let's use int instead of bool in the v2 function signature to avoid any potential problems from mixing c++ compilers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Fixed. Thanks.

In our first implemenation of the sdpa shim api, we didn't consider
the case where the optional scale argument could be None. It was
unnoticed because we always got a default argument for the cuda backend.
The issue was detected with the cpu backend.

This PR implements versioning for shim kernels. Currently, we only
have different versions for the sdpa api. We expect we would only
maintain a very small number of abi-compatible shim APIs that
had different versions.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
chenyang78 added a commit that referenced this pull request Nov 13, 2023
In our first implemenation of the sdpa shim api, we didn't consider
the case where the optional scale argument could be None. It was
unnoticed because we always got a default argument for the cuda backend.
The issue was detected with the cpu backend.

This PR implements versioning for shim kernels. Currently, we only
have different versions for the sdpa api. We expect we would only
maintain a very small number of abi-compatible shim APIs that
had different versions.

ghstack-source-id: 46e533fccd3ca5e4e1da5421c5b561294fbdecc1
Pull Request resolved: #113487
@chenyang78
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 13, 2023
@chenyang78 chenyang78 added the topic: not user facing topic category label Nov 13, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
In our first implemenation of the sdpa shim api, we didn't consider
the case where the optional scale argument could be None. It was
unnoticed because we always got a default argument for the cuda backend.
The issue was detected with the cpu backend.

This PR implements versioning for shim kernels. Currently, we only
have different versions for the sdpa api. We expect we would only
maintain a very small number of abi-compatible shim APIs that
had different versions.

Pull Request resolved: pytorch#113487
Approved by: https://github.com/int3, https://github.com/desertfire
@facebook-github-bot facebook-github-bot deleted the gh/chenyang78/3/head branch November 17, 2023 15:30
int3 added a commit that referenced this pull request Dec 1, 2023
This is a backout of #113747 which reverted the above two commits.

[ghstack-poisoned]
int3 added a commit that referenced this pull request Dec 1, 2023
This is a backout of #113747 which reverted the above two commits.

ghstack-source-id: f5a794fdf03bbbccc8ad416d1f3960aa02052c6a
Pull Request resolved: #114974
int3 added a commit that referenced this pull request Dec 1, 2023
This is a backout of #113747 which reverted the above two commits.

[ghstack-poisoned]
int3 added a commit that referenced this pull request Dec 1, 2023
This is a backout of #113747 which reverted the above two commits.

ghstack-source-id: fecfab474deec140005a36f81fdd72b4528f9f26
Pull Request resolved: #114990
int3 added a commit that referenced this pull request Dec 2, 2023
…pport)"


This is a backout of #113747 which reverted the above two commits. Now that
#113997 has landed, this diff can be landed safely without breaking ABI compatibility.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
int3 added a commit that referenced this pull request Dec 2, 2023
This is a backout of #113747 which reverted the above two commits.

Pull Request resolved: #114974
ghstack-source-id: fecfab474deec140005a36f81fdd72b4528f9f26
pytorchmergebot pushed a commit that referenced this pull request Dec 2, 2023
…4974)

This is a backout of #113747 which reverted the above two commits. Now that
#113997 has landed, this diff can be landed safely without breaking ABI compatibility.

Pull Request resolved: #114974
Approved by: https://github.com/chenyang78
dmenig pushed a commit to dmenig/pytorch that referenced this pull request Dec 21, 2023
… support) (pytorch#114974)

This is a backout of pytorch#113747 which reverted the above two commits. Now that
pytorch#113997 has landed, this diff can be landed safely without breaking ABI compatibility.

Pull Request resolved: pytorch#114974
Approved by: https://github.com/chenyang78
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants