diff --git a/test/test_onnx.py b/test/test_onnx.py index a63c2f09e16..ce0bc5c7b97 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -1,13 +1,6 @@ -# onnxruntime requires python 3.5 or above -try: - # This import should be before that of torch - # see https://github.com/onnx/onnx/issues/2394#issuecomment-581638840 - import onnxruntime -except ImportError: - onnxruntime = None - from common_utils import set_rng_seed, assert_equal import io +import pytest import torch from torchvision import ops from torchvision import models @@ -20,11 +13,13 @@ from collections import OrderedDict from typing import List, Tuple -import pytest from torchvision.ops._register_onnx_ops import _onnx_opset_version +# In environments without onnxruntime we prefer to +# invoke all tests in the repo and have this one skipped rather than fail. +onnxruntime = pytest.importorskip("onnxruntime") + -@pytest.mark.skipif(onnxruntime is None, reason='ONNX Runtime unavailable') class TestONNXExporter: @classmethod def setup_class(cls):