Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,20 +1345,33 @@ def test_random_grayscale(self):
def test_random_erasing(self):
"""Unit tests for random erasing transform"""

img = torch.rand([3, 224, 224])
img = torch.rand([3, 60, 60])

# Test Set 1: Erasing with int value
img_re = transforms.RandomErasing(value=0)(img)
assert img_re.size(0) == 3

# Test Set 2: Erasing with random value
img_re = transforms.RandomErasing(value=0.2)
i, j, h, w, v = img_re.get_params(img, scale=img_re.scale, ratio=img_re.ratio, value=img_re.value)
img_output = F.erase(img, i, j, h, w, v)
assert img_output.size(0) == 3

# Test Set 2: Check if the unerased region is preserved
orig_unerased = img.clone()
orig_unerased[:, i:i + h, j:j + w] = 0
output_unerased = img_output.clone()
output_unerased[:, i:i + h, j:j + w] = 0
assert torch.equal(orig_unerased, output_unerased)

# Test Set 3: Erasing with random value
img_re = transforms.RandomErasing(value='random')(img)
assert img_re.size(0) == 3

# Test Set 3: Erasing with tuple value
# Test Set 4: Erasing with tuple value
img_re = transforms.RandomErasing(value=(0.2, 0.2, 0.2))(img)
assert img_re.size(0) == 3

# Test Set 5: Testing the inplace behaviour
img_re = transforms.RandomErasing(value=(0.2), inplace=True)(img)
assert torch.equal(img_re, img)


if __name__ == '__main__':
unittest.main()
6 changes: 5 additions & 1 deletion torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ def to_grayscale(img, num_output_channels=1):
return img


def erase(img, i, j, h, w, v):
def erase(img, i, j, h, w, v, inplace=False):
""" Erase the input Tensor Image with given value.

Args:
Expand All @@ -816,12 +816,16 @@ def erase(img, i, j, h, w, v):
h (int): Height of the erased region.
w (int): Width of the erased region.
v: Erasing value.
inplace(bool,optional): For in-place operations. By default is set False.

Returns:
Tensor Image: Erased image.
"""
if not isinstance(img, torch.Tensor):
raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))

if not inplace:
img = img.clone()

img[:, i:i + h, j:j + w] = v
return img
7 changes: 5 additions & 2 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,8 @@ class RandomErasing(object):
erase all pixels. If a tuple of length 3, it is used to erase
R, G, B channels respectively.
If a str of 'random', erasing each pixel with random values.
inplace: boolean to make this transform inplace.Default set to False.

Returns:
Erased Image.
# Examples:
Expand All @@ -1207,7 +1209,7 @@ class RandomErasing(object):
>>> ])
"""

def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 1. / 0.3), value=0):
def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 1. / 0.3), value=0, inplace=False):
assert isinstance(value, (numbers.Number, str, tuple, list))
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
warnings.warn("range should be of kind (min, max)")
Expand All @@ -1218,6 +1220,7 @@ def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 1. / 0.3), value=0):
self.scale = scale
self.ratio = ratio
self.value = value
self.inplace = inplace

@staticmethod
def get_params(img, scale, ratio, value=0):
Expand Down Expand Up @@ -1261,5 +1264,5 @@ def __call__(self, img):
"""
if random.uniform(0, 1) < self.p:
x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=self.value)
return F.erase(img, x, y, h, w, v)
return F.erase(img, x, y, h, w, v, self.inplace)
return img