Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions test/_assert_utils.py

This file was deleted.

8 changes: 5 additions & 3 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import __main__
import random
import inspect
import functools

from numbers import Number
from torch._six import string_classes
Expand All @@ -17,8 +18,6 @@
import numpy as np
from PIL import Image

from _assert_utils import assert_equal

IS_PY39 = sys.version_info.major == 3 and sys.version_info.minor == 9
PY39_SEGFAULT_SKIP_MSG = "Segmentation fault with Python 3.9, see https://github.com/pytorch/vision/issues/3367"
PY39_SKIP = unittest.skipIf(IS_PY39, PY39_SEGFAULT_SKIP_MSG)
Expand Down Expand Up @@ -268,14 +267,17 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu
return batch_tensor


assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)


def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):
np_pil_image = np.array(pil_image)
if np_pil_image.ndim == 2:
np_pil_image = np_pil_image[:, :, None]
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
if msg is None:
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
assert_equal(tensor.cpu(), pil_tensor, check_stride=False, msg=msg)
assert_equal(tensor.cpu(), pil_tensor, msg=msg)


def _assert_approx_equal_tensor_to_pil(tensor, pil_image, tol=1e-5, msg=None, agg_method="mean",
Expand Down
3 changes: 1 addition & 2 deletions test/test_datasets_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from torchvision.datasets.video_utils import VideoClips, unfold
from torchvision import get_video_backend

from common_utils import get_tmp_dir
from _assert_utils import assert_equal
from common_utils import get_tmp_dir, assert_equal


@contextlib.contextmanager
Expand Down
9 changes: 4 additions & 5 deletions test/test_datasets_video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from torchvision import io
from torchvision.datasets.video_utils import VideoClips, unfold

from common_utils import get_tmp_dir
from _assert_utils import assert_equal
from common_utils import get_tmp_dir, assert_equal


@contextlib.contextmanager
Expand Down Expand Up @@ -41,22 +40,22 @@ def test_unfold(self):
[0, 1, 2],
[3, 4, 5],
])
assert_equal(r, expected, check_stride=False)
assert_equal(r, expected)

r = unfold(a, 3, 2, 1)
expected = torch.tensor([
[0, 1, 2],
[2, 3, 4],
[4, 5, 6]
])
assert_equal(r, expected, check_stride=False)
assert_equal(r, expected)

r = unfold(a, 3, 2, 2)
expected = torch.tensor([
[0, 2, 4],
[2, 4, 6],
])
assert_equal(r, expected, check_stride=False)
assert_equal(r, expected)

@pytest.mark.skipif(not io.video._av_available(), reason="this test requires av")
def test_video_clips(self):
Expand Down
14 changes: 3 additions & 11 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
_assert_equal_tensor_to_pil,
_assert_approx_equal_tensor_to_pil,
_test_fn_on_batch,
assert_equal,
)
from _assert_utils import assert_equal

from typing import Dict, List, Sequence, Tuple

Expand Down Expand Up @@ -187,11 +187,7 @@ def test_square_rotations(self, device, height, width, dt, angle, config, fn):
tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
)
if config is not None:
assert_equal(
torch.rot90(tensor, **config),
out_tensor,
check_stride=False,
)
assert_equal(torch.rot90(tensor, **config), out_tensor)

if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
Expand Down Expand Up @@ -856,7 +852,6 @@ def test_resized_crop(device, mode):
assert_equal(
expected_out_tensor,
out_tensor,
check_stride=False,
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]),
)

Expand Down Expand Up @@ -1001,10 +996,7 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn):
).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)

out = fn(tensor, kernel_size=ksize, sigma=sigma)
torch.testing.assert_close(
out, true_out, rtol=0.0, atol=1.0, check_stride=False,
msg="{}, {}".format(ksize, sigma)
)
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg="{}, {}".format(ksize, sigma))


@pytest.mark.parametrize('device', cpu_and_gpu())
Expand Down
5 changes: 2 additions & 3 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import torch
from PIL import Image, __version__ as PILLOW_VERSION
import torchvision.transforms.functional as F
from common_utils import get_tmp_dir, needs_cuda
from _assert_utils import assert_equal
from common_utils import get_tmp_dir, needs_cuda, assert_equal

from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
Expand Down Expand Up @@ -280,7 +279,7 @@ def test_read_1_bit_png(shape):
img.save(image_path)
img1 = read_image(image_path)
img2 = normalize_dimensions(torch.as_tensor(pixels * 255, dtype=torch.uint8))
assert_equal(img1, img2, check_stride=False)
assert_equal(img1, img2)


@pytest.mark.parametrize('shape', [
Expand Down
3 changes: 1 addition & 2 deletions test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import warnings
from urllib.error import URLError

from common_utils import get_tmp_dir
from _assert_utils import assert_equal
from common_utils import get_tmp_dir, assert_equal


try:
Expand Down
3 changes: 1 addition & 2 deletions test/test_models_detection_anchor_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from common_utils import TestCase
from _assert_utils import assert_equal
from common_utils import TestCase, assert_equal
from torchvision.models.detection.anchor_utils import AnchorGenerator, DefaultBoxGenerator
from torchvision.models.detection.image_list import ImageList
import pytest
Expand Down
2 changes: 1 addition & 1 deletion test/test_models_detection_negative_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead

import pytest
from _assert_utils import assert_equal
from common_utils import assert_equal


class TestModelsDetectionNegativeSamples:
Expand Down
2 changes: 1 addition & 1 deletion test/test_models_detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torchvision.models.detection.transform import GeneralizedRCNNTransform
import pytest
from torchvision.models.detection import backbone_utils
from _assert_utils import assert_equal
from common_utils import assert_equal


class TestModelsDetectionUtils:
Expand Down
3 changes: 1 addition & 2 deletions test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
except ImportError:
onnxruntime = None

from common_utils import set_rng_seed
from _assert_utils import assert_equal
from common_utils import set_rng_seed, assert_equal
import io
import torch
from torchvision import ops
Expand Down
3 changes: 1 addition & 2 deletions test/test_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from common_utils import needs_cuda, cpu_and_gpu
from _assert_utils import assert_equal
from common_utils import needs_cuda, cpu_and_gpu, assert_equal
import math
from abc import ABC, abstractmethod
import pytest
Expand Down
Loading