-
Couldn't load subscription status.
- Fork 25.7k
[ONNX] Add initial support for FP8 ONNX export #107962
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/107962
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New Failure, 2 Unrelated FailuresAs of commit 3f8fb7a with merge base c68d0a7 ( NEW FAILURE - The following job has failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@justinchuby please review as the conflicts were with #107829. Supporting FP8 here means opset 19 in TS. |
test/onnx/test_op_consistency.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we test opset18? We don’t support it in onnx.export
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't until we add helper functions inputs/attributes changes into symbolic_opset*.py. I think CI would break if we bump it in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean opset 18 is not supported by the onnx exporter? [col2im op]
(https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset18.py) was implemented at symbolic_opset18.py, for example. This is the onnxruntime unit test for it
Maybe some tests need to be skipped for opset 18, but col2im is an op needed by several models out there, including internal customers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(#107829) So opset18 support is dependent on the various Reduce* ops being updated because its axis attribute was promoted to input. There is a good number of ops using reduce directly or implicitly, and fixing all of them would not be realistic for 2.1, if at all. To avoid confusions for users we say opset18 is not supported, so they don't export a model to opset18 only to find more errors down the road. However, users can still choose to ignore the warning if they know what they are doing (e.g. when they need to use col2im)
Add support for ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2 and ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN to enable export of torch models that use FP8 (E4M3 and E5M2) to ONNX (opset 19)
Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com>
Define constants for the FP8 ONNX types to avoid breaking Meta's internal usage of ONNX which pre-dates 1.14 and thus does not support FLOAT8 types. - TensorProto_DataType_FLOAT8E4M3FN=17 - TensorProto_DataType_FLOAT8E5M2=19 cf. #106379 (comment)
|
@kit1980 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
I've imported this to verify the Meta-internal builds. |
|
@kit1980 do you need to import again? Thanks! |
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 3 checks: pull / linux-focal-py3-clang9-android-ndk-r19c-gradle-custom-build-single-full-jit / build-and-test (default, 1, 1, linux.2xlarge), pull / linux-focal-py3-clang9-android-ndk-r19c-gradle-custom-build-single / build-and-test (default, 1, 1, linux.2xlarge), Meta Internal-Only Changes Check Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-focal-rocm5.6-py3.8 / test (default, 3, 3, linux.rocm.gpu) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 2 checks: trunk / linux-focal-rocm5.6-py3.8 / test (default, 3, 3, linux.rocm.gpu), Meta Internal-Only Changes Check Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 mandatory check(s) failed. The first few are:
Dig deeper by viewing the failures on hud |
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 4 checks: pull / linux-focal-py3-clang9-android-ndk-r19c-gradle-custom-build-single / build-and-test (default, 1, 1, linux.2xlarge), pull / linux-focal-py3-clang9-android-ndk-r19c-gradle-custom-build-single-full-jit / build-and-test (default, 1, 1, linux.2xlarge), trunk / linux-focal-rocm5.6-py3.8 / test (default, 3, 3, linux.rocm.gpu), Meta Internal-Only Changes Check Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
| {c10::kDouble, 11}, | ||
| {c10::kQInt8, 12}, | ||
| {c10::kQUInt8, 13}, | ||
| {c10::kQInt32, 14}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BowenBao looks like the enums are wrong for qint8/bfloat16 etc. But that's ok for the release because we don't need this pass for the dynamo exporter.
This PR resurrects @tcherckez-nvidia's #106379 with changes to resolve conflicts against newer `main` and defines our own constants for the new ONNX types to [avoid breaking Meta's internal usage of an old ONNX](#106379 (comment)). - `::torch::onnx::TensorProto_DataType_FLOAT8E4M3FN=17` - `::torch::onnx::TensorProto_DataType_FLOAT8E5M2=19` Pull Request resolved: #107962 Approved by: https://github.com/justinchuby, https://github.com/titaiwangms
|
This PR resurrects @tcherckez-nvidia's #106379 with changes to resolve conflicts against newer `main` and defines our own constants for the new ONNX types to [avoid breaking Meta's internal usage of an old ONNX](#106379 (comment)). - `::torch::onnx::TensorProto_DataType_FLOAT8E4M3FN=17` - `::torch::onnx::TensorProto_DataType_FLOAT8E5M2=19` Pull Request resolved: #107962 Approved by: https://github.com/justinchuby, https://github.com/titaiwangms Co-authored-by: Aaron Bockover <abock@microsoft.com>
This PR resurrects @tcherckez-nvidia's #106379 with changes to resolve conflicts against newer
mainand defines our own constants for the new ONNX types to avoid breaking Meta's internal usage of an old ONNX.::torch::onnx::TensorProto_DataType_FLOAT8E4M3FN=17::torch::onnx::TensorProto_DataType_FLOAT8E5M2=19