Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 30 additions & 70 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,76 +1621,6 @@ def test_gaussian_blur_asserts(self):
with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"):
transforms.GaussianBlur(3, "sigma_string")

def _test_randomness(self, fn, trans, configs):
random_state = random.getstate()
random.seed(42)
img = transforms.ToPILImage()(torch.rand(3, 16, 18))

for p in [0.5, 0.7]:
for config in configs:
inv_img = fn(img, **config)

num_samples = 250
counts = 0
for _ in range(num_samples):
tranformation = trans(p=p, **config)
tranformation.__repr__()
out = tranformation(img)
if out == inv_img:
counts += 1

p_value = stats.binom_test(counts, num_samples, p=p)
random.setstate(random_state)
self.assertGreater(p_value, 0.0001)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_invert(self):
self._test_randomness(
F.invert,
transforms.RandomInvert,
[{}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_posterize(self):
self._test_randomness(
F.posterize,
transforms.RandomPosterize,
[{"bits": 4}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_solarize(self):
self._test_randomness(
F.solarize,
transforms.RandomSolarize,
[{"threshold": 192}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_adjust_sharpness(self):
self._test_randomness(
F.adjust_sharpness,
transforms.RandomAdjustSharpness,
[{"sharpness_factor": 2.0}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_autocontrast(self):
self._test_randomness(
F.autocontrast,
transforms.RandomAutocontrast,
[{}]
)

@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_equalize(self):
self._test_randomness(
F.equalize,
transforms.RandomEqualize,
[{}]
)

def test_autoaugment(self):
for policy in transforms.AutoAugmentPolicy:
for fill in [None, 85, (128, 128, 128)]:
Expand Down Expand Up @@ -1834,6 +1764,36 @@ def test_pad_with_mode_F_images(self):
assert_equal(padded_img.size, [edge_size + 2 * pad for edge_size in img.size], check_stride=False)


@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
@pytest.mark.parametrize('fn, trans, config', [
(F.invert, transforms.RandomInvert, {}),
(F.posterize, transforms.RandomPosterize, {"bits": 4}),
(F.solarize, transforms.RandomSolarize, {"threshold": 192}),
(F.adjust_sharpness, transforms.RandomAdjustSharpness, {"sharpness_factor": 2.0}),
(F.autocontrast, transforms.RandomAutocontrast, {}),
(F.equalize, transforms.RandomEqualize, {})])
@pytest.mark.parametrize('p', (.5, .7))
def test_randomness(fn, trans, config, p):
random_state = random.getstate()
random.seed(42)
img = transforms.ToPILImage()(torch.rand(3, 16, 18))

inv_img = fn(img, **config)

num_samples = 250
counts = 0
for _ in range(num_samples):
tranformation = trans(p=p, **config)
tranformation.__repr__()
out = tranformation(img)
if out == inv_img:
counts += 1

p_value = stats.binom_test(counts, num_samples, p=p)
random.setstate(random_state)
assert p_value > 0.0001


def test_adjust_brightness():
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
Expand Down