Skip to content

Commit

Permalink
Implement RandomApply which applies a transformation with a given pro…
Browse files Browse the repository at this point in the history
…bability
  • Loading branch information
timofurrer committed Apr 19, 2020
1 parent d1c7a4a commit 40772b1
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
28 changes: 28 additions & 0 deletions detectron2/data/transforms/transform_gen.py
Expand Up @@ -21,6 +21,7 @@
from .transform import ExtentTransform, ResizeTransform, RotationTransform

__all__ = [
"RandomApply",
"RandomBrightness",
"RandomContrast",
"RandomCrop",
Expand Down Expand Up @@ -113,6 +114,33 @@ def __repr__(self):
__str__ = __repr__


class RandomApply(TransformGen):
"""
Randomly apply the wrapper transformation with a given probability.
"""

def __init__(self, transform, prob=0.5):
"""
Args:
prob (float): probability between 0.0 and 1.0 that
the wrapper transformation is applied
"""
super().__init__()
assert 0.0 <= prob <= 1.0, f"Probablity must be between 0.0 and 1.0 (given: {prob})"
self.prob = prob
self.transform = transform

def get_transform(self, img):
do = self._rand_range() < self.prob
if do:
if isinstance(self.transform, TransformGen):
return self.transform.get_transform(img)
else:
return self.transform
else:
return NoOpTransform()


class RandomFlip(TransformGen):
"""
Flip the image horizontally or vertically with the given probability.
Expand Down
60 changes: 60 additions & 0 deletions tests/test_data_transform.py
Expand Up @@ -4,6 +4,7 @@
import logging
import numpy as np
import unittest
from unittest import mock

from detectron2.config import get_cfg
from detectron2.data import detection_utils
Expand Down Expand Up @@ -78,3 +79,62 @@ def test_print_transform_gen(self):

t = T.RandomFlip()
self.assertTrue(str(t) == "RandomFlip()")

def test_random_apply_prob_out_of_range_check(self):
# GIVEN
test_probabilities = {0.0: True, 0.5: True, 1.0: True, -0.01: False, 1.01: False}

# WHEN
for given_probability, is_valid in test_probabilities.items():
# THEN
if not is_valid:
self.assertRaises(AssertionError, T.RandomApply, None, prob=given_probability)
else:
T.RandomApply(None, prob=given_probability)

def test_random_apply_wrapping_transform_gen_probability_occured_evaluation(self):
# GIVEN
test_probability = 0.001
transform_mock = mock.MagicMock(name="MockTransform", spec=T.TransformGen)
image_mock = mock.MagicMock(name="MockImage")
random_apply = T.RandomApply(transform_mock, prob=test_probability)

# WHEN
with mock.patch.object(random_apply, "_rand_range") as rand_range_mock:
rand_range_mock.return_value = 0.0001
transform = random_apply.get_transform(image_mock)

# THEN
transform_mock.get_transform.assert_called_once_with(image_mock)
assert transform is not transform_mock

def test_random_apply_wrapping_std_transform_probability_occured_evaluation(self):
# GIVEN
test_probability = 0.001
transform_mock = mock.MagicMock(name="MockTransform", spec=T.Transform)
image_mock = mock.MagicMock(name="MockImage")
random_apply = T.RandomApply(transform_mock, prob=test_probability)

# WHEN
with mock.patch.object(random_apply, "_rand_range") as rand_range_mock:
rand_range_mock.return_value = 0.0001
transform = random_apply.get_transform(image_mock)

# THEN
assert transform is transform_mock

def test_random_apply_probability_not_occured_evaluation(self):
# GIVEN
test_probability = 0.001
transform_mock = mock.MagicMock(name="MockTransform", spec=T.TransformGen)
image_mock = mock.MagicMock(name="MockImage")
random_apply = T.RandomApply(transform_mock, prob=test_probability)

# WHEN
with mock.patch.object(random_apply, "_rand_range") as rand_range_mock:
rand_range_mock.return_value = 0.9
transform = random_apply.get_transform(image_mock)

# THEN
transform_mock.get_transform.assert_not_called()
assert isinstance(transform, T.NoOpTransform)

0 comments on commit 40772b1

Please sign in to comment.