Skip to content

Commit

Permalink
Add RaggedTensor support to tf.image.resize -- allows resizing a batc…
Browse files Browse the repository at this point in the history
…h of images that have different sizes to all have the same size. For now, this uses `tf.map_fn`, but if this proves to be too slow/inefficient, then we could look into other solutions.

PiperOrigin-RevId: 401762053
Change-Id: Ia3e825219f9c2449bb0e8bfab8c1ac48833ac815
  • Loading branch information
edloper authored and tensorflower-gardener committed Oct 8, 2021
1 parent bab047f commit 545fd4f
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 0 deletions.
26 changes: 26 additions & 0 deletions tensorflow/python/ops/ragged/BUILD
Expand Up @@ -27,6 +27,7 @@ py_library(
":ragged_functional_ops",
":ragged_gather_ops",
":ragged_getitem",
":ragged_image_ops",
":ragged_map_ops",
":ragged_math_ops",
":ragged_operators",
Expand Down Expand Up @@ -238,6 +239,7 @@ py_library(
":ragged_functional_ops",
":ragged_gather_ops",
":ragged_getitem",
":ragged_image_ops",
":ragged_map_ops",
":ragged_math_ops",
":ragged_operators",
Expand Down Expand Up @@ -417,6 +419,18 @@ py_library(
],
)

py_library(
name = "ragged_image_ops",
srcs = ["ragged_image_ops.py"],
srcs_version = "PY3",
deps = [
":ragged_tensor",
"//tensorflow/python:image_ops",
"//tensorflow/python:map_fn",
"//tensorflow/python/framework:tensor_spec",
],
)

py_library(
name = "ragged_map_ops",
srcs = ["ragged_map_ops.py"],
Expand Down Expand Up @@ -1097,6 +1111,18 @@ py_test(
],
)

py_test(
name = "ragged_resize_image_op_test",
srcs = ["ragged_resize_image_op_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":ragged", # fixdeps: keep
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)

py_test(
name = "ragged_map_fn_op_test",
size = "small",
Expand Down
98 changes: 98 additions & 0 deletions tensorflow/python/ops/ragged/ragged_image_ops.py
@@ -0,0 +1,98 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image operations for RaggedTensors."""

from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import map_fn
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util import dispatch


@dispatch.dispatch_for_api(image_ops.resize_images_v2)
def resize_images_v2(images: ragged_tensor.RaggedTensor,
size,
method=image_ops.ResizeMethod.BILINEAR,
preserve_aspect_ratio=False,
antialias=False,
name=None):
"""RaggedTensor dispatcher for tf.image.resize (tf-v2)."""
with ops.name_scope(name, "RaggedResizeImages", [images, size]):
return _resize_images(
image_ops.resize_images_v2,
images,
size,
method=method,
preserve_aspect_ratio=preserve_aspect_ratio,
antialias=antialias)


@dispatch.dispatch_for_api(image_ops.resize_images)
def resize_images_v1(images: ragged_tensor.RaggedTensor,
size,
method=image_ops.ResizeMethodV1.BILINEAR,
align_corners=False,
preserve_aspect_ratio=False,
name=None):
"""RaggedTensor dispatcher for tf.image.resize (tf-v1)."""
with ops.name_scope(name, "RaggedResizeImages", [images, size]):
return _resize_images(
image_ops.resize_images,
images,
size,
method=method,
preserve_aspect_ratio=preserve_aspect_ratio,
align_corners=align_corners)


def _resize_images(resize_op, images, size, **kwargs):
"""RaggedTensor dispatcher for tf.image.resize."""
if images.shape.rank != 4:
raise ValueError(
"tf.image.resize: images.shape.rank must be 4 if images is ragged.")

# Determine the output shape (excluding the batch dimension).
static_batch_size = tensor_shape.dimension_value(images.shape[0])
size = ops.convert_to_tensor(size, dtypes.int32, "size")
size_as_shape = tensor_util.constant_value_as_shape(size).with_rank(2)
out_shape = size_as_shape + images.shape[-1:]
out_spec = tensor_spec.TensorSpec(out_shape, dtypes.float32)

def resize_one(image):
if isinstance(image, ragged_tensor.RaggedTensor):
image = image.to_tensor()
return resize_op(image, size, **kwargs)

def resize_with_map():
return map_fn.map_fn_v2(resize_one, images, fn_output_signature=out_spec)

def empty_result():
channels = array_ops.shape(images.flat_values)[-1:]
return array_ops.zeros(array_ops.concat([[0], size, channels], axis=0))

if static_batch_size == 0:
return empty_result()
elif static_batch_size is not None:
return resize_with_map()
else:
empty_batch = math_ops.equal(images.nrows(), 0)
return control_flow_ops.cond(empty_batch, empty_result, resize_with_map)
1 change: 1 addition & 0 deletions tensorflow/python/ops/ragged/ragged_ops.py
Expand Up @@ -34,6 +34,7 @@
from tensorflow.python.ops.ragged import ragged_functional_ops
from tensorflow.python.ops.ragged import ragged_gather_ops
from tensorflow.python.ops.ragged import ragged_getitem
from tensorflow.python.ops.ragged import ragged_image_ops
from tensorflow.python.ops.ragged import ragged_map_ops
from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_operators
Expand Down
136 changes: 136 additions & 0 deletions tensorflow/python/ops/ragged/ragged_resize_image_op_test.py
@@ -0,0 +1,136 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for RaggedTensor dispatch of tf.images.resize."""

from absl.testing import parameterized

from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import googletest


@test_util.run_all_in_graph_and_eager_modes
class RaggedResizeImageOpTest(test_util.TensorFlowTestCase,
parameterized.TestCase):

def make_image_batch(self, sizes, channels):
if not sizes:
return ragged_tensor.RaggedTensor.from_tensor(
array_ops.zeros([0, 5, 5, channels]), ragged_rank=2)
images = [
array_ops.reshape(
math_ops.range(w * h * channels * 1.0), [w, h, channels])
for (w, h) in sizes
]
return ragged_concat_ops.stack(images)

@parameterized.parameters([
dict(src_sizes=[], dst_size=(4, 4), v1=True),
dict(src_sizes=[], dst_size=(4, 4), v1=False),
dict(src_sizes=[(2, 2)], dst_size=(4, 4), v1=True),
dict(src_sizes=[(2, 2)], dst_size=(4, 4), v1=False),
dict(src_sizes=[(2, 8), (3, 5), (10, 10)], dst_size=(5, 5), v1=True),
dict(src_sizes=[(2, 8), (3, 5), (10, 10)], dst_size=(5, 5), v1=False),
])
def testResize(self, src_sizes, dst_size, v1=False):
resize = image_ops.resize_images if v1 else image_ops.resize_images_v2

# Construct the input images.
channels = 3
images = self.make_image_batch(src_sizes, channels)
expected_shape = [len(src_sizes)] + list(dst_size) + [channels]

# Resize the ragged batch of images.
resized_images = resize(images, dst_size)
self.assertIsInstance(resized_images, ops.Tensor)
self.assertEqual(resized_images.shape.as_list(), expected_shape)

# Check that results for each image matches what we'd get with the
# non-batch version of tf.images.resize.
for i in range(len(src_sizes)):
actual = resized_images[i]
expected = resize(images[i].to_tensor(), dst_size)
self.assertAllClose(actual, expected)

@parameterized.parameters([
dict(src_shape=[None, None, None, None], src_sizes=[], dst_size=(4, 4)),
dict(src_shape=[None, None, None, 3], src_sizes=[], dst_size=(4, 4)),
dict(src_shape=[0, None, None, None], src_sizes=[], dst_size=(4, 4)),
dict(src_shape=[0, None, None, 3], src_sizes=[], dst_size=(4, 4)),
dict(
src_shape=[None, None, None, None],
src_sizes=[(2, 2)],
dst_size=(4, 4)),
dict(
src_shape=[None, None, None, None],
src_sizes=[(2, 8), (3, 5), (10, 10)],
dst_size=(5, 5)),
dict(
src_shape=[None, None, None, 1],
src_sizes=[(2, 8), (3, 5), (10, 10)],
dst_size=(5, 5)),
dict(
src_shape=[3, None, None, 1],
src_sizes=[(2, 8), (3, 5), (10, 10)],
dst_size=(5, 5)),
])
def testResizeWithPartialStaticShape(self, src_shape, src_sizes, dst_size):
channels = src_shape[-1] or 3
images = self.make_image_batch(src_sizes, channels)
rt_spec = ragged_tensor.RaggedTensorSpec(src_shape,
ragged_rank=images.ragged_rank)
expected_shape = [len(src_sizes)] + list(dst_size) + [channels]

# Use @tf.function to erase static shape information.
@def_function.function(input_signature=[rt_spec])
def do_resize(images):
return image_ops.resize_images_v2(images, dst_size)

resized_images = do_resize(images)
self.assertIsInstance(resized_images, ops.Tensor)
self.assertTrue(resized_images.shape.is_compatible_with(expected_shape))

# Check that results for each image matches what we'd get with the
# non-batch version of tf.images.resize.
for i in range(len(src_sizes)):
actual = resized_images[i]
expected = image_ops.resize_images_v2(images[i].to_tensor(), dst_size)
self.assertAllClose(actual, expected)

def testSizeIsTensor(self):
@def_function.function
def do_resize(images, new_size):
return image_ops.resize_images_v2(images, new_size)

src_images = self.make_image_batch([[5, 8], [3, 2], [10, 4]], 3)
resized_images = do_resize(src_images, constant_op.constant([2, 2]))
self.assertIsInstance(resized_images, ops.Tensor)
self.assertTrue(resized_images.shape.is_compatible_with([3, 2, 2, 3]))

def testBadRank(self):
rt = ragged_tensor.RaggedTensor.from_tensor(array_ops.zeros([5, 5, 3]))
with self.assertRaisesRegex(ValueError, 'rank must be 4'):
image_ops.resize_images_v2(rt, [10, 10])


if __name__ == '__main__':
googletest.main()

0 comments on commit 545fd4f

Please sign in to comment.