Skip to content

Commit

Permalink
Fix random_patch.
Browse files Browse the repository at this point in the history
  • Loading branch information
IanTayler authored and vierja committed Sep 5, 2017
1 parent 4879a7f commit 2feaa30
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 45 deletions.
30 changes: 24 additions & 6 deletions luminoth/datasets/object_detection_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import tensorflow as tf

from luminoth.utils.image import (
resize_image, flip_image
resize_image, flip_image, random_patch
)

DATA_AUGMENTATION_STRATEGIES = {
'flip': flip_image,
'patch': random_patch,
}


Expand Down Expand Up @@ -90,19 +91,36 @@ def _augment(self, image, bboxes, default_prob=0.5):
aug_fn = DATA_AUGMENTATION_STRATEGIES[aug_type]

random_number = tf.random_uniform([])
prob = aug_config.pop('prob', default_prob)
prob = tf.to_float(aug_config.pop('prob', default_prob))
apply_aug_strategy = tf.less(random_number, prob)

augmented = tf.cond(
apply_aug_strategy,
lambda: aug_fn(image, bboxes, **aug_config),
lambda: aug_fn(image, aug_config, bboxes),
lambda: {'image': image, 'bboxes': bboxes}
)

applied_data_augmentation.append({aug_type: apply_aug_strategy})
update_condition = tf.greater(
tf.gather(tf.shape(augmented['bboxes']), 0),
0
)
image = tf.cond(
update_condition,
lambda: augmented['image'],
lambda: image
)
# Hot fix. This works because bboxes is either always or never
# None in a single training session.
if bboxes is not None:
bboxes = tf.cond(
update_condition,
# TODO: find out why we're sometimes getting float
# bboxes.
lambda: tf.to_int32(augmented['bboxes']),
lambda: bboxes
)

image = augmented['image']
bboxes = augmented['bboxes']
applied_data_augmentation.append({aug_type: apply_aug_strategy})

return image, bboxes, applied_data_augmentation

Expand Down
4 changes: 4 additions & 0 deletions luminoth/models/fasterrcnn/base_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ dataset:
left_right: True
up_down: False
prob: 0.5
- patch:
min_height: 400
min_width: 400
prob: 0.2

network:
# Total number of classes to predict
Expand Down
112 changes: 75 additions & 37 deletions luminoth/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,38 @@
from luminoth.utils.bbox_transform_tf import clip_boxes


def adjust_bboxes(bboxes, old_height, old_width, new_height, new_width):
"""Adjusts the bboxes of an image that has been resized.
Args:
bboxes: Tensor with shape (num_bboxes, 5). Last element is the label.
old_height: Float. Height of the original image.
old_width: Float. Width of the original image.
new_height: Float. Height of the image after resizing.
new_width: Float. Width of the image after resizing.
Returns:
Tensor with shape (num_bboxes, 5), with the adjusted bboxes.
"""
# We normalize bounding boxes points.
bboxes_float = tf.to_float(bboxes)
x_min, y_min, x_max, y_max, label = tf.unstack(bboxes_float, axis=1)

x_min = x_min / old_width
y_min = y_min / old_height
x_max = x_max / old_width
y_max = y_max / old_height

# Use new size to scale back the bboxes points to absolute values.
x_min = tf.to_int32(x_min * new_width)
y_min = tf.to_int32(y_min * new_height)
x_max = tf.to_int32(x_max * new_width)
y_max = tf.to_int32(y_max * new_height)
label = tf.to_int32(label) # Cast back to int.

# Concat points and label to return a [num_bboxes, 5] tensor.
return tf.stack([x_min, y_min, x_max, y_max, label], axis=1)


