diff --git a/test/common_utils.py b/test/common_utils.py index e20e2c658bc..5a853771301 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -4,10 +4,8 @@ import random import shutil import tempfile -from distutils.util import strtobool import numpy as np -import pytest import torch from PIL import Image from torchvision import io @@ -15,18 +13,9 @@ import __main__ # noqa: 401 -def get_bool_env_var(name, *, exist_ok=False, default=False): - value = os.getenv(name) - if value is None: - return default - if exist_ok: - return True - return bool(strtobool(value)) - - -IN_CIRCLE_CI = get_bool_env_var("CIRCLECI") -IN_RE_WORKER = get_bool_env_var("INSIDE_RE_WORKER", exist_ok=True) -IN_FBCODE = get_bool_env_var("IN_FBCODE_TORCHVISION") +IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == "true" +IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None +IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" CUDA_NOT_AVAILABLE_MSG = "CUDA device not available" CIRCLECI_GPU_NO_CUDA_MSG = "We're in a CircleCI GPU machine, and this test doesn't need cuda." @@ -213,7 +202,3 @@ def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs): # scriptable function test s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs) torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol) - - -def run_on_env_var(name, *, skip_reason=None, exist_ok=False, default=False): - return pytest.mark.skipif(not get_bool_env_var(name, exist_ok=exist_ok, default=default), reason=skip_reason) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 9afde9e3128..41616fce15d 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -1,16 +1,17 @@ import importlib +import os import pytest import test_models as TM import torch -from common_utils import cpu_and_gpu, run_on_env_var, needs_cuda +from common_utils import cpu_and_gpu, needs_cuda from torchvision.prototype import models from torchvision.prototype.models._api import WeightsEnum, Weights from torchvision.prototype.models._utils import handle_legacy_interface -run_if_test_with_prototype = run_on_env_var( - "PYTORCH_TEST_WITH_PROTOTYPE", - skip_reason="Prototype tests are disabled by default. Set PYTORCH_TEST_WITH_PROTOTYPE=1 to run them.", +run_if_test_with_prototype = pytest.mark.skipif( + os.getenv("PYTORCH_TEST_WITH_PROTOTYPE") != "1", + reason="Prototype tests are disabled by default. Set PYTORCH_TEST_WITH_PROTOTYPE=1 to run them.", )