diff --git a/test/test_transforms.py b/test/test_transforms.py index b3d426e5221..8abf9e88db3 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1959,10 +1959,12 @@ def test_autoaugment(self): def test_random_erasing(self): img = torch.ones(3, 128, 128) - t = transforms.RandomErasing(scale=(0.1, 0.1), ratio=(1. / 3., 3. / 1.)) + t = transforms.RandomErasing(scale=(0.1, 0.1), ratio=(1 / 3, 3.)) y, x, h, w, v = t.get_params(img, t.scale, t.ratio, [t.value, ]) aspect_ratio = h / w - self.assertTrue(aspect_ratio > 1. / 3. and aspect_ratio < 3. / 1.) + # Add some tolerance due to the rounding and int conversion used in the transform + tol = 0.05 + self.assertTrue(1 / 3 - tol <= aspect_ratio <= 3 + tol) aspect_ratios = [] random.seed(42)