From 2b4c4b3f1e68b0380e614048d33cdebc25b339c9 Mon Sep 17 00:00:00 2001 From: surgan12 Date: Thu, 27 Jun 2019 07:27:50 +0530 Subject: [PATCH 1/5] test improved --- test/test_transforms.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 346ddd90511..58fa036153f 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1345,11 +1345,19 @@ def test_random_grayscale(self): def test_random_erasing(self): """Unit tests for random erasing transform""" - img = torch.rand([3, 224, 224]) + img = torch.rand([3, 60, 60]) # Test Set 1: Erasing with int value - img_re = transforms.RandomErasing(value=0)(img) - assert img_re.size(0) == 3 + img_re = transforms.RandomErasing(value=0) + i, j, h, w, v = img_re.get_params(img, scale=img_re.scale, ratio=img_re.ratio, value=img_re.value) + + # Check if the unerased region is preserved + img_output = F.erase(img, i, j, h, w, v) + erased_region = torch.zeros([3, 60, 60], dtype=torch.float32) + erased_region[:, i:i + h, j:j + w] = v + + assert torch.equal(img - img_output, erased_region) + assert img_output.size(0) == 3 # Test Set 2: Erasing with random value img_re = transforms.RandomErasing(value='random')(img) From eb21d63a96c0fd49f0ea13cd56206a98ee38eea6 Mon Sep 17 00:00:00 2001 From: Surgan Jandial Date: Thu, 27 Jun 2019 12:31:07 +0530 Subject: [PATCH 2/5] Update test_transforms.py --- test/test_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 58fa036153f..cb203600d0b 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1350,7 +1350,7 @@ def test_random_erasing(self): # Test Set 1: Erasing with int value img_re = transforms.RandomErasing(value=0) i, j, h, w, v = img_re.get_params(img, scale=img_re.scale, ratio=img_re.ratio, value=img_re.value) - + # Check if the unerased region is preserved img_output = F.erase(img, i, j, h, w, v) erased_region = torch.zeros([3, 60, 60], dtype=torch.float32) From fcbfa5d196843281968a35a106a038eb579f4a12 Mon Sep 17 00:00:00 2001 From: surgan12 Date: Fri, 28 Jun 2019 11:44:04 +0530 Subject: [PATCH 3/5] behaviour changes RandomErasing --- torchvision/transforms/functional.py | 6 +++++- torchvision/transforms/transforms.py | 7 +++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 8388f49b0f9..b543b14f15e 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -806,7 +806,7 @@ def to_grayscale(img, num_output_channels=1): return img -def erase(img, i, j, h, w, v): +def erase(img, i, j, h, w, v, inplace=False): """ Erase the input Tensor Image with given value. Args: @@ -816,6 +816,7 @@ def erase(img, i, j, h, w, v): h (int): Height of the erased region. w (int): Width of the erased region. v: Erasing value. + inplace(bool,optional): For in-place operations. By default is set False. Returns: Tensor Image: Erased image. @@ -823,5 +824,8 @@ def erase(img, i, j, h, w, v): if not isinstance(img, torch.Tensor): raise TypeError('img should be Tensor Image. Got {}'.format(type(img))) + if not inplace: + img = img.clone() + img[:, i:i + h, j:j + w] = v return img diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 5639288494c..0b50144c07f 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1196,6 +1196,8 @@ class RandomErasing(object): erase all pixels. If a tuple of length 3, it is used to erase R, G, B channels respectively. If a str of 'random', erasing each pixel with random values. + inplace: boolean to make this transform inplace.Default set to False. + Returns: Erased Image. # Examples: @@ -1207,7 +1209,7 @@ class RandomErasing(object): >>> ]) """ - def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 1. / 0.3), value=0): + def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 1. / 0.3), value=0, inplace=False): assert isinstance(value, (numbers.Number, str, tuple, list)) if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): warnings.warn("range should be of kind (min, max)") @@ -1218,6 +1220,7 @@ def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 1. / 0.3), value=0): self.scale = scale self.ratio = ratio self.value = value + self.inplace = inplace @staticmethod def get_params(img, scale, ratio, value=0): @@ -1261,5 +1264,5 @@ def __call__(self, img): """ if random.uniform(0, 1) < self.p: x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=self.value) - return F.erase(img, x, y, h, w, v) + return F.erase(img, x, y, h, w, v, self.inplace) return img From c34617c5f1696b778ba4db27a81972f0c7d2985f Mon Sep 17 00:00:00 2001 From: surgan12 Date: Fri, 28 Jun 2019 12:09:30 +0530 Subject: [PATCH 4/5] test fixes --- test/test_transforms.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index cb203600d0b..70188e7161e 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1348,25 +1348,29 @@ def test_random_erasing(self): img = torch.rand([3, 60, 60]) # Test Set 1: Erasing with int value - img_re = transforms.RandomErasing(value=0) + img_re = transforms.RandomErasing(value=0.2) i, j, h, w, v = img_re.get_params(img, scale=img_re.scale, ratio=img_re.ratio, value=img_re.value) - - # Check if the unerased region is preserved img_output = F.erase(img, i, j, h, w, v) - erased_region = torch.zeros([3, 60, 60], dtype=torch.float32) - erased_region[:, i:i + h, j:j + w] = v - - assert torch.equal(img - img_output, erased_region) assert img_output.size(0) == 3 - # Test Set 2: Erasing with random value + # Test Set 2: Check if the unerased region is preserved + orig_unerased = img.clone() + orig_unerased[:, i:i + h, j:j + w] = 0 + output_unerased = img_output.clone() + output_unerased[:, i:i + h, j:j + w] = 0 + assert torch.equal(orig_unerased, output_unerased) + + # Test Set 3: Erasing with random value img_re = transforms.RandomErasing(value='random')(img) assert img_re.size(0) == 3 - # Test Set 3: Erasing with tuple value + # Test Set 4: Erasing with tuple value img_re = transforms.RandomErasing(value=(0.2, 0.2, 0.2))(img) assert img_re.size(0) == 3 + # Test Set 5: Testing the inplace behaviour + img_re = transforms.RandomErasing(value=(0.2), inplace=True)(img) + assert torch.equal(img_re, img) if __name__ == '__main__': unittest.main() From 0acffcf22d688b44279aa78166870c6f36545670 Mon Sep 17 00:00:00 2001 From: surgan12 Date: Fri, 28 Jun 2019 12:35:22 +0530 Subject: [PATCH 5/5] linter fix --- test/test_transforms.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index 70188e7161e..794e7a07c07 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1372,5 +1372,6 @@ def test_random_erasing(self): img_re = transforms.RandomErasing(value=(0.2), inplace=True)(img) assert torch.equal(img_re, img) + if __name__ == '__main__': unittest.main()