Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated RandomErasing code and tests to support batch of tensors (#2702) #2721

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 0 additions & 61 deletions test/test_transforms.py
Expand Up @@ -1624,67 +1624,6 @@ def test_random_grayscale(self):
# Checking if RandomGrayscale can be printed as string
trans3.__repr__()

def test_random_erasing(self):
"""Unit tests for random erasing transform"""
for is_scripted in [False, True]:
torch.manual_seed(12)
img = torch.rand(3, 60, 60)

# Test Set 0: invalid value
random_erasing = transforms.RandomErasing(value=(0.1, 0.2, 0.3, 0.4), p=1.0)
with self.assertRaises(ValueError, msg="If value is a sequence, it should have either a single value or 3"):
img_re = random_erasing(img)

# Test Set 1: Erasing with int value
random_erasing = transforms.RandomErasing(value=0.2)
if is_scripted:
random_erasing = torch.jit.script(random_erasing)

i, j, h, w, v = transforms.RandomErasing.get_params(
img, scale=random_erasing.scale, ratio=random_erasing.ratio, value=[random_erasing.value, ]
)
img_output = F.erase(img, i, j, h, w, v)
self.assertEqual(img_output.size(0), 3)

# Test Set 2: Check if the unerased region is preserved
true_output = img.clone()
true_output[:, i:i + h, j:j + w] = random_erasing.value
self.assertTrue(torch.equal(true_output, img_output))

# Test Set 3: Erasing with random value
random_erasing = transforms.RandomErasing(value="random")
if is_scripted:
random_erasing = torch.jit.script(random_erasing)
img_re = random_erasing(img)

self.assertEqual(img_re.size(0), 3)

# Test Set 4: Erasing with tuple value
random_erasing = transforms.RandomErasing(value=(0.2, 0.2, 0.2))
if is_scripted:
random_erasing = torch.jit.script(random_erasing)
img_re = random_erasing(img)
self.assertEqual(img_re.size(0), 3)
true_output = img.clone()
true_output[:, i:i + h, j:j + w] = torch.tensor(random_erasing.value)[:, None, None]
self.assertTrue(torch.equal(true_output, img_output))

# Test Set 5: Testing the inplace behaviour
random_erasing = transforms.RandomErasing(value=(0.2,), inplace=True)
if is_scripted:
random_erasing = torch.jit.script(random_erasing)

img_re = random_erasing(img)
self.assertTrue(torch.equal(img_re, img))

# Test Set 6: Checking when no erased region is selected
img = torch.rand([3, 300, 1])
random_erasing = transforms.RandomErasing(ratio=(0.1, 0.2), value="random")
if is_scripted:
random_erasing = torch.jit.script(random_erasing)
img_re = random_erasing(img)
self.assertTrue(torch.equal(img_re, img))


if __name__ == '__main__':
unittest.main()
24 changes: 24 additions & 0 deletions test/test_transforms_tensor.py
Expand Up @@ -433,6 +433,30 @@ def test_compose(self):
with self.assertRaisesRegex(RuntimeError, r"Could not get name of python class object"):
torch.jit.script(t)

def test_random_erasing(self):
img = torch.rand(3, 60, 60)

# Test Set 0: invalid value
random_erasing = T.RandomErasing(value=(0.1, 0.2, 0.3, 0.4), p=1.0)
with self.assertRaises(ValueError, msg="If value is a sequence, it should have either a single value or 3"):
random_erasing(img)

tensor, _ = self._create_data(24, 32, channels=3, device=self.device)
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)

test_configs = [
{"value": 0.2},
{"value": "random"},
{"value": (0.2, 0.2, 0.2)},
{"value": "random", "ratio": (0.1, 0.2)},
]

for config in test_configs:
fn = T.RandomErasing(**config)
scripted_fn = torch.jit.script(fn)
self._test_transform_vs_scripted(fn, scripted_fn, tensor)
self._test_transform_vs_scripted_on_batch(fn, scripted_fn, batch_tensors)


@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester):
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/functional.py
Expand Up @@ -1024,5 +1024,5 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool
if not inplace:
img = img.clone()

img[:, i:i + h, j:j + w] = v
img[..., i:i + h, j:j + w] = v
return img
4 changes: 2 additions & 2 deletions torchvision/transforms/transforms.py
Expand Up @@ -1375,7 +1375,7 @@ def __repr__(self):

class RandomErasing(torch.nn.Module):
""" Randomly selects a rectangle region in an image and erases its pixels.
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/pdf/1708.04896.pdf
'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896

Args:
p: probability that the random erasing operation will be performed.
Expand Down Expand Up @@ -1439,7 +1439,7 @@ def get_params(
Returns:
tuple: params (i, j, h, w, v) to be passed to ``erase`` for random erasing.
"""
img_c, img_h, img_w = img.shape
img_c, img_h, img_w = img.shape[-3], img.shape[-2], img.shape[-1]
area = img_h * img_w

for _ in range(10):
Expand Down