diff --git a/test/common_utils.py b/test/common_utils.py index a95591ae570..a7534f7d3ed 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -219,8 +219,9 @@ def freeze_rng_state(): def cycle_over(objs): - for idx, obj in enumerate(objs): - yield obj, objs[:idx] + objs[idx + 1:] + for idx, obj1 in enumerate(objs): + for obj2 in objs[:idx] + objs[idx + 1:]: + yield obj1, obj2 def int_dtypes(): diff --git a/test/test_transforms.py b/test/test_transforms.py index 49396e78435..8a0fb0f5ca6 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1,12 +1,9 @@ -import itertools import os import torch import torchvision.transforms as transforms import torchvision.transforms.functional as F import torchvision.transforms.functional_tensor as F_t from torch._utils_internal import get_file_path_2 -from numpy.testing import assert_array_almost_equal -import unittest import math import random import numpy as np @@ -30,126 +27,118 @@ os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg') -class Tester(unittest.TestCase): - - def test_convert_image_dtype_float_to_float(self): - for input_dtype, output_dtypes in cycle_over(float_dtypes()): - input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) - for output_dtype in output_dtypes: - with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): - transform = transforms.ConvertImageDtype(output_dtype) - transform_script = torch.jit.script(F.convert_image_dtype) - - output_image = transform(input_image) - output_image_script = transform_script(input_image, output_dtype) - - torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6) - - actual_min, actual_max = output_image.tolist() - desired_min, desired_max = 0.0, 1.0 - - self.assertAlmostEqual(actual_min, desired_min) - self.assertAlmostEqual(actual_max, desired_max) - - def test_convert_image_dtype_float_to_int(self): - for input_dtype in float_dtypes(): - input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) - for output_dtype in int_dtypes(): - with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): - transform = transforms.ConvertImageDtype(output_dtype) - transform_script = torch.jit.script(F.convert_image_dtype) - - if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or ( - input_dtype == torch.float64 and output_dtype == torch.int64 - ): - with self.assertRaises(RuntimeError): - transform(input_image) - else: - output_image = transform(input_image) - output_image_script = transform_script(input_image, output_dtype) - - torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6) - - actual_min, actual_max = output_image.tolist() - desired_min, desired_max = 0, torch.iinfo(output_dtype).max - - self.assertEqual(actual_min, desired_min) - self.assertEqual(actual_max, desired_max) - - def test_convert_image_dtype_int_to_float(self): - for input_dtype in int_dtypes(): - input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype) - for output_dtype in float_dtypes(): - with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): - transform = transforms.ConvertImageDtype(output_dtype) - transform_script = torch.jit.script(F.convert_image_dtype) - - output_image = transform(input_image) - output_image_script = transform_script(input_image, output_dtype) - - torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6) - - actual_min, actual_max = output_image.tolist() - desired_min, desired_max = 0.0, 1.0 - - self.assertAlmostEqual(actual_min, desired_min) - self.assertGreaterEqual(actual_min, desired_min) - self.assertAlmostEqual(actual_max, desired_max) - self.assertLessEqual(actual_max, desired_max) - - def test_convert_image_dtype_int_to_int(self): - for input_dtype, output_dtypes in cycle_over(int_dtypes()): - input_max = torch.iinfo(input_dtype).max - input_image = torch.tensor((0, input_max), dtype=input_dtype) - for output_dtype in output_dtypes: - output_max = torch.iinfo(output_dtype).max - - with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): - transform = transforms.ConvertImageDtype(output_dtype) - transform_script = torch.jit.script(F.convert_image_dtype) - - output_image = transform(input_image) - output_image_script = transform_script(input_image, output_dtype) - - torch.testing.assert_close( - output_image_script, - output_image, - rtol=0.0, - atol=1e-6, - msg="{} vs {}".format(output_image_script, output_image), - ) - - actual_min, actual_max = output_image.tolist() - desired_min, desired_max = 0, output_max - - # see https://github.com/pytorch/vision/pull/2078#issuecomment-641036236 for details - if input_max >= output_max: - error_term = 0 - else: - error_term = 1 - (torch.iinfo(output_dtype).max + 1) // (torch.iinfo(input_dtype).max + 1) - - self.assertEqual(actual_min, desired_min) - self.assertEqual(actual_max, desired_max + error_term) - - def test_convert_image_dtype_int_to_int_consistency(self): - for input_dtype, output_dtypes in cycle_over(int_dtypes()): - input_max = torch.iinfo(input_dtype).max - input_image = torch.tensor((0, input_max), dtype=input_dtype) - for output_dtype in output_dtypes: - output_max = torch.iinfo(output_dtype).max - if output_max <= input_max: - continue - - with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): - transform = transforms.ConvertImageDtype(output_dtype) - inverse_transfrom = transforms.ConvertImageDtype(input_dtype) - output_image = inverse_transfrom(transform(input_image)) - - actual_min, actual_max = output_image.tolist() - desired_min, desired_max = 0, input_max - - self.assertEqual(actual_min, desired_min) - self.assertEqual(actual_max, desired_max) +class TestConvertImageDtype: + @pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(float_dtypes())) + def test_float_to_float(self, input_dtype, output_dtype): + input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) + transform = transforms.ConvertImageDtype(output_dtype) + transform_script = torch.jit.script(F.convert_image_dtype) + + output_image = transform(input_image) + output_image_script = transform_script(input_image, output_dtype) + + torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6) + + actual_min, actual_max = output_image.tolist() + desired_min, desired_max = 0.0, 1.0 + + assert abs(actual_min - desired_min) < 1e-7 + assert abs(actual_max - desired_max) < 1e-7 + + @pytest.mark.parametrize('input_dtype', float_dtypes()) + @pytest.mark.parametrize('output_dtype', int_dtypes()) + def test_float_to_int(self, input_dtype, output_dtype): + input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) + transform = transforms.ConvertImageDtype(output_dtype) + transform_script = torch.jit.script(F.convert_image_dtype) + + if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or ( + input_dtype == torch.float64 and output_dtype == torch.int64 + ): + with pytest.raises(RuntimeError): + transform(input_image) + else: + output_image = transform(input_image) + output_image_script = transform_script(input_image, output_dtype) + + torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6) + + actual_min, actual_max = output_image.tolist() + desired_min, desired_max = 0, torch.iinfo(output_dtype).max + + assert actual_min == desired_min + assert actual_max == desired_max + + @pytest.mark.parametrize('input_dtype', int_dtypes()) + @pytest.mark.parametrize('output_dtype', float_dtypes()) + def test_int_to_float(self, input_dtype, output_dtype): + input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype) + transform = transforms.ConvertImageDtype(output_dtype) + transform_script = torch.jit.script(F.convert_image_dtype) + + output_image = transform(input_image) + output_image_script = transform_script(input_image, output_dtype) + + torch.testing.assert_close(output_image_script, output_image, rtol=0.0, atol=1e-6) + + actual_min, actual_max = output_image.tolist() + desired_min, desired_max = 0.0, 1.0 + + assert abs(actual_min - desired_min) < 1e-7 + assert actual_min >= desired_min + assert abs(actual_max - desired_max) < 1e-7 + assert actual_max <= desired_max + + @pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(int_dtypes())) + def test_dtype_int_to_int(self, input_dtype, output_dtype): + input_max = torch.iinfo(input_dtype).max + input_image = torch.tensor((0, input_max), dtype=input_dtype) + output_max = torch.iinfo(output_dtype).max + + transform = transforms.ConvertImageDtype(output_dtype) + transform_script = torch.jit.script(F.convert_image_dtype) + + output_image = transform(input_image) + output_image_script = transform_script(input_image, output_dtype) + + torch.testing.assert_close( + output_image_script, + output_image, + rtol=0.0, + atol=1e-6, + msg="{} vs {}".format(output_image_script, output_image), + ) + + actual_min, actual_max = output_image.tolist() + desired_min, desired_max = 0, output_max + + # see https://github.com/pytorch/vision/pull/2078#issuecomment-641036236 for details + if input_max >= output_max: + error_term = 0 + else: + error_term = 1 - (torch.iinfo(output_dtype).max + 1) // (torch.iinfo(input_dtype).max + 1) + + assert actual_min == desired_min + assert actual_max == (desired_max + error_term) + + @pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(int_dtypes())) + def test_int_to_int_consistency(self, input_dtype, output_dtype): + input_max = torch.iinfo(input_dtype).max + input_image = torch.tensor((0, input_max), dtype=input_dtype) + + output_max = torch.iinfo(output_dtype).max + if output_max <= input_max: + return + + transform = transforms.ConvertImageDtype(output_dtype) + inverse_transfrom = transforms.ConvertImageDtype(input_dtype) + output_image = inverse_transfrom(transform(input_image)) + + actual_min, actual_max = output_image.tolist() + desired_min, desired_max = 0, input_max + + assert actual_min == desired_min + assert actual_max == desired_max @pytest.mark.skipif(accimage is None, reason="accimage not available") @@ -2120,4 +2109,4 @@ def test_random_affine(): if __name__ == '__main__': - unittest.main() + pytest.main([__file__])