diff --git a/test/test_transforms.py b/test/test_transforms.py index 6dd4dfc3830..33d81c657da 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1621,76 +1621,6 @@ def test_gaussian_blur_asserts(self): with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"): transforms.GaussianBlur(3, "sigma_string") - def _test_randomness(self, fn, trans, configs): - random_state = random.getstate() - random.seed(42) - img = transforms.ToPILImage()(torch.rand(3, 16, 18)) - - for p in [0.5, 0.7]: - for config in configs: - inv_img = fn(img, **config) - - num_samples = 250 - counts = 0 - for _ in range(num_samples): - tranformation = trans(p=p, **config) - tranformation.__repr__() - out = tranformation(img) - if out == inv_img: - counts += 1 - - p_value = stats.binom_test(counts, num_samples, p=p) - random.setstate(random_state) - self.assertGreater(p_value, 0.0001) - - @unittest.skipIf(stats is None, 'scipy.stats not available') - def test_random_invert(self): - self._test_randomness( - F.invert, - transforms.RandomInvert, - [{}] - ) - - @unittest.skipIf(stats is None, 'scipy.stats not available') - def test_random_posterize(self): - self._test_randomness( - F.posterize, - transforms.RandomPosterize, - [{"bits": 4}] - ) - - @unittest.skipIf(stats is None, 'scipy.stats not available') - def test_random_solarize(self): - self._test_randomness( - F.solarize, - transforms.RandomSolarize, - [{"threshold": 192}] - ) - - @unittest.skipIf(stats is None, 'scipy.stats not available') - def test_random_adjust_sharpness(self): - self._test_randomness( - F.adjust_sharpness, - transforms.RandomAdjustSharpness, - [{"sharpness_factor": 2.0}] - ) - - @unittest.skipIf(stats is None, 'scipy.stats not available') - def test_random_autocontrast(self): - self._test_randomness( - F.autocontrast, - transforms.RandomAutocontrast, - [{}] - ) - - @unittest.skipIf(stats is None, 'scipy.stats not available') - def test_random_equalize(self): - self._test_randomness( - F.equalize, - transforms.RandomEqualize, - [{}] - ) - def test_autoaugment(self): for policy in transforms.AutoAugmentPolicy: for fill in [None, 85, (128, 128, 128)]: @@ -1834,6 +1764,36 @@ def test_pad_with_mode_F_images(self): assert_equal(padded_img.size, [edge_size + 2 * pad for edge_size in img.size], check_stride=False) +@pytest.mark.skipif(stats is None, reason="scipy.stats not available") +@pytest.mark.parametrize('fn, trans, config', [ + (F.invert, transforms.RandomInvert, {}), + (F.posterize, transforms.RandomPosterize, {"bits": 4}), + (F.solarize, transforms.RandomSolarize, {"threshold": 192}), + (F.adjust_sharpness, transforms.RandomAdjustSharpness, {"sharpness_factor": 2.0}), + (F.autocontrast, transforms.RandomAutocontrast, {}), + (F.equalize, transforms.RandomEqualize, {})]) +@pytest.mark.parametrize('p', (.5, .7)) +def test_randomness(fn, trans, config, p): + random_state = random.getstate() + random.seed(42) + img = transforms.ToPILImage()(torch.rand(3, 16, 18)) + + inv_img = fn(img, **config) + + num_samples = 250 + counts = 0 + for _ in range(num_samples): + tranformation = trans(p=p, **config) + tranformation.__repr__() + out = tranformation(img) + if out == inv_img: + counts += 1 + + p_value = stats.binom_test(counts, num_samples, p=p) + random.setstate(random_state) + assert p_value > 0.0001 + + def test_adjust_brightness(): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]