-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Add runtime type checking to export
#83673
Conversation
🔗 Helpful links
✅ No Failures (5 Pending)As of commit de9a878 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
With this we found
|
4f1456b
to
bda1a25
Compare
Supports warning thanks to discussion in beartype/beartype#159 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fantastic!
Mypy should be happy now |
7dde277
to
4aab58e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀
Ah, please check onnx CI failures, looks related to the return annotation stuff. |
Just realized and fixed |
@pytorchbot merge -g |
@pytorchbot successfully started a merge job. Check the current status here. |
merged 🎉 special thanks to @leycec. Thank you for your generous help! |
Wow! Eternal gratitude to @justinchuby for pushing this magic forward, @BowenBao for the deep review, and the whole PyTorch DevX Team for their continued excellence. Thanks so much, everyone. |
Summary: This PR adds an internal wrapper on the [beartype](https://github.com/beartype/beartype) library to perform runtime type checking in `torch.onnx`. It uses beartype when it is found in the environment and is reduced to a no-op when beartype is not found. Setting the env var `TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK=ERRORS` will turn on the feature. setting `TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK=DISABLED` will disable all checks. When not set and `beartype` is installed, a warning message is emitted. Now when users call an api with invalid arguments e.g. ```python torch.onnx.export(conv, y, path, export_params=True, training=False) # traning should take TrainingModel, not bool ``` they get ``` Traceback (most recent call last): File "bisect_m1_error.py", line 63, in <module> main() File "bisect_m1_error.py", line 59, in main reveal_error() File "bisect_m1_error.py", line 32, in reveal_error torch.onnx.export(conv, y, cpu_model_path, export_params=True, training=False) File "<beartype(torch.onnx.utils.export) at 0x1281f5a60>", line 136, in export File "pytorch/venv/lib/python3.9/site-packages/beartype/_decor/_error/errormain.py", line 301, in raise_pep_call_exception raise exception_cls( # type: ignore[misc] beartype.roar.BeartypeCallHintParamViolation: beartyped export() parameter training=False violates type hint <class 'torch._C._onnx.TrainingMode'>, as False not instance of <protocol "torch._C._onnx.TrainingMode">. ``` when `TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK` is not set and `beartype` is installed, a warning message is emitted. ``` >>> torch.onnx.export("foo", "bar", "f") <stdin>:1: CallHintViolationWarning: Traceback (most recent call last): File "/home/justinchu/dev/pytorch/torch/onnx/_internal/_beartype.py", line 54, in _coerce_beartype_exceptions_to_warnings return beartyped(*args, **kwargs) File "<beartype(torch.onnx.utils.export) at 0x7f1d4ab35280>", line 39, in export File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.9/site-packages/beartype/_decor/_error/errormain.py", line 301, in raise_pep_call_exception raise exception_cls( # type: ignore[misc] beartype.roar.BeartypeCallHintParamViolation: beartyped export() parameter model='foo' violates type hint typing.Union[torch.nn.modules.module.Module, torch.jit._script.ScriptModule, torch.jit.ScriptFunction], as 'foo' not <protocol "torch.jit.ScriptFunction">, <protocol "torch.nn.modules.module.Module">, or <protocol "torch.jit._script.ScriptModule">. Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/home/justinchu/dev/pytorch/torch/onnx/_internal/_beartype.py", line 63, in _coerce_beartype_exceptions_to_warnings return func(*args, **kwargs) File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 482, in export _export( File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 1422, in _export with exporter_context(model, training, verbose): File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.9/contextlib.py", line 119, in __enter__ return next(self.gen) File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 177, in exporter_context with select_model_mode_for_export( File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.9/contextlib.py", line 119, in __enter__ return next(self.gen) File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 95, in select_model_mode_for_export originally_training = model.training AttributeError: 'str' object has no attribute 'training' ``` We see the error is caught right when the type mismatch happens, improving from what otherwise would become `AttributeError: 'str' object has no attribute 'training'` Pull Request resolved: #83673 Approved by: https://github.com/BowenBao Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/bf25a140f9d948915d52f294b5204196d1ca22e3 Reviewed By: weiwangmeta Differential Revision: D39063732 fbshipit-source-id: 1e7ca60b574bb3ef37150c268d845e15b139831a
This PR adds an internal wrapper on the beartype library to perform runtime type checking in
torch.onnx
. It uses beartype when it is found in the environment and is reduced to a no-op when beartype is not found.Setting the env var
TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK=ERRORS
will turn on the feature. settingTORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK=DISABLED
will disable all checks. When not set andbeartype
is installed, a warning message is emitted.Now when users call an api with invalid arguments e.g.
they get
when
TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK
is not set andbeartype
is installed, a warning message is emitted.We see the error is caught right when the type mismatch happens, improving from what otherwise would become
AttributeError: 'str' object has no attribute 'training'