diff --git a/test/test_transforms.py b/test/test_transforms.py index 346ddd90511..794e7a07c07 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1345,20 +1345,33 @@ def test_random_grayscale(self): def test_random_erasing(self): """Unit tests for random erasing transform""" - img = torch.rand([3, 224, 224]) + img = torch.rand([3, 60, 60]) # Test Set 1: Erasing with int value - img_re = transforms.RandomErasing(value=0)(img) - assert img_re.size(0) == 3 - - # Test Set 2: Erasing with random value + img_re = transforms.RandomErasing(value=0.2) + i, j, h, w, v = img_re.get_params(img, scale=img_re.scale, ratio=img_re.ratio, value=img_re.value) + img_output = F.erase(img, i, j, h, w, v) + assert img_output.size(0) == 3 + + # Test Set 2: Check if the unerased region is preserved + orig_unerased = img.clone() + orig_unerased[:, i:i + h, j:j + w] = 0 + output_unerased = img_output.clone() + output_unerased[:, i:i + h, j:j + w] = 0 + assert torch.equal(orig_unerased, output_unerased) + + # Test Set 3: Erasing with random value img_re = transforms.RandomErasing(value='random')(img) assert img_re.size(0) == 3 - # Test Set 3: Erasing with tuple value + # Test Set 4: Erasing with tuple value img_re = transforms.RandomErasing(value=(0.2, 0.2, 0.2))(img) assert img_re.size(0) == 3 + # Test Set 5: Testing the inplace behaviour + img_re = transforms.RandomErasing(value=(0.2), inplace=True)(img) + assert torch.equal(img_re, img) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 8388f49b0f9..b543b14f15e 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -806,7 +806,7 @@ def to_grayscale(img, num_output_channels=1): return img -def erase(img, i, j, h, w, v): +def erase(img, i, j, h, w, v, inplace=False): """ Erase the input Tensor Image with given value. Args: @@ -816,6 +816,7 @@ def erase(img, i, j, h, w, v): h (int): Height of the erased region. w (int): Width of the erased region. v: Erasing value. + inplace(bool,optional): For in-place operations. By default is set False. Returns: Tensor Image: Erased image. @@ -823,5 +824,8 @@ def erase(img, i, j, h, w, v): if not isinstance(img, torch.Tensor): raise TypeError('img should be Tensor Image. Got {}'.format(type(img))) + if not inplace: + img = img.clone() + img[:, i:i + h, j:j + w] = v return img diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 5639288494c..0b50144c07f 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1196,6 +1196,8 @@ class RandomErasing(object): erase all pixels. If a tuple of length 3, it is used to erase R, G, B channels respectively. If a str of 'random', erasing each pixel with random values. + inplace: boolean to make this transform inplace.Default set to False. + Returns: Erased Image. # Examples: @@ -1207,7 +1209,7 @@ class RandomErasing(object): >>> ]) """ - def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 1. / 0.3), value=0): + def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 1. / 0.3), value=0, inplace=False): assert isinstance(value, (numbers.Number, str, tuple, list)) if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): warnings.warn("range should be of kind (min, max)") @@ -1218,6 +1220,7 @@ def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 1. / 0.3), value=0): self.scale = scale self.ratio = ratio self.value = value + self.inplace = inplace @staticmethod def get_params(img, scale, ratio, value=0): @@ -1261,5 +1264,5 @@ def __call__(self, img): """ if random.uniform(0, 1) < self.p: x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=self.value) - return F.erase(img, x, y, h, w, v) + return F.erase(img, x, y, h, w, v, self.inplace) return img