def resize_image(image, bboxes=None, min_size=None, max_size=None):
"""
We need to resize image and (optionally) bounding boxes when the biggest
Expand Down Expand Up @@ -42,7 +74,7 @@ def resize_image(image, bboxes=None, min_size=None, max_size=None):
else:
upscale_factor = tf.constant(1.)

if max_size:
if max_size is not None:
# We do the same calculating the downscale factor, to end up with an
# image where the biggest dimension is less than `image_max_size`.
# When the image is small enough the scale factor is 1. (no change)
Expand All @@ -65,24 +97,11 @@ def resize_image(image, bboxes=None, min_size=None, max_size=None):
)

if bboxes is not None:
# We normalize bounding boxes points before modifying the image.
bboxes_float = tf.to_float(bboxes)
x_min, y_min, x_max, y_max, label = tf.unstack(bboxes_float, axis=1)

x_min = x_min / width
y_min = y_min / height
x_max = x_max / width
y_max = y_max / height

# Use new size to scale back the bboxes points to absolute values.
x_min = tf.to_int32(x_min * new_width)
y_min = tf.to_int32(y_min * new_height)
x_max = tf.to_int32(x_max * new_width)
y_max = tf.to_int32(y_max * new_height)
label = tf.to_int32(label) # Cast back to int.

# Concat points and label to return a [num_bboxes, 5] tensor.
bboxes = tf.stack([x_min, y_min, x_max, y_max, label], axis=1)
bboxes = adjust_bboxes(
bboxes,
old_height=height, old_width=width,
new_height=new_height, new_width=new_width
)
return {
'image': image,
'bboxes': bboxes,
Expand All @@ -95,20 +114,27 @@ def resize_image(image, bboxes=None, min_size=None, max_size=None):
}


def flip_image(image, bboxes=None, left_right=True, up_down=False):
def flip_image(image, config, bboxes=None):
"""Flips image on its axis for data augmentation.
Args:
image: Tensor with image of shape (H, W, 3).
config: EasyDict
left_right: Boolean flag to flip the image horizontally
(left to right).
up_down: Boolean flag to flip the image vertically (upside down)
bboxes: Optional Tensor with bounding boxes with shape
(total_bboxes, 5).
left_right: Boolean flag to flip the image horizontally
(left to right).
up_down: Boolean flag to flip the image vertically (upside down)
Returns:
image: Flipped image with the same shape.
bboxes: Tensor with the same shape.
"""
if 'left_right' not in config:
config.left_right = True
if 'up_down' not in config:
config.up_down = False
left_right = config.left_right
up_down = config.up_down

image_shape = tf.shape(image)
height = image_shape[0]
Expand Down Expand Up @@ -151,33 +177,44 @@ def flip_image(image, bboxes=None, left_right=True, up_down=False):
return return_dict


def random_patch(image, bboxes=None, debug=False):
def random_patch(image, config, bboxes=None, debug=False):
"""Gets a random patch from an image.
Args:
image: Tensor with shape (H, W, 3).
config: EasyDict
min_height: Minimum height of the patch.
min_width: Minimum width of the patch.
bboxes: Tensor with the ground-truth boxes. Shaped (total_boxes, 5).
The last element in each box is the category label.
debug: Boolean. If True, random seeds will be set to 0.
Returns:
image: Tensor with shape (H', W', 3), with H' <= H and W' <= W. A
random patch of the input image.
bboxes: Tensor with shape (new_total_boxes, 5), where we keep
bboxes that have their center inside the patch, cropping
them to the patch boundaries.
them to the patch boundaries. If we didn't get any bboxes, then
it's set as -1.
"""
# Get default values if not set.
if 'min_height' not in config:
config.min_height = 400
if 'min_width' not in config:
config.min_width = 400

if debug:
seed = 0
else:
seed = None
# See the documentation on tf.crop_to_bounding_box for the meaning of
# See the documentation on tf.image.crop_to_bounding_box for the meaning of
# these variables.
offset_width = tf.random_uniform(
shape=[],
minval=0,
maxval=tf.subtract(
tf.shape(image)[1],
1
config.min_width
),
dtype=tf.int32,
seed=seed
Expand All @@ -187,14 +224,14 @@ def random_patch(image, bboxes=None, debug=False):
minval=0,
maxval=tf.subtract(
tf.shape(image)[0],
1
config.min_height
),
dtype=tf.int32,
seed=seed
)
target_width = tf.random_uniform(
shape=[],
minval=1,
minval=config.min_width,
maxval=tf.subtract(
tf.shape(image)[1],
offset_width
Expand All @@ -204,7 +241,7 @@ def random_patch(image, bboxes=None, debug=False):
)
target_height = tf.random_uniform(
shape=[],
minval=1,
minval=config.min_height,
maxval=tf.subtract(
tf.shape(image)[0],
offset_height
Expand All @@ -214,15 +251,14 @@ def random_patch(image, bboxes=None, debug=False):
)
new_image = tf.image.crop_to_bounding_box(
image,
offset_height, offset_width,
target_height, target_width
offset_height=offset_height, offset_width=offset_width,
target_height=target_height, target_width=target_width
)

return_dict = {'image': new_image}

# Return if we didn't have bboxes.
if bboxes is None:
return_dict['bboxes'] = tf.constant(-1.)
return return_dict

# Now we will remove all bboxes whose centers are not inside the cropped
Expand Down Expand Up @@ -296,11 +332,13 @@ def random_patch(image, bboxes=None, debug=False):
# Finally, we clip the boxes and add back the labels.
new_bboxes = tf.concat(
[
clip_boxes(
new_bboxes_unclipped[:, :4],
imshape=tf.shape(new_image)[:2]
tf.to_int32(
clip_boxes(
new_bboxes_unclipped,
imshape=tf.shape(new_image)[:2]
),
),
tf.cast(masked_bboxes[:, 4:], tf.float32)
masked_bboxes[:, 4:]
],
axis=1
)
Expand Down
68 changes: 66 additions & 2 deletions luminoth/utils/image_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import numpy as np
import tensorflow as tf

from luminoth.utils.image import resize_image, flip_image
from easydict import EasyDict

from luminoth.utils.image import (
resize_image, flip_image, random_patch
)
from luminoth.utils.test.gt_boxes import generate_gt_boxes


class ImageTest(tf.test.TestCase):
def setUp(self):
self._random_patch_config = EasyDict({
'min_height': 400,
'min_width': 400,
})

def _gen_image(self, *shape):
return np.random.rand(*shape)

Expand Down Expand Up @@ -44,18 +54,30 @@ def _flip_image(self, image_array, boxes_array=None, left_right=False,
feed_dict = {
image: image_array,
}
config = EasyDict({
'left_right': left_right,
'up_down': up_down,
})
if boxes_array is not None:
boxes = tf.placeholder(bboxes_dtype, boxes_array.shape)
feed_dict[boxes] = boxes_array
else:
boxes = None
flipped = flip_image(
image, bboxes=boxes, left_right=left_right, up_down=up_down
image, config, bboxes=boxes,
)
with self.test_session() as sess:
flipped_dict = sess.run(flipped, feed_dict=feed_dict)
return flipped_dict['image'], flipped_dict.get('bboxes')

def _random_patch(self, image, config, bboxes=None):
with self.test_session() as sess:
# passing bboxes=None throws an error.
patch = random_patch(image, config, bboxes=bboxes, debug=True)
return_dict = sess.run(patch)
ret_bboxes = return_dict.get('bboxes')
return return_dict['image'], ret_bboxes

def testResizeOnlyImage(self):
# No min or max size, it doesn't change the image.
resized_image, _, scale = self._resize_image(
Expand Down Expand Up @@ -261,6 +283,48 @@ def testFlipBboxesDiffDtype(self):
flipped_boxes_float, flipped_boxes_int
)

def testRandomPatchImageBboxes(self):
"""Tests the integrity of the return values of random_patch
When bboxes is not None.
"""
im_shape = (800, 600, 3)
total_boxes = 20
# We don't care about the label
label = 3
# First test case, we use randomly generated image and bboxes.
image, bboxes = self._get_image_with_boxes(im_shape, total_boxes)
# Add a label to each bbox.
bboxes_w_label = tf.concat(
[
bboxes,
tf.fill((bboxes.shape[0], 1), label)
],
axis=1
)
config = self._random_patch_config
ret_image, ret_bboxes = self._random_patch(
image, config, bboxes_w_label
)
# Assertions
self.assertLessEqual(ret_bboxes.shape[0], total_boxes)
self.assertTrue(np.all(ret_bboxes >= 0))
self.assertTrue(np.all(ret_bboxes[:, :4] <= ret_image.shape[1]))
self.assertTrue(np.all(ret_image.shape <= im_shape))

def testRandomPatchOnlyImage(self):
"""Tests the integrity of the return values of random_patch
When bboxes is None.
"""
im_shape = (600, 800, 3)
image = self._gen_image(*im_shape)
config = self._random_patch_config
ret_image, ret_bboxes = self._random_patch(image, config)
# Assertions
self.assertTrue(np.all(ret_image.shape <= im_shape))
self.assertIs(ret_bboxes, None)


if __name__ == '__main__':
tf.test.main()

0 comments on commit 2feaa30

Please sign in to comment.