Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 3 additions & 18 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,18 @@
import random
import shutil
import tempfile
from distutils.util import strtobool

import numpy as np
import pytest
import torch
from PIL import Image
from torchvision import io

import __main__ # noqa: 401


def get_bool_env_var(name, *, exist_ok=False, default=False):
value = os.getenv(name)
if value is None:
return default
if exist_ok:
return True
return bool(strtobool(value))


IN_CIRCLE_CI = get_bool_env_var("CIRCLECI")
IN_RE_WORKER = get_bool_env_var("INSIDE_RE_WORKER", exist_ok=True)
IN_FBCODE = get_bool_env_var("IN_FBCODE_TORCHVISION")
IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == "true"
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
CIRCLECI_GPU_NO_CUDA_MSG = "We're in a CircleCI GPU machine, and this test doesn't need cuda."

Expand Down Expand Up @@ -213,7 +202,3 @@ def _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=1e-8, **fn_kwargs):
# scriptable function test
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
torch.testing.assert_close(transformed_batch, s_transformed_batch, rtol=1e-5, atol=scripted_fn_atol)


def run_on_env_var(name, *, skip_reason=None, exist_ok=False, default=False):
return pytest.mark.skipif(not get_bool_env_var(name, exist_ok=exist_ok, default=default), reason=skip_reason)
9 changes: 5 additions & 4 deletions test/test_prototype_models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import importlib
import os

import pytest
import test_models as TM
import torch
from common_utils import cpu_and_gpu, run_on_env_var, needs_cuda
from common_utils import cpu_and_gpu, needs_cuda
from torchvision.prototype import models
from torchvision.prototype.models._api import WeightsEnum, Weights
from torchvision.prototype.models._utils import handle_legacy_interface

run_if_test_with_prototype = run_on_env_var(
"PYTORCH_TEST_WITH_PROTOTYPE",
skip_reason="Prototype tests are disabled by default. Set PYTORCH_TEST_WITH_PROTOTYPE=1 to run them.",
run_if_test_with_prototype = pytest.mark.skipif(
os.getenv("PYTORCH_TEST_WITH_PROTOTYPE") != "1",
reason="Prototype tests are disabled by default. Set PYTORCH_TEST_WITH_PROTOTYPE=1 to run them.",
)


Expand Down