Skip to content

Commit

Permalink
Add batch support for various image_ops (#14854)
Browse files Browse the repository at this point in the history
* Change fix_image_flip_shape to create shape based on rank

* Refactor duplicate code to _EnsureTensorIs4D

* Convert flip_up_down

* Temporarily comment out ValueError Check

* Add batch support for flip_left_right

* Add batch support for random_flip_left_right

* Add batch support for random_flip_up_down

* Add batch support for transpose_image

* Add batch support for rot90

* Correct comments

* Refactor so as not to introduce new method

* Add tests for batch inputs

* Fix test to expect 3 or 4 dims

* Fix misc Pylint issues in image_ops_impl.py

* Fix misc Pyline issues in image_ops_test.py

* Refactor into _flip_image

* Correct Idempotent to Involution

* Check if >20 images were flipped

* Reverse condition in rot90

* Remove duplicate comment

* Address feedback

* Punctuation
  • Loading branch information
JoshVarty authored and martinwicke committed Dec 7, 2017
1 parent 8d3a25a commit 20aa9e0
Show file tree
Hide file tree
Showing 2 changed files with 322 additions and 116 deletions.
244 changes: 143 additions & 101 deletions tensorflow/python/ops/image_ops_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,81 @@ def _CheckAtLeast3DImage(image, require_static=True):
return []


def fix_image_flip_shape(image, result):
"""Set the shape to 3 dimensional if we don't know anything else.
def _EnsureTensorIs4D(image):
"""Converts `image` to a 4-D Tensor if it is not already one.
Args:
image: 4-D Tensor of shape `[batch, height, width, channels]` or
3-D Tensor of shape `[height, width, channels]`.
Raises:
ValueError: if image is not a 3-D or 4-D Tensor.
Returns:
If `image` was 4-D, a 4-D float Tensor of shape
`[batch, width, height, channels]`
If `image` was 3-D, a 4-D float Tensor of shape
`[1, width, height, channels]`
"""
original_shape = image.get_shape()
is_batch = True
if original_shape.ndims == 3:
is_batch = False
image = array_ops.expand_dims(image, 0)
elif original_shape.ndims is None:
is_batch = False
image = array_ops.expand_dims(image, 0)
image.set_shape([None] * 4)
elif original_shape.ndims != 4:
raise ValueError('\'image\' must have either 3 or 4 dimensions.')

return (image, is_batch)

def _flip_image(image, axis, random=False, seed=None):
"""
Flips image(s) around a given axis.
Args:
image: 4-D Tensor of shape `[batch, height, width, channels]` or
3-D Tensor of shape `[height, width, channels]`.
axis: A Python integer representing the axis on which the image(s)
will be flipped. Note: The provided axis must be specified relative
to the shape `[batch, height, width, channels]` as 3-D images will
be expanded to fit this shape before being flipped.
random: A boolean representing whether or not we should flip the
image(s) at random.
seed: Python integer. Used to create a random seed. See
tf.set_random_seed for behavior.
Raises:
ValueError: if image is not a 3-D or 4-D Tensor.
Returns:
A tensor of the same type and shape as `image`
"""
image = ops.convert_to_tensor(image, name='image')
original_image = image
image, is_batch = _EnsureTensorIs4D(image)

image = control_flow_ops.with_dependencies(
_CheckAtLeast3DImage(image, require_static=False), image)

batch, _, _, _ = _ImageDimensions(image, rank=4)
flipped = array_ops.reverse(image, [axis])

if random == True:
uniform_random = random_ops.random_uniform([batch], 0, 1.0, seed=seed)
mirror_cond = math_ops.less(uniform_random, 0.5)
flipped = array_ops.where(mirror_cond, x=image, y=flipped)

if is_batch:
return fix_image_flip_shape(original_image, flipped, rank=4)

flipped = array_ops.squeeze(flipped, squeeze_dims=[0])
return fix_image_flip_shape(original_image, flipped, rank=3)


def fix_image_flip_shape(image, result, rank=3):
"""Set the shape to original dimensional if we don't know anything else.
Args:
image: original image size
Expand All @@ -195,171 +268,174 @@ def fix_image_flip_shape(image, result):

image_shape = image.get_shape()
if image_shape == tensor_shape.unknown_shape():
result.set_shape([None, None, None])
result.set_shape([None] * rank)
else:
result.set_shape(image_shape)
return result


def random_flip_up_down(image, seed=None):
"""Randomly flips an image vertically (upside down).
"""Randomly flips image(s) vertically (upside down).
With a 1 in 2 chance, outputs the contents of `image` flipped along the first
dimension, which is `height`. Otherwise output the image as-is.
With a 1 in 2 chance, outputs the contents of `image` flipped along the height
dimension. Otherwise output the image as-is.
Args:
image: A 3-D tensor of shape `[height, width, channels].`
image: 4-D Tensor of shape `[batch, height, width, channels]` or
3-D Tensor of shape `[height, width, channels]`.
seed: A Python integer. Used to create a random seed. See
@{tf.set_random_seed}
for behavior.
Returns:
A 3-D tensor of the same type and shape as `image`.
A tensor of the same type and shape as `image`.
Raises:
ValueError: if the shape of `image` not supported.
"""
image = ops.convert_to_tensor(image, name='image')
image = control_flow_ops.with_dependencies(
_Check3DImage(image, require_static=False), image)
uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
mirror_cond = math_ops.less(uniform_random, .5)
result = control_flow_ops.cond(mirror_cond,
lambda: array_ops.reverse(image, [0]),
lambda: image)
return fix_image_flip_shape(image, result)
return _flip_image(image, axis=1, random=True, seed=seed)


def random_flip_left_right(image, seed=None):
"""Randomly flip an image horizontally (left to right).
"""Randomly flip image(s) horizontally (left to right).
With a 1 in 2 chance, outputs the contents of `image` flipped along the
second dimension, which is `width`. Otherwise output the image as-is.
width dimension. Otherwise output the image as-is.
Args:
image: A 3-D tensor of shape `[height, width, channels].`
image: 4-D Tensor of shape `[batch, height, width, channels]` or
3-D Tensor of shape `[height, width, channels]`.
seed: A Python integer. Used to create a random seed. See
@{tf.set_random_seed}
for behavior.
Returns:
A 3-D tensor of the same type and shape as `image`.
A tensor of the same type and shape as `image`.
Raises:
ValueError: if the shape of `image` not supported.
"""
image = ops.convert_to_tensor(image, name='image')
image = control_flow_ops.with_dependencies(
_Check3DImage(image, require_static=False), image)
uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
mirror_cond = math_ops.less(uniform_random, .5)
result = control_flow_ops.cond(mirror_cond,
lambda: array_ops.reverse(image, [1]),
lambda: image)
return fix_image_flip_shape(image, result)
return _flip_image(image, axis=2, random=True, seed=seed)


def flip_left_right(image):
"""Flip an image horizontally (left to right).
Outputs the contents of `image` flipped along the second dimension, which is
`width`.
Outputs the contents of `image` flipped along the width dimension.
See also `reverse()`.
Args:
image: A 3-D tensor of shape `[height, width, channels].`
image: 4-D Tensor of shape `[batch, height, width, channels]` or
3-D Tensor of shape `[height, width, channels]`.
Returns:
A 3-D tensor of the same type and shape as `image`.
A tensor of the same type and shape as `image`.
Raises:
ValueError: if the shape of `image` not supported.
"""
image = ops.convert_to_tensor(image, name='image')
image = control_flow_ops.with_dependencies(
_Check3DImage(image, require_static=False), image)
return fix_image_flip_shape(image, array_ops.reverse(image, [1]))

return _flip_image(image, axis=2, random=False)

def flip_up_down(image):
"""Flip an image vertically (upside down).
Outputs the contents of `image` flipped along the first dimension, which is
`height`.
Outputs the contents of `image` flipped along the height dimension.
See also `reverse()`.
Args:
image: A 3-D tensor of shape `[height, width, channels].`
image: 4-D Tensor of shape `[batch, height, width, channels]` or
3-D Tensor of shape `[height, width, channels]`.
Returns:
A 3-D tensor of the same type and shape as `image`.
A tensor of the same type and shape as `image`.
Raises:
ValueError: if the shape of `image` not supported.
"""
image = ops.convert_to_tensor(image, name='image')
image = control_flow_ops.with_dependencies(
_Check3DImage(image, require_static=False), image)
return fix_image_flip_shape(image, array_ops.reverse(image, [0]))
return _flip_image(image, axis=1, random=False)


def rot90(image, k=1, name=None):
"""Rotate an image counter-clockwise by 90 degrees.
"""Rotate image(s) counter-clockwise by 90 degrees.
Args:
image: A 3-D tensor of shape `[height, width, channels]`.
image: 4-D Tensor of shape `[batch, height, width, channels]` or
3-D Tensor of shape `[height, width, channels]`.
k: A scalar integer. The number of times the image is rotated by 90 degrees.
name: A name for this operation (optional).
Returns:
A rotated 3-D tensor of the same type and shape as `image`.
A rotated of the same type and shape as `image`.
Raises:
ValueError: if the shape of `image` not supported.
"""
with ops.name_scope(name, 'rot90', [image, k]) as scope:
image = ops.convert_to_tensor(image, name='image')
image, is_batch = _EnsureTensorIs4D(image)
image = control_flow_ops.with_dependencies(
_Check3DImage(image, require_static=False), image)
_CheckAtLeast3DImage(image, require_static=False), image)
k = ops.convert_to_tensor(k, dtype=dtypes.int32, name='k')
k.get_shape().assert_has_rank(0)
k = math_ops.mod(k, 4)

def _rot90():
return array_ops.transpose(array_ops.reverse_v2(image, [1]),
[1, 0, 2])
return array_ops.transpose(array_ops.reverse_v2(image, [2]),
[0, 2, 1, 3])
def _rot180():
return array_ops.reverse_v2(image, [0, 1])
return array_ops.reverse_v2(image, [1, 2])
def _rot270():
return array_ops.reverse_v2(array_ops.transpose(image, [1, 0, 2]),
[1])
return array_ops.reverse_v2(array_ops.transpose(image, [0, 2, 1, 3]),
[2])
cases = [(math_ops.equal(k, 1), _rot90),
(math_ops.equal(k, 2), _rot180),
(math_ops.equal(k, 3), _rot270)]

ret = control_flow_ops.case(cases, default=lambda: image, exclusive=True,
result = control_flow_ops.case(cases, default=lambda: image, exclusive=True,
name=scope)
ret.set_shape([None, None, image.get_shape()[2]])
return ret

shape = image.get_shape()
result.set_shape([shape[0], None, None, shape[3]])

if is_batch == True:
return result

result = array_ops.squeeze(result, squeeze_dims=[0])
return result


def transpose_image(image):
"""Transpose an image by swapping the first and second dimension.
"""Transpose an image by swapping the height and width dimension.
See also `transpose()`.
Args:
image: 3-D tensor of shape `[height, width, channels]`
image: 4-D Tensor of shape `[batch, height, width, channels]` or
3-D Tensor of shape `[height, width, channels]`.
Returns:
A 3-D tensor of shape `[width, height, channels]`
If `image` was 4-D, a 4-D float Tensor of shape
`[batch, width, height, channels]`
If `image` was 3-D, a 3-D float Tensor of shape
`[width, height, channels]`
Raises:
ValueError: if the shape of `image` not supported.
"""
image = ops.convert_to_tensor(image, name='image')
image, is_batch = _EnsureTensorIs4D(image)
image = control_flow_ops.with_dependencies(
_Check3DImage(image, require_static=False), image)
return array_ops.transpose(image, [1, 0, 2], name='transpose_image')
_CheckAtLeast3DImage(image, require_static=False), image)

