diff --git a/test/test_transforms.py b/test/test_transforms.py index 25d61fafeb4..b92b73d21c7 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1955,6 +1955,25 @@ def test_autoaugment(self): img = transform(img) transform.__repr__() + def test_random_erasing(self): + img = torch.ones(3, 128, 128) + + t = transforms.RandomErasing(scale=(0.1, 0.1), ratio=(1. / 3., 3. / 1.)) + y, x, h, w, v = t.get_params(img, t.scale, t.ratio, [t.value, ]) + aspect_ratio = h / w + self.assertTrue(aspect_ratio > 1. / 3. and aspect_ratio < 3. / 1.) + + aspect_ratios = [] + random.seed(42) + trial = 1000 + for _ in range(trial): + y, x, h, w, v = t.get_params(img, t.scale, t.ratio, [t.value, ]) + aspect_ratios.append(h / w) + + count_bigger_then_ones = len([1 for aspect_ratio in aspect_ratios if aspect_ratio > 1]) + count_smaller_then_ones = len([1 for aspect_ratio in aspect_ratios if aspect_ratio < 1]) + self.assertAlmostEqual(count_bigger_then_ones / trial, count_smaller_then_ones / trial, 1) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 30911978558..916956e29fd 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -828,9 +828,9 @@ def get_params( width, height = F._get_image_size(img) area = height * width + log_ratio = torch.log(torch.tensor(ratio)) for _ in range(10): target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() - log_ratio = torch.log(torch.tensor(ratio)) aspect_ratio = torch.exp( torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) ).item() @@ -1576,9 +1576,12 @@ def get_params( img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1] area = img_h * img_w + log_ratio = torch.log(torch.tensor(ratio)) for _ in range(10): erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() - aspect_ratio = torch.empty(1).uniform_(ratio[0], ratio[1]).item() + aspect_ratio = torch.exp( + torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) + ).item() h = int(round(math.sqrt(erase_area * aspect_ratio))) w = int(round(math.sqrt(erase_area / aspect_ratio)))