-
Notifications
You must be signed in to change notification settings - Fork 22k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[aotinductor] add versions for the sdpa shim api #113487
Conversation
In our first implemenation of the sdpa shim api, we didn't consider the case where the optional scale argument could be None. It was unnoticed because we always got a default argument for the cuda backend. The issue was detected with the cpu backend. This PR implements versioning for shim kernels. Currently, we only have different versions for the sdpa api. We expect we would only maintain a very small number of abi-compatible shim APIs that had different versions. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/113487
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a3b5d67 with merge base b910d9e (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -252,14 +252,14 @@ AOTITorchError aoti_torch_create_tensor_from_blob( | |||
}); | |||
} | |||
|
|||
AOTITorchError aoti_torch__scaled_dot_product_flash_attention( | |||
static AOTITorchError _aoti_torch__scaled_dot_product_flash_attention_internal( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need a new _internal
function -- we could just have aoti_torch__scaled_dot_product_flash_attention
call aoti_torch__scaled_dot_product_flash_attention_v2
while passing the scale param as &scale
. The v2 function would then call at::_scaled_dot_product_flash_attention
directly.
# For sdpa, we need the v2 version only if any optional | ||
# kwarg is missing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the idea was to make all new models call the v2 version, so that we could delete the old one entirely? Doesn't using the v2 version only if an optional arg is missing make migration harder?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the idea was to make all new models call the v2 version, so that we could delete the old one entirely? Doesn't using the v2 version only if an optional arg is missing make migration harder?
It's not about new models v.s. old models. Given the same model, the optional scale could be either None or have a default value, for different backends (e.g. cpu and cuda in our case). Regarding migrating to the V2 version, my understanding is that once the new V2 version becomes available on the serving side, we could always generate the V2 version. At that point, we may safely remove the old one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh right. We want to keep generating calls to the old (v1) version where possible for now so the CUDA models keep working with old libtorch deployments. Got it
torch/_inductor/ir.py
Outdated
@@ -4278,6 +4302,7 @@ def __init__( | |||
self.kernel = ( | |||
f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}" | |||
) | |||
self.abi_compatible_kernel = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure why the typechecker didn't complain, but shouldn't this line be in the constructor of ExternKernelAlloc
, since ExternKernelAlloc.codegen
references it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Fixed. Thanks.
In our first implemenation of the sdpa shim api, we didn't consider the case where the optional scale argument could be None. It was unnoticed because we always got a default argument for the cuda backend. The issue was detected with the cpu backend. This PR implements versioning for shim kernels. Currently, we only have different versions for the sdpa api. We expect we would only maintain a very small number of abi-compatible shim APIs that had different versions. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
In our first implemenation of the sdpa shim api, we didn't consider the case where the optional scale argument could be None. It was unnoticed because we always got a default argument for the cuda backend. The issue was detected with the cpu backend. This PR implements versioning for shim kernels. Currently, we only have different versions for the sdpa api. We expect we would only maintain a very small number of abi-compatible shim APIs that had different versions. ghstack-source-id: 8283d5506c21091ddebd3dc18b56b99691a9c933 Pull Request resolved: #113487
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a comment about the function signature. The rest LGTM.
AtenTensorHandle key, | ||
AtenTensorHandle value, | ||
double dropout_p, | ||
bool is_causal, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't paid close enough attention to the v1 function signature, but let's use int instead of bool in the v2 function signature to avoid any potential problems from mixing c++ compilers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Fixed. Thanks.
In our first implemenation of the sdpa shim api, we didn't consider the case where the optional scale argument could be None. It was unnoticed because we always got a default argument for the cuda backend. The issue was detected with the cpu backend. This PR implements versioning for shim kernels. Currently, we only have different versions for the sdpa api. We expect we would only maintain a very small number of abi-compatible shim APIs that had different versions. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
In our first implemenation of the sdpa shim api, we didn't consider the case where the optional scale argument could be None. It was unnoticed because we always got a default argument for the cuda backend. The issue was detected with the cpu backend. This PR implements versioning for shim kernels. Currently, we only have different versions for the sdpa api. We expect we would only maintain a very small number of abi-compatible shim APIs that had different versions. ghstack-source-id: 46e533fccd3ca5e4e1da5421c5b561294fbdecc1 Pull Request resolved: #113487
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
In our first implemenation of the sdpa shim api, we didn't consider the case where the optional scale argument could be None. It was unnoticed because we always got a default argument for the cuda backend. The issue was detected with the cpu backend. This PR implements versioning for shim kernels. Currently, we only have different versions for the sdpa api. We expect we would only maintain a very small number of abi-compatible shim APIs that had different versions. Pull Request resolved: pytorch#113487 Approved by: https://github.com/int3, https://github.com/desertfire
…pport)" This is a backout of #113747 which reverted the above two commits. Now that #113997 has landed, this diff can be landed safely without breaking ABI compatibility. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…4974) This is a backout of #113747 which reverted the above two commits. Now that #113997 has landed, this diff can be landed safely without breaking ABI compatibility. Pull Request resolved: #114974 Approved by: https://github.com/chenyang78
… support) (pytorch#114974) This is a backout of pytorch#113747 which reverted the above two commits. Now that pytorch#113997 has landed, this diff can be landed safely without breaking ABI compatibility. Pull Request resolved: pytorch#114974 Approved by: https://github.com/chenyang78
Stack from ghstack (oldest at bottom):
In our first implemenation of the sdpa shim api, we didn't consider
the case where the optional scale argument could be None. It was
unnoticed because we always got a default argument for the cuda backend.
The issue was detected with the cpu backend.
This PR implements versioning for shim kernels. Currently, we only
have different versions for the sdpa api. We expect we would only
maintain a very small number of abi-compatible shim APIs that
had different versions.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @kadeng @muchulee8 @aakhundov @ColinPeppler