result = array_ops.transpose(image, [0, 2, 1, 3], name='transpose_image')

if is_batch:
return result

result = array_ops.squeeze(result, squeeze_dims=[0])
return result


def central_crop(image, central_fraction):
Expand Down Expand Up @@ -445,21 +521,9 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height,
negative.
"""
image = ops.convert_to_tensor(image, name='image')

is_batch = True
image_shape = image.get_shape()
if image_shape.ndims == 3:
is_batch = False
image = array_ops.expand_dims(image, 0)
elif image_shape.ndims is None:
is_batch = False
image = array_ops.expand_dims(image, 0)
image.set_shape([None] * 4)
elif image_shape.ndims != 4:
raise ValueError('\'image\' must have either 3 or 4 dimensions.')
image, is_batch = _EnsureTensorIs4D(image)

assert_ops = _CheckAtLeast3DImage(image, require_static=False)

batch, height, width, depth = _ImageDimensions(image, rank=4)

after_padding_width = target_width - offset_width - width
Expand Down Expand Up @@ -524,21 +588,9 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height,
negative, or either `target_height` or `target_width` is not positive.
"""
image = ops.convert_to_tensor(image, name='image')

is_batch = True
image_shape = image.get_shape()
if image_shape.ndims == 3:
is_batch = False
image = array_ops.expand_dims(image, 0)
elif image_shape.ndims is None:
is_batch = False
image = array_ops.expand_dims(image, 0)
image.set_shape([None] * 4)
elif image_shape.ndims != 4:
raise ValueError('\'image\' must have either 3 or 4 dimensions.')
image, is_batch = _EnsureTensorIs4D(image)

assert_ops = _CheckAtLeast3DImage(image, require_static=False)

batch, height, width, depth = _ImageDimensions(image, rank=4)

assert_ops += _assert(offset_width >= 0, ValueError,
Expand Down Expand Up @@ -599,17 +651,7 @@ def resize_image_with_crop_or_pad(image, target_height, target_width):
`[new_height, new_width, channels]`.
"""
image = ops.convert_to_tensor(image, name='image')
image_shape = image.get_shape()
is_batch = True
if image_shape.ndims == 3:
is_batch = False
image = array_ops.expand_dims(image, 0)
elif image_shape.ndims is None:
is_batch = False
image = array_ops.expand_dims(image, 0)
image.set_shape([None] * 4)
elif image_shape.ndims != 4:
raise ValueError('\'image\' must have either 3 or 4 dimensions.')
image, is_batch = _EnsureTensorIs4D(image)

assert_ops = _CheckAtLeast3DImage(image, require_static=False)
assert_ops += _assert(target_width > 0, ValueError,
Expand Down
Loading

0 comments on commit 20aa9e0

Please sign in to comment.