Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add decode uint16 PNG images support for tf.image.decode_image. #18628

Merged
merged 6 commits into from
Jun 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 14 additions & 8 deletions tensorflow/python/ops/image_ops_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,13 +1483,13 @@ def is_jpeg(contents, name=None):


@tf_export('image.decode_image')
def decode_image(contents, channels=None, name=None):
def decode_image(contents, channels=None, dtype=dtypes.uint8, name=None):
"""Convenience function for `decode_bmp`, `decode_gif`, `decode_jpeg`,
and `decode_png`.

Detects whether an image is a BMP, GIF, JPEG, or PNG, and performs the
appropriate operation to convert the input bytes `string` into a `Tensor` of
type `uint8`.
appropriate operation to convert the input bytes `string` into a `Tensor`
of type `dtype`.

Note: `decode_gif` returns a 4-D array `[num_frames, height, width, 3]`, as
opposed to `decode_bmp`, `decode_jpeg` and `decode_png`, which return 3-D
Expand All @@ -1501,10 +1501,11 @@ def decode_image(contents, channels=None, name=None):
contents: 0-D `string`. The encoded image bytes.
channels: An optional `int`. Defaults to `0`. Number of color channels for
the decoded image.
dtype: The desired DType of the returned `Tensor`.
name: A name for the operation (optional)

Returns:
`Tensor` with type `uint8` with shape `[height, width, num_channels]` for
`Tensor` with type `dtype` and shape `[height, width, num_channels]` for
BMP, JPEG, and PNG images and shape `[num_frames, height, width, 3]` for
GIF images.

Expand All @@ -1528,7 +1529,7 @@ def _bmp():
channels_msg = 'Channels must be in (None, 0, 3) when decoding BMP images'
assert_channels = control_flow_ops.Assert(good_channels, [channels_msg])
with ops.control_dependencies([assert_decode, assert_channels]):
return gen_image_ops.decode_bmp(contents)
return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype)

def _gif():
# Create assert to make sure that channels is not set to 1
Expand All @@ -1541,7 +1542,7 @@ def _gif():
channels_msg = 'Channels must be in (None, 0, 3) when decoding GIF images'
assert_channels = control_flow_ops.Assert(good_channels, [channels_msg])
with ops.control_dependencies([assert_channels]):
return gen_image_ops.decode_gif(contents)
return convert_image_dtype(gen_image_ops.decode_gif(contents), dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably remove convert_image_dtype calls from each of these helper functions and instead do it once at the end so that this function will now support types other than uint8 and uint16 for all image formats.


def check_gif():
# Create assert op to check that bytes are GIF decodable
Expand All @@ -1550,7 +1551,11 @@ def check_gif():

def _png():
"""Decodes a PNG image."""
return gen_image_ops.decode_png(contents, channels)
return convert_image_dtype(
gen_image_ops.decode_png(contents, channels,
dtype=dtypes.uint8
if dtype == dtypes.uint8
else dtypes.uint16), dtype)

def check_png():
"""Checks if an image is PNG."""
Expand All @@ -1566,7 +1571,8 @@ def _jpeg():
'images')
assert_channels = control_flow_ops.Assert(good_channels, [channels_msg])
with ops.control_dependencies([assert_channels]):
return gen_image_ops.decode_jpeg(contents, channels)
return convert_image_dtype(
gen_image_ops.decode_jpeg(contents, channels), dtype)

# Decode normal JPEG images (start with \xff\xd8\xff\xe0)
# as well as JPEG images with EXIF data (start with \xff\xd8\xff\xe1).
Expand Down
83 changes: 83 additions & 0 deletions tensorflow/python/ops/image_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3746,5 +3746,88 @@ def testSobelEdges5x3x4x2(self):
self.assertAllClose(expected_batch, actual_sobel)


class DecodeImageTest(test_util.TensorFlowTestCase):

def testJpegUint16(self):
with self.test_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/jpeg/testdata"
jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
image0 = image_ops.decode_image(jpeg0, dtype=dtypes.uint16)
image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0),
dtypes.uint16)
image0, image1 = sess.run([image0, image1])
self.assertAllEqual(image0, image1)

def testPngUint16(self):
with self.test_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/png/testdata"
png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
image0 = image_ops.decode_image(png0, dtype=dtypes.uint16)
image1 = image_ops.convert_image_dtype(
image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.uint16)
image0, image1 = sess.run([image0, image1])
self.assertAllEqual(image0, image1)

def testGifUint16(self):
with self.test_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/gif/testdata"
gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
image0 = image_ops.decode_image(gif0, dtype=dtypes.uint16)
image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0),
dtypes.uint16)
image0, image1 = sess.run([image0, image1])
self.assertAllEqual(image0, image1)

def testBmpUint16(self):
with self.test_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/bmp/testdata"
bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
image0 = image_ops.decode_image(bmp0, dtype=dtypes.uint16)
image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0),
dtypes.uint16)
image0, image1 = sess.run([image0, image1])
self.assertAllEqual(image0, image1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also test the float32 case? I think then we're happy.


def testJpegFloat32(self):
with self.test_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/jpeg/testdata"
jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
image0 = image_ops.decode_image(jpeg0, dtype=dtypes.float32)
image1 = image_ops.convert_image_dtype(image_ops.decode_jpeg(jpeg0),
dtypes.float32)
image0, image1 = sess.run([image0, image1])
self.assertAllEqual(image0, image1)

def testPngFloat32(self):
with self.test_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/png/testdata"
png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
image0 = image_ops.decode_image(png0, dtype=dtypes.float32)
image1 = image_ops.convert_image_dtype(
image_ops.decode_png(png0, dtype=dtypes.uint16), dtypes.float32)
image0, image1 = sess.run([image0, image1])
self.assertAllEqual(image0, image1)

def testGifFloat32(self):
with self.test_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/gif/testdata"
gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
image0 = image_ops.decode_image(gif0, dtype=dtypes.float32)
image1 = image_ops.convert_image_dtype(image_ops.decode_gif(gif0),
dtypes.float32)
image0, image1 = sess.run([image0, image1])
self.assertAllEqual(image0, image1)

def testBmpFloat32(self):
with self.test_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/bmp/testdata"
bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
image0 = image_ops.decode_image(bmp0, dtype=dtypes.float32)
image1 = image_ops.convert_image_dtype(image_ops.decode_bmp(bmp0),
dtypes.float32)
image0, image1 = sess.run([image0, image1])
self.assertAllEqual(image0, image1)


if __name__ == "__main__":
googletest.main()
2 changes: 1 addition & 1 deletion tensorflow/tools/api/golden/tensorflow.image.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ tf_module {
}
member_method {
name: "decode_image"
argspec: "args=[\'contents\', \'channels\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
argspec: "args=[\'contents\', \'channels\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'uint8\'>\", \'None\'], "
}
member_method {
name: "decode_jpeg"
Expand Down