-
Notifications
You must be signed in to change notification settings - Fork 21.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support fp8 in AOTInductor + support optional<> in C ABI #112527
Conversation
This was originally @ipiszy's PR: #112358 It turns out that we need to add support for optional types in order to support fp8 gemm (i.e. scaled_mm). Since our ABI-stable C interface can't support optional<> directly, I have created a ShimOptional struct instead. This ShimOptional is used only for non-pointer optional types; pointer optionals can have their nullopt value represented by `nullptr`. I decided to create a ShimOptional instead of adding an extra `bool` param to the callee because this simplifies things. Having the same number of arguments regardless of whether we are emitting Python / C++ / ABI-compatible C++ makes codegen easier. There are a number of existing ABI-compatible functions that have optional-typed parameters. Previously, they just assumed they would never be passed a `nullopt` / `None` at runtime. Changing them to use ShimOptional now would break ABI stability, so I have created an exclude list for those functions. Finally, I think the current implementation is kind of messy, pulling in argument type info from a variety of places, and possibly missing some edge cases with const arg codegen. I've left a bunch of FIXME comments; would appreciate feedback on whether I could improve things. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112527
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2b1de23 with merge base 78b8465 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This was originally ipiszy's PR: #112358 It turns out that we need to add support for optional types in order to support fp8 gemm (i.e. scaled_mm). Since our ABI-stable C interface can't support optional<> directly, I have created a ShimOptional struct instead. This ShimOptional is used only for non-pointer optional types; pointer optionals can have their nullopt value represented by `nullptr`. I decided to create a ShimOptional instead of adding an extra `bool` param to the callee because this simplifies things. Having the same number of arguments regardless of whether we are emitting Python / C++ / ABI-compatible C++ makes codegen easier. There are a number of existing ABI-compatible functions that have optional-typed parameters. Previously, they just assumed they would never be passed a `nullopt` / `None` at runtime. Changing them to use ShimOptional now would break ABI stability, so I have created an exclude list for those functions. Finally, I think the current implementation is kind of messy, pulling in argument type info from a variety of places, and possibly missing some edge cases with const arg codegen. I've left a bunch of FIXME comments; would appreciate feedback on whether I could improve things. ghstack-source-id: ae2aed50f8f0a43095881f39c6aca281b8df2724 Pull Request resolved: #112527
torch/_inductor/ir.py
Outdated
@@ -3542,7 +3542,8 @@ def apply_constraint(self): | |||
pass | |||
|
|||
def codegen_const_args(self): | |||
return map(V.graph.wrapper_code.val_to_arg_str, self.constant_args) | |||
# FIXME: separate invocation for cpp arg strs? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe this is fine since looking at its caller, codegen_args
, it looks like the codegen is meant to target Python only
I'm overall confused by the difference between ExternKernels and FallbackKernels though, and how codegen logic is split between the two. Both of them define codegen_args
, but they share a single implementation of codegen_kwargs
...
Thanks @int3 ! I wonder have we considered using a pointer to represent optional instead of introducing a boolean value? Code would look cleaner in this case. wrt existing APIs which are affected by this change: can we add new APIs and deprecate old APIs in parallel? e.g. We could add a new set of APIs in this PR, and then after the PR is released in prod, we remove the legacy branching logics and hard code kernel names of these operators with new API names. |
Not sure what you have in mind here. You mean something like
Yeah we should do that. I think it might be easier to do it in a follow-up PR though |
I also think using pointers might be better. The idea would be that we always pass a pointer of the element type of an optional argument, where (1) we use NULL to present c10::nullopt; and For example, for
or
Because we always know the element type of the optional argument in shim, we wouldn't have ambiguity. Moreover, we wouldn't have to implement our own Optional struct to handle various types. |
torch/_inductor/ir.py
Outdated
if hasattr(self, "kwargs_default_value"): | ||
type_ = self.kwargs_default_value.get(arg_name).get("type") | ||
else: | ||
type_ = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may want to just throw an exception if we couldn't get a valid type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That creates a lot of runtime errors... the issue is that kwargs_default_value
is defined only on FallbackKernel
, but this method is defined on ExternKernel
. I was hoping @desertfire or @jansel might be able to suggest a better way of getting the type here, and/or if this is a thing that I need to concern myself with for non-fallback ExternKernel
s
@@ -4166,14 +4179,21 @@ def is_not_write(arg): | |||
x.name for x in kernel._schema.arguments if x.kwarg_only | |||
] | |||
|
|||
def is_legacy_abi_kernel(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if we really need this. For these two kernels, all optional arguments have real values, so we wouldn't hit the c10::nullopt path. Moreover, we could add another version of interface, e.g. aoti_torch__scaled_dot_product_flash_attention_nullopt
to handle the missing nullopt cases. Relying on this is_legacy_abi_kernel
looks very hacky to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to distinguish these two models, but I agree with @chenyang78 that generating these two fallback functions names differently in the wrapper codegen is a cleaner solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For these two kernels, all optional arguments have real values, so we wouldn't hit the c10::nullopt path.
We would be changing the way the real values get passed in though (regardless of whether we are doing the ShimOptional or pointer approach).
But okay I am happy to do the other less hacky approach once we figure out the issues with getting the valid type above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even if we add new versions of these two interfaces (actually we only need to care about the flash_attention one and the other is not used in prod), to make sure the newly published snapshots work with old predictor binaries, we still need to keep the branch until the new API is released in predictor, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If repeat_interleave_Tensor
is not in production, let's fix it by changing its shim API. cc @adnanaziz
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you meant to tag @aakhundov :)
Just checked with him, he says it's fine to change
AtenTensorHandle scale_a, | ||
AtenTensorHandle scale_b, | ||
AtenTensorHandle scale_result, | ||
bool use_fast_accum, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not use bool
in the C interface.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, why not? Is it because C bools and C++ bools are subtly different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is sufficient. What if the call passes in a float?
The idea is to use reinterpret_cast
so the type doesn't matter. Although I see that I forgot to add that into the codegen. But I see what @chenyang78 meant by using pointers instead... that might be nicer, I'll give it a shot.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like aoti_torch__scaled_dot_product_flash_attention
is also using bool
s
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://stackoverflow.com/questions/40020423/getting-bool-from-c-to-c-and-back sounds like this is relatively safe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As described in that stackoverflow thread, "C's and C++'s bool type are different, but, as long as you stick to the same compiler (in your case, gcc), it should be safe, as this is a reasonable common scenario.", but we can't make that assumption here, as we don't know what the model.so
was compiled with.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I thought we were only concerned with models compiled under fbcode. Okay, I'll change it.
@@ -4166,14 +4179,21 @@ def is_not_write(arg): | |||
x.name for x in kernel._schema.arguments if x.kwarg_only | |||
] | |||
|
|||
def is_legacy_abi_kernel(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to distinguish these two models, but I agree with @chenyang78 that generating these two fallback functions names differently in the wrapper codegen is a cleaner solution.
This was originally ipiszy's PR: #112358 It turns out that we need to add support for optional types in order to support fp8 gemm (i.e. scaled_mm). Since our ABI-stable C interface can't support optional<> directly, I have created a ShimOptional struct instead. This ShimOptional is used only for non-pointer optional types; pointer optionals can have their nullopt value represented by `nullptr`. I decided to create a ShimOptional instead of adding an extra `bool` param to the callee because this simplifies things. Having the same number of arguments regardless of whether we are emitting Python / C++ / ABI-compatible C++ makes codegen easier. There are a number of existing ABI-compatible functions that have optional-typed parameters. Previously, they just assumed they would never be passed a `nullopt` / `None` at runtime. Changing them to use ShimOptional now would break ABI stability, so I have created an exclude list for those functions. Finally, I think the current implementation is kind of messy, pulling in argument type info from a variety of places, and possibly missing some edge cases with const arg codegen. I've left a bunch of FIXME comments; would appreciate feedback on whether I could improve things. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
This was originally ipiszy's PR: #112358 It turns out that we need to add support for optional types in order to support fp8 gemm (i.e. scaled_mm). Since our ABI-stable C interface can't support optional<> directly, I have created a ShimOptional struct instead. This ShimOptional is used only for non-pointer optional types; pointer optionals can have their nullopt value represented by `nullptr`. I decided to create a ShimOptional instead of adding an extra `bool` param to the callee because this simplifies things. Having the same number of arguments regardless of whether we are emitting Python / C++ / ABI-compatible C++ makes codegen easier. There are a number of existing ABI-compatible functions that have optional-typed parameters. Previously, they just assumed they would never be passed a `nullopt` / `None` at runtime. Changing them to use ShimOptional now would break ABI stability, so I have created an exclude list for those functions. Finally, I think the current implementation is kind of messy, pulling in argument type info from a variety of places, and possibly missing some edge cases with const arg codegen. I've left a bunch of FIXME comments; would appreciate feedback on whether I could improve things. ghstack-source-id: 4d47aecb0984714b0291bddae22fc2f592648b2f Pull Request resolved: #112527
This was originally ipiszy's PR: #112358 It turns out that we need to add support for optional types in order to support fp8 gemm (i.e. scaled_mm). Since our ABI-stable C interface can't support optional<> directly, I am passing in optional types via pointer instead. `AtenTensorHandle`s are already pointers, so nothing needs to change there. Only value types need to change. We decided on this approach instead of adding an extra `bool` param to the callee because this simplifies things. Having the same number of arguments regardless of whether we are emitting Python / C++ / ABI-compatible C++ makes codegen easier. There are a number of existing ABI-compatible functions that have optional-typed value parameters. Previously, they just assumed they would never be passed a `nullopt` / `None` at runtime. Changing them to use pointer types now would break ABI stability, so I have created an exclude list for those functions. Finally, I think the current implementation is kind of messy, and only works for FallbackKernels, even though technically ExternKernels could also have the same issue. It also doesn't support optional types nested in lists. I've left FIXME comments for both issues. [ghstack-poisoned]
This was originally ipiszy's PR: #112358 It turns out that we need to add support for optional types in order to support fp8 gemm (i.e. scaled_mm). Since our ABI-stable C interface can't support optional<> directly, I am passing in optional types via pointer instead. `AtenTensorHandle`s are already pointers, so nothing needs to change there. Only value types need to change. We decided on this approach instead of adding an extra `bool` param to the callee because this simplifies things. Having the same number of arguments regardless of whether we are emitting Python / C++ / ABI-compatible C++ makes codegen easier. There are a number of existing ABI-compatible functions that have optional-typed value parameters. Previously, they just assumed they would never be passed a `nullopt` / `None` at runtime. Changing them to use pointer types now would break ABI stability, so I have created an exclude list for those functions. Finally, I think the current implementation is kind of messy, and only works for FallbackKernels, even though technically ExternKernels could also have the same issue. It also doesn't support optional types nested in lists. I've left FIXME comments for both issues. ghstack-source-id: 611cce9120a1a15c45a1fa79697ad23fc6f1233b Pull Request resolved: #112527
|
torch/_inductor/lowering.py
Outdated
@@ -2018,7 +2018,6 @@ def apply_constraint(arg, fx_arg): | |||
make_fallback(aten._thnn_fused_lstm_cell, require_dense) | |||
make_fallback(aten.topk) | |||
make_fallback(aten.upsample_bicubic2d_backward, require_contiguous) | |||
make_fallback(aten._scaled_mm.default) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel we probably need to keep this since we haven't added lowering for _scaled_mm? (I'm a bit confused about this as well). Have you tried that test/inductor/test_fp8.py
can run successfully if we remove this fallback?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, I've checked and the test passes
not entirely sure either about how the fallback mechanism works though
This was originally ipiszy's PR: #112358 It turns out that we need to add support for optional types in order to support fp8 gemm (i.e. scaled_mm). Since our ABI-stable C interface can't support optional<> directly, I am passing in optional types via pointer instead. `AtenTensorHandle`s are already pointers, so nothing needs to change there. Only value types need to change. We decided on this approach instead of adding an extra `bool` param to the callee because this simplifies things. Having the same number of arguments regardless of whether we are emitting Python / C++ / ABI-compatible C++ makes codegen easier. There are a number of existing ABI-compatible functions that have optional-typed value parameters. Previously, they just assumed they would never be passed a `nullopt` / `None` at runtime. Changing them to use pointer types now would break ABI stability, so I have created an exclude list for those functions. Finally, I think the current implementation is kind of messy, and only works for FallbackKernels, even though technically ExternKernels could also have the same issue. It also doesn't support optional types nested in lists. I've left FIXME comments for both issues. [ghstack-poisoned]
This was originally ipiszy's PR: #112358 It turns out that we need to add support for optional types in order to support fp8 gemm (i.e. scaled_mm). Since our ABI-stable C interface can't support optional<> directly, I am passing in optional types via pointer instead. `AtenTensorHandle`s are already pointers, so nothing needs to change there. Only value types need to change. We decided on this approach instead of adding an extra `bool` param to the callee because this simplifies things. Having the same number of arguments regardless of whether we are emitting Python / C++ / ABI-compatible C++ makes codegen easier. There are a number of existing ABI-compatible functions that have optional-typed value parameters. Previously, they just assumed they would never be passed a `nullopt` / `None` at runtime. Changing them to use pointer types now would break ABI stability, so I have created an exclude list for those functions. Finally, I think the current implementation is kind of messy, and only works for FallbackKernels, even though technically ExternKernels could also have the same issue. It also doesn't support optional types nested in lists. I've left FIXME comments for both issues. ghstack-source-id: 0bcf2f1f86d6d75ed21ce5ab99cf54f5e1f61736 Pull Request resolved: #112527
@pytorchbot merge |
Merge failedReason: This PR has internal changes and must be landed via Phabricator Details for Dev Infra teamRaised by workflow job |
Let's see if unlinking works... @pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
are we stuck with ABI stability already? |
We do need to maintain backwards compatibility for |
|
) This was originally ipiszy's PR: pytorch#112358 It turns out that we need to add support for optional types in order to support fp8 gemm (i.e. scaled_mm). Since our ABI-stable C interface can't support optional<> directly, I am passing in optional types via pointer instead. `AtenTensorHandle`s are already pointers, so nothing needs to change there. Only value types need to change. We decided on this approach instead of adding an extra `bool` param to the callee because this simplifies things. Having the same number of arguments regardless of whether we are emitting Python / C++ / ABI-compatible C++ makes codegen easier. There are a number of existing ABI-compatible functions that have optional-typed value parameters. Previously, they just assumed they would never be passed a `nullopt` / `None` at runtime. Changing them to use pointer types now would break ABI stability, so I have created an exclude list for those functions. Finally, I think the current implementation is kind of messy, and only works for FallbackKernels, even though technically ExternKernels could also have the same issue. It also doesn't support optional types nested in lists. I've left FIXME comments for both issues. Differential Revision: [D51084289](https://our.internmc.facebook.com/intern/diff/D51084289) Pull Request resolved: pytorch#112527 Approved by: https://github.com/chenyang78, https://github.com/desertfire
…ytorch#112527)" Summary: ABI breakage issue https://docs.google.com/document/d/1E7mhxOF4SvVO6r5rkrGKci7y4wfTj2RiT8UgeCl2tb0/edit It blocked the QRT of Ads PT2 model and MRS models. Since the lowering service red zone starts Nov 16th, let's mitigate in this way: Let's revert this diff first, then Jing can pick up this patch. After that, let's revert this backout. Test Plan: sandcastle Reviewed By: chenyang78, zoranzhao Differential Revision: D51326063
…ytorch#112527)" Test Plan: sandcastle Differential Revision: D51330618
…112527)" (#113747) Test Plan: sandcastle Differential Revision: D51330618 Pull Request resolved: #113747 Approved by: https://github.com/chenyang78, https://github.com/khabinov
…n C ABI (pytorch#112527)"" Test Plan: sandcastle Differential Revision: D51387179
…pport)" This is a backout of #113747 which reverted the above two commits. Now that #113997 has landed, this diff can be landed safely without breaking ABI compatibility. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
…4974) This is a backout of #113747 which reverted the above two commits. Now that #113997 has landed, this diff can be landed safely without breaking ABI compatibility. Pull Request resolved: #114974 Approved by: https://github.com/chenyang78
… support) (pytorch#114974) This is a backout of pytorch#113747 which reverted the above two commits. Now that pytorch#113997 has landed, this diff can be landed safely without breaking ABI compatibility. Pull Request resolved: pytorch#114974 Approved by: https://github.com/chenyang78
… support) (pytorch#114974) This is a backout of pytorch#113747 which reverted the above two commits. Now that pytorch#113997 has landed, this diff can be landed safely without breaking ABI compatibility. Pull Request resolved: pytorch#114974 Approved by: https://github.com/chenyang78
Stack from ghstack (oldest at bottom):
This was originally ipiszy's PR: #112358
It turns out that we need to add support for optional types in order to
support fp8 gemm (i.e. scaled_mm). Since our ABI-stable C interface
can't support optional<> directly, I am passing in optional types via
pointer instead.
AtenTensorHandle
s are already pointers, so nothing needs to changethere. Only value types need to change.
We decided on this approach instead of adding an extra
bool
param tothe callee because this simplifies things. Having the same number of
arguments regardless of whether we are emitting Python / C++ /
ABI-compatible C++ makes codegen easier.
There are a number of existing ABI-compatible functions that have
optional-typed value parameters. Previously, they just assumed they
would never be passed a
nullopt
/None
at runtime. Changing them touse pointer types now would break ABI stability, so I have created an
exclude list for those functions.
Finally, I think the current implementation is kind of messy, and only
works for FallbackKernels, even though technically ExternKernels could
also have the same issue. It also doesn't support optional types nested
in lists. I've left FIXME comments for both issues.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler
Differential Revision: D51084289