Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add runtime type checking to export #83673

Closed
wants to merge 40 commits into from

Conversation

justinchuby
Copy link
Collaborator

@justinchuby justinchuby commented Aug 18, 2022

This PR adds an internal wrapper on the beartype library to perform runtime type checking in torch.onnx. It uses beartype when it is found in the environment and is reduced to a no-op when beartype is not found.

Setting the env var TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK=ERRORS will turn on the feature. setting TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK=DISABLED will disable all checks. When not set and beartype is installed, a warning message is emitted.

Now when users call an api with invalid arguments e.g.

torch.onnx.export(conv, y, path, export_params=True, training=False)

# traning should take TrainingModel, not bool

they get

Traceback (most recent call last):
  File "bisect_m1_error.py", line 63, in <module>
    main()
  File "bisect_m1_error.py", line 59, in main
    reveal_error()
  File "bisect_m1_error.py", line 32, in reveal_error
    torch.onnx.export(conv, y, cpu_model_path, export_params=True, training=False)
  File "<@beartype(torch.onnx.utils.export) at 0x1281f5a60>", line 136, in export
  File "pytorch/venv/lib/python3.9/site-packages/beartype/_decor/_error/errormain.py", line 301, in raise_pep_call_exception
    raise exception_cls(  # type: ignore[misc]
beartype.roar.BeartypeCallHintParamViolation: @beartyped export() parameter training=False violates type hint <class 'torch._C._onnx.TrainingMode'>, as False not instance of <protocol "torch._C._onnx.TrainingMode">.

when TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK is not set and beartype is installed, a warning message is emitted.

>>> torch.onnx.export("foo", "bar", "f")
<stdin>:1: CallHintViolationWarning: Traceback (most recent call last):
  File "/home/justinchu/dev/pytorch/torch/onnx/_internal/_beartype.py", line 54, in _coerce_beartype_exceptions_to_warnings
    return beartyped(*args, **kwargs)
  File "<@beartype(torch.onnx.utils.export) at 0x7f1d4ab35280>", line 39, in export
  File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.9/site-packages/beartype/_decor/_error/errormain.py", line 301, in raise_pep_call_exception
    raise exception_cls(  # type: ignore[misc]
beartype.roar.BeartypeCallHintParamViolation: @beartyped export() parameter model='foo' violates type hint typing.Union[torch.nn.modules.module.Module, torch.jit._script.ScriptModule, torch.jit.ScriptFunction], as 'foo' not <protocol "torch.jit.ScriptFunction">, <protocol "torch.nn.modules.module.Module">, or <protocol "torch.jit._script.ScriptModule">.

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/justinchu/dev/pytorch/torch/onnx/_internal/_beartype.py", line 63, in _coerce_beartype_exceptions_to_warnings
    return func(*args, **kwargs)
  File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 482, in export
    _export(
  File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 1422, in _export
    with exporter_context(model, training, verbose):
  File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.9/contextlib.py", line 119, in __enter__
    return next(self.gen)
  File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 177, in exporter_context
    with select_model_mode_for_export(
  File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.9/contextlib.py", line 119, in __enter__
    return next(self.gen)
  File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 95, in select_model_mode_for_export
    originally_training = model.training
AttributeError: 'str' object has no attribute 'training'

We see the error is caught right when the type mismatch happens, improving from what otherwise would become AttributeError: 'str' object has no attribute 'training'

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Aug 18, 2022

🔗 Helpful links

✅ No Failures (5 Pending)

As of commit de9a878 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@justinchuby justinchuby marked this pull request as ready for review August 18, 2022 16:18
@justinchuby justinchuby added module: onnx Related to torch.onnx release notes: onnx torch.onnx related changes that should show up in the release notes topic: improvements topic category labels Aug 18, 2022
@justinchuby justinchuby marked this pull request as draft August 18, 2022 23:31
@justinchuby justinchuby marked this pull request as ready for review August 22, 2022 20:35
@justinchuby
Copy link
Collaborator Author

With this we found

――――――――――――――――――――――――――― TestONNXOpset.test_topk ――――――――――――――――――――――――――――
[gw0] linux -- Python 3.7.13 /opt/conda/bin/python
Traceback (most recent call last):
  File "/var/lib/jenkins/workspace/test/onnx/test_onnx_opset.py", line 121, in test_topk
    check_onnx_opsets_operator(module, [x, k], ops, opset_versions=[10])
  File "/var/lib/jenkins/workspace/test/onnx/test_onnx_opset.py", line 67, in check_onnx_opsets_operator
    dynamic_axes=dynamic_axes,
  File "<@beartype(torch.onnx.utils.export) at 0x7fb281514680>", line 64, in export
  File "/var/lib/jenkins/.local/lib/python3.7/site-packages/beartype/_decor/_error/errormain.py", line 302, in raise_pep_call_exception
    f'{exception_prefix}violates type hint {repr(hint)}, as '
beartype.roar.BeartypeCallHintParamViolation: @beartyped export() parameter args=[tensor([1., 2., 3., 4., 5.], requires_grad=True), tensor(3)] violates type hint typing.Union[typing.Tuple[typing.Any, ...], torch.Tensor], as [tensor([1., 2., 3., 4., 5.], requires_grad=True), tensor(3)] not tuple or <class "torch.Tensor">.

@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 23, 2022
@justinchuby
Copy link
Collaborator Author

Supports warning thanks to discussion in beartype/beartype#159

Copy link
Collaborator

@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic!

torch/onnx/_globals.py Outdated Show resolved Hide resolved
torch/onnx/_internal/_beartype.py Outdated Show resolved Hide resolved
torch/onnx/_internal/_beartype.py Outdated Show resolved Hide resolved
torch/onnx/utils.py Show resolved Hide resolved
test/onnx/test_onnx_opset.py Show resolved Hide resolved
.jenkins/caffe2/test.sh Outdated Show resolved Hide resolved
torch/onnx/_exporter_states.py Outdated Show resolved Hide resolved
torch/onnx/_globals.py Outdated Show resolved Hide resolved
torch/onnx/_internal/_beartype.py Outdated Show resolved Hide resolved
@justinchuby
Copy link
Collaborator Author

Mypy should be happy now

Copy link
Collaborator

@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀

torch/onnx/_internal/_beartype.py Outdated Show resolved Hide resolved
@BowenBao
Copy link
Collaborator

Ah, please check onnx CI failures, looks related to the return annotation stuff.

@justinchuby
Copy link
Collaborator Author

Just realized and fixed

@justinchuby
Copy link
Collaborator Author

@pytorchbot merge -g

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here.
The merge job was triggered with the green (-g) flag. This means that your change will be merged once all checks on your PR have passed (ETA: 0-4 Hours). If this is not the intended behavior, feel free to use some of the other merge options in the wiki.
Please reach out to the PyTorch DevX Team with feedback or questions!

@justinchuby
Copy link
Collaborator Author

merged 🎉 special thanks to @leycec. Thank you for your generous help!

@leycec
Copy link

leycec commented Aug 25, 2022

Wow! Eternal gratitude to @justinchuby for pushing this magic forward, @BowenBao for the deep review, and the whole PyTorch DevX Team for their continued excellence. Thanks so much, everyone.

facebook-github-bot pushed a commit that referenced this pull request Aug 26, 2022
Summary:
This PR adds an internal wrapper on the [beartype](https://github.com/beartype/beartype) library to perform runtime type checking in `torch.onnx`. It uses beartype when it is found in the environment and is reduced to a no-op when beartype is not found.

Setting the env var `TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK=ERRORS` will turn on the feature. setting `TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK=DISABLED` will disable all checks. When not set and `beartype` is installed, a warning message is emitted.

Now when users call an api with invalid arguments e.g.

```python
torch.onnx.export(conv, y, path, export_params=True, training=False)

# traning should take TrainingModel, not bool
```

they get

```
Traceback (most recent call last):
  File "bisect_m1_error.py", line 63, in <module>
    main()
  File "bisect_m1_error.py", line 59, in main
    reveal_error()
  File "bisect_m1_error.py", line 32, in reveal_error
    torch.onnx.export(conv, y, cpu_model_path, export_params=True, training=False)
  File "<beartype(torch.onnx.utils.export) at 0x1281f5a60>", line 136, in export
  File "pytorch/venv/lib/python3.9/site-packages/beartype/_decor/_error/errormain.py", line 301, in raise_pep_call_exception
    raise exception_cls(  # type: ignore[misc]
beartype.roar.BeartypeCallHintParamViolation: beartyped export() parameter training=False violates type hint <class 'torch._C._onnx.TrainingMode'>, as False not instance of <protocol "torch._C._onnx.TrainingMode">.
```

when `TORCH_ONNX_EXPERIMENTAL_RUNTIME_TYPE_CHECK` is not set and `beartype` is installed, a warning message is emitted.

```
>>> torch.onnx.export("foo", "bar", "f")
<stdin>:1: CallHintViolationWarning: Traceback (most recent call last):
  File "/home/justinchu/dev/pytorch/torch/onnx/_internal/_beartype.py", line 54, in _coerce_beartype_exceptions_to_warnings
    return beartyped(*args, **kwargs)
  File "<beartype(torch.onnx.utils.export) at 0x7f1d4ab35280>", line 39, in export
  File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.9/site-packages/beartype/_decor/_error/errormain.py", line 301, in raise_pep_call_exception
    raise exception_cls(  # type: ignore[misc]
beartype.roar.BeartypeCallHintParamViolation: beartyped export() parameter model='foo' violates type hint typing.Union[torch.nn.modules.module.Module, torch.jit._script.ScriptModule, torch.jit.ScriptFunction], as 'foo' not <protocol "torch.jit.ScriptFunction">, <protocol "torch.nn.modules.module.Module">, or <protocol "torch.jit._script.ScriptModule">.

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/justinchu/dev/pytorch/torch/onnx/_internal/_beartype.py", line 63, in _coerce_beartype_exceptions_to_warnings
    return func(*args, **kwargs)
  File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 482, in export
    _export(
  File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 1422, in _export
    with exporter_context(model, training, verbose):
  File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.9/contextlib.py", line 119, in __enter__
    return next(self.gen)
  File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 177, in exporter_context
    with select_model_mode_for_export(
  File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.9/contextlib.py", line 119, in __enter__
    return next(self.gen)
  File "/home/justinchu/dev/pytorch/torch/onnx/utils.py", line 95, in select_model_mode_for_export
    originally_training = model.training
AttributeError: 'str' object has no attribute 'training'
```

We see the error is caught right when the type mismatch happens, improving from what otherwise would become `AttributeError: 'str' object has no attribute 'training'`

Pull Request resolved: #83673
Approved by: https://github.com/BowenBao

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/bf25a140f9d948915d52f294b5204196d1ca22e3

Reviewed By: weiwangmeta

Differential Revision: D39063732

fbshipit-source-id: 1e7ca60b574bb3ef37150c268d845e15b139831a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged module: onnx Related to torch.onnx open source release notes: onnx torch.onnx related changes that should show up in the release notes topic: improvements topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants