Skip to content

Commit

Permalink
Update on "[ONNX] Add test_fx_op_consistency.py"
Browse files Browse the repository at this point in the history
Should be merged after #99434

<!--
copilot:all
-->
### <samp>馃 Generated by Copilot at f59c428</samp>

### Summary
馃摑馃攧馃殌

<!--
1.  馃摑 - This emoji represents the improvement of the documentation and type annotation of the ONNX exporter and its tests.
2.  馃攧 - This emoji represents the renaming of the `dont_care` function to `skip` and the update of the references and documentation. This change reflects a refactoring and improvement of the code quality and readability.
3.  馃殌 - This emoji represents the enhancement of the ONNX exporter's ability to handle more input and output types for PyTorch models. This change increases the performance and functionality of the exporter.
-->
This pull request enhances the ONNX exporter and its tests to handle more input and output types for PyTorch models. It improves the type annotation, tolerance handling, and documentation of the exporter and its test functions. It also renames a test function to make the code more consistent and clear.

> _To export PyTorch models with ONNX_
> _We need to handle various contexts_
> _We improved the annotation_
> _And the tolerance function_
> _And renamed `dont_care` to `skip` for the tests_

### Walkthrough
*  Expand the `_InputArgsType` type annotation to include int, float, and bool types, in addition to torch.Tensor, Sequence, and Mapping, to support more types of inputs for PyTorch models ([link](https://github.com/pytorch/pytorch/pull/99465/files?diff=unified&w=0#diff-c8fa56eefd7f98fb4f9739d57df57f02ede77e28528133736010a6d06651ebcbL44-R46), [link](https://github.com/pytorch/pytorch/pull/99465/files?diff=unified&w=0#diff-0795f54fd1f38cfbf2c4a863a4efc9f40f2ea020a2b1612605c361b8d8d35862L197-R199), [link](https://github.com/pytorch/pytorch/pull/99465/files?diff=unified&w=0#diff-0795f54fd1f38cfbf2c4a863a4efc9f40f2ea020a2b1612605c361b8d8d35862L291-R295)).



[ghstack-poisoned]
  • Loading branch information
titaiwangms committed Apr 27, 2023
2 parents 752ce43 + dc8467e commit eb99f60
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
7 changes: 4 additions & 3 deletions test/onnx/test_fx_op_consistency.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Owner(s): ["module: onnx"]

"""Test consistency between the output values of torch.onnx exported operators
"""Test consistency between the output values of torch.onnx FX exported operators
and torch operators given the same inputs.
Usage:
pytest test/onnx/test_fx_op_consistency.py
pytest test/onnx/test_op_consistency.py
To run tests on a specific operator (e.g. torch.ceil):
pytest test/onnx/test_fx_op_consistency.py -k ceil
pytest test/onnx/test_op_consistency.py -k ceil
pytest test/onnx/test_op_consistency.py -k nn_functional_scaled_dot_product_attention
Read more on Running and writing tests:
https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
Expand Down
1 change: 1 addition & 0 deletions test/onnx/test_op_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def reason_flaky() -> str:
fixme("nn.functional.scaled_dot_product_attention", reason="fixme: ORT crashes on Windows, segfaults randomly on Linux"),
skip("sqrt", dtypes=BOOL_TYPES, reason=reason_onnx_does_not_support("Sqrt")),
skip("stft", opsets=[opsets_before(17)], reason=reason_onnx_does_not_support("STFT")),
skip("tile", opsets=[opsets_before(13)], reason=reason_onnx_does_not_support("Tile")),
fixme("unflatten", opsets=[opsets_before(13)], reason="Helper function is needed to support legacy ops."),
)
# fmt: on
Expand Down

0 comments on commit eb99f60

Please sign in to comment.