Skip to content

Conversation

@abock
Copy link
Contributor

@abock abock commented Aug 25, 2023

This PR resurrects @tcherckez-nvidia's #106379 with changes to resolve conflicts against newer main and defines our own constants for the new ONNX types to avoid breaking Meta's internal usage of an old ONNX.

  • ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FN=17
  • ::torch::onnx::TensorProto_DataType_FLOAT8E5M2=19

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 25, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/107962

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure, 2 Unrelated Failures

As of commit 3f8fb7a with merge base c68d0a7 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Aug 25, 2023
@abock
Copy link
Contributor Author

abock commented Aug 25, 2023

@justinchuby please review as the conflicts were with #107829. Supporting FP8 here means opset 19 in TS.

@justinchuby justinchuby added this to the 2.1.0 milestone Aug 25, 2023
@abock abock self-assigned this Aug 25, 2023
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we test opset18? We don’t support it in onnx.export

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't until we add helper functions inputs/attributes changes into symbolic_opset*.py. I think CI would break if we bump it in this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean opset 18 is not supported by the onnx exporter? [col2im op]
(https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset18.py) was implemented at symbolic_opset18.py, for example. This is the onnxruntime unit test for it

Maybe some tests need to be skipped for opset 18, but col2im is an op needed by several models out there, including internal customers

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(#107829) So opset18 support is dependent on the various Reduce* ops being updated because its axis attribute was promoted to input. There is a good number of ops using reduce directly or implicitly, and fixing all of them would not be realistic for 2.1, if at all. To avoid confusions for users we say opset18 is not supported, so they don't export a model to opset18 only to find more errors down the road. However, users can still choose to ignore the warning if they know what they are doing (e.g. when they need to use col2im)

@titaiwangms titaiwangms added module: onnx Related to torch.onnx topic: new features topic category labels Aug 25, 2023
tcherckez-nvidia and others added 6 commits August 25, 2023 19:04
Add support for ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2 and ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN to enable export of torch models that use FP8 (E4M3 and E5M2) to ONNX (opset 19)
Co-authored-by: Thiago Crepaldi <thiago.crepaldi@microsoft.com>
Define constants for the FP8 ONNX types to avoid breaking Meta's
internal usage of ONNX which pre-dates 1.14 and thus does not support
FLOAT8 types.

- TensorProto_DataType_FLOAT8E4M3FN=17
- TensorProto_DataType_FLOAT8E5M2=19

cf. #106379 (comment)
@facebook-github-bot
Copy link
Contributor

@kit1980 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@kit1980
Copy link
Contributor

kit1980 commented Aug 26, 2023

I've imported this to verify the Meta-internal builds.

@justinchuby justinchuby self-assigned this Sep 8, 2023
@justinchuby
Copy link
Collaborator

@kit1980 do you need to import again? Thanks!

@justinchuby justinchuby added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 8, 2023
@justinchuby
Copy link
Collaborator

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 3 checks: pull / linux-focal-py3-clang9-android-ndk-r19c-gradle-custom-build-single-full-jit / build-and-test (default, 1, 1, linux.2xlarge), pull / linux-focal-py3-clang9-android-ndk-r19c-gradle-custom-build-single / build-and-test (default, 1, 1, linux.2xlarge), Meta Internal-Only Changes Check

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-focal-rocm5.6-py3.8 / test (default, 3, 3, linux.rocm.gpu)

Details for Dev Infra team Raised by workflow job

@justinchuby
Copy link
Collaborator

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 2 checks: trunk / linux-focal-rocm5.6-py3.8 / test (default, 3, 3, linux.rocm.gpu), Meta Internal-Only Changes Check

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

@justinchuby
Copy link
Collaborator

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

{c10::kDouble, 11},
{c10::kQInt8, 12},
{c10::kQUInt8, 13},
{c10::kQInt32, 14},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this right?

Copy link
Collaborator

@justinchuby justinchuby Sep 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BowenBao looks like the enums are wrong for qint8/bfloat16 etc. But that's ok for the release because we don't need this pass for the dynamo exporter.

justinchuby pushed a commit that referenced this pull request Sep 9, 2023
This PR resurrects @tcherckez-nvidia's #106379 with changes to resolve conflicts against newer `main` and defines our own constants for the new ONNX types to [avoid breaking Meta's internal usage of an old ONNX](#106379 (comment)).

- `::torch::onnx::TensorProto_DataType_FLOAT8E4M3FN=17`
- `::torch::onnx::TensorProto_DataType_FLOAT8E5M2=19`
Pull Request resolved: #107962
Approved by: https://github.com/justinchuby, https://github.com/titaiwangms
@justinchuby
Copy link
Collaborator

justinchuby commented Sep 9, 2023

@justinchuby justinchuby deleted the abock/onnx-fp8 branch September 9, 2023 04:50
atalman pushed a commit that referenced this pull request Sep 11, 2023
This PR resurrects @tcherckez-nvidia's #106379 with changes to resolve conflicts against newer `main` and defines our own constants for the new ONNX types to [avoid breaking Meta's internal usage of an old ONNX](#106379 (comment)).

- `::torch::onnx::TensorProto_DataType_FLOAT8E4M3FN=17`
- `::torch::onnx::TensorProto_DataType_FLOAT8E5M2=19`
Pull Request resolved: #107962
Approved by: https://github.com/justinchuby, https://github.com/titaiwangms

Co-authored-by: Aaron Bockover <abock@microsoft.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: onnx Related to torch.onnx open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: new features topic category

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.