Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved tf.image.adjust_gamma for uint8 images #26508

Merged
merged 4 commits into from
May 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 15 additions & 19 deletions tensorflow/python/ops/image_ops_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1587,50 +1587,46 @@ def adjust_contrast(images, contrast_factor):
@tf_export('image.adjust_gamma')
def adjust_gamma(image, gamma=1, gain=1):
"""Performs Gamma Correction on the input image.

Also known as Power Law Transform. This function transforms the
input image pixelwise according to the equation `Out = In**gamma`
after scaling each pixel to the range 0 to 1.

Also known as Power Law Transform. This function converts the
input images at first to float representation, then transforms them
pixelwise according to the equation `Out = gain * In**gamma`,
and then converts the back to the original data type.
Args:
image : A Tensor.
image : RGB image or images to adjust.
gamma : A scalar or tensor. Non negative real number.
gain : A scalar or tensor. The constant multiplier.

Returns:
A Tensor. Gamma corrected output image.

A Tensor. A Gamma-adjusted tensor of the same shape and type as `image`.
Raises:
ValueError: If gamma is negative.

Notes:
For gamma greater than 1, the histogram will shift towards left and
the output image will be darker than the input image.
For gamma less than 1, the histogram will shift towards right and
the output image will be brighter than the input image.

References:
[1] http://en.wikipedia.org/wiki/Gamma_correction
"""

with ops.name_scope(None, 'adjust_gamma', [image, gamma, gain]) as name:
# Convert pixel value to DT_FLOAT for computing adjusted image.
img = ops.convert_to_tensor(image, name='img', dtype=dtypes.float32)
# Keep image dtype for computing the scale of corresponding dtype.
image = ops.convert_to_tensor(image, name='image')
# Remember original dtype to so we can convert back if needed
orig_dtype = image.dtype

if orig_dtype in [dtypes.float16, dtypes.float32]:
flt_image = image
else:
flt_image = convert_image_dtype(image, dtypes.float32)

assert_op = _assert(gamma >= 0, ValueError,
'Gamma should be a non-negative real number.')
if assert_op:
gamma = control_flow_ops.with_dependencies(assert_op, gamma)

# scale = max(dtype) - min(dtype).
scale = constant_op.constant(
image.dtype.limits[1] - image.dtype.limits[0], dtype=dtypes.float32)
# According to the definition of gamma correction.
adjusted_img = (img / scale)**gamma * scale * gain
adjusted_img = gain * flt_image**gamma

return adjusted_img
return convert_image_dtype(adjusted_img, orig_dtype, saturate=True)


@tf_export('image.convert_image_dtype')
Expand Down
143 changes: 74 additions & 69 deletions tensorflow/python/ops/image_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,118 +239,123 @@ def testShapeInference(self):


class AdjustGamma(test_util.TensorFlowTestCase):

def test_adjust_gamma_one(self):
"""Same image should be returned for gamma equal to one"""
@test_util.run_deprecated_v1
def test_adjust_gamma_less_zero_float32(self):
"""White image should be returned for gamma equal to zero"""
with self.cached_session():
x_data = np.random.uniform(0, 255, (8, 8))
x_data = np.random.uniform(0, 1.0, (8, 8))
x_np = np.array(x_data, dtype=np.float32)

x = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.adjust_gamma(x, gamma=1)

y_tf = self.evaluate(y)
y_np = x_np

self.assertAllClose(y_tf, y_np, 1e-6)
err_msg = "Gamma should be a non-negative real number"
with self.assertRaisesRegexp(ValueError, err_msg):
image_ops.adjust_gamma(x, gamma=-1)
zimmerrol marked this conversation as resolved.
Show resolved Hide resolved

def test_adjust_gamma_less_zero(self):
@test_util.run_deprecated_v1
def test_adjust_gamma_less_zero_uint8(self):
"""White image should be returned for gamma equal to zero"""
with self.cached_session():
x_data = np.random.uniform(0, 255, (8, 8))
x_np = np.array(x_data, dtype=np.float32)
x_np = np.array(x_data, dtype=np.uint8)

x = constant_op.constant(x_np, shape=x_np.shape)

err_msg = "Gamma should be a non-negative real number."

try:
err_msg = "Gamma should be a non-negative real number"
with self.assertRaisesRegexp(ValueError, err_msg):
image_ops.adjust_gamma(x, gamma=-1)
except Exception as e:
if err_msg not in str(e):
raise
else:
raise AssertionError("Exception not raised: %s" % err_msg)

@test_util.run_deprecated_v1
def test_adjust_gamma_less_zero_tensor(self):
"""White image should be returned for gamma equal to zero"""
with self.cached_session():
x_data = np.random.uniform(0, 255, (8, 8))
x_data = np.random.uniform(0, 1.0, (8, 8))
x_np = np.array(x_data, dtype=np.float32)

x = constant_op.constant(x_np, shape=x_np.shape)
y = constant_op.constant(-1.0, dtype=dtypes.float32)

image = image_ops.adjust_gamma(x, gamma=y)

err_msg = "Gamma should be a non-negative real number."
try:
err_msg = "Gamma should be a non-negative real number"
with self.assertRaisesRegexp(errors.InvalidArgumentError, err_msg):
self.evaluate(image)
except Exception as e:
if err_msg not in str(e):
raise
else:
raise AssertionError("Exception not raised: %s" % err_msg)

def test_adjust_gamma_zero(self):
"""White image should be returned for gamma equal to zero"""
def _test_adjust_gamma_uint8(self, gamma):
"""Verifying the output with expected results for gamma
correction for uint8 images"""
with self.cached_session():
x_data = np.random.uniform(0, 255, (8, 8))
x_np = np.array(x_data, dtype=np.float32)

x_np = np.random.uniform(0, 255, (8, 8)).astype(np.uint8)
x = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.adjust_gamma(x, gamma=0)

y_tf = self.evaluate(y)
y = image_ops.adjust_gamma(x, gamma=gamma)
y_tf = np.trunc(y.eval())

dtype = x.dtype.as_numpy_dtype
y_np = np.array([dtypes.dtype_range[dtype][1]] * x_np.size)
y_np = y_np.reshape((8, 8))
# calculate gamma correction using numpy
# firstly, transform uint8 to float representation
# then perform correction
y_np = np.power(x_np / 255.0, gamma)
# convert correct numpy image back to uint8 type
y_np = np.trunc(np.clip(y_np * 255.5, 0, 255.0))

self.assertAllClose(y_tf, y_np, 1e-6)

@test_util.run_deprecated_v1
def test_adjust_gamma_less_one(self):
def _test_adjust_gamma_float32(self, gamma):
"""Verifying the output with expected results for gamma
correction with gamma equal to half"""
correction for float32 images"""
with self.cached_session():
x_np = np.arange(0, 255, 4, np.uint8).reshape(8, 8)
y = image_ops.adjust_gamma(x_np, gamma=0.5)
y_tf = np.trunc(y.eval())
x_np = np.random.uniform(0, 1.0, (8, 8))
x = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.adjust_gamma(x, gamma=gamma)
y_tf = y.eval()

y_np = np.array(
[[0, 31, 45, 55, 63, 71, 78, 84], [
90, 95, 100, 105, 110, 115, 119, 123
], [127, 131, 135, 139, 142, 146, 149, 153], [
156, 159, 162, 165, 168, 171, 174, 177
], [180, 183, 186, 188, 191, 194, 196, 199], [
201, 204, 206, 209, 211, 214, 216, 218
], [221, 223, 225, 228, 230, 232, 234, 236],
[238, 241, 243, 245, 247, 249, 251, 253]],
dtype=np.float32)
y_np = np.clip(np.power(x_np, gamma), 0, 1.0)

self.assertAllClose(y_tf, y_np, 1e-6)

@test_util.run_deprecated_v1
def test_adjust_gamma_greater_one(self):
def test_adjust_gamma_one_float32(self):
"""Same image should be returned for gamma equal to one"""
self._test_adjust_gamma_float32(1.0)

@test_util.run_deprecated_v1
def test_adjust_gamma_one_uint8(self):
self._test_adjust_gamma_uint8(1.0)

@test_util.run_deprecated_v1
def test_adjust_gamma_zero_uint8(self):
"""White image should be returned for gamma equal
to zero for uint8 images"""
self._test_adjust_gamma_uint8(gamma=0.0)

@test_util.run_deprecated_v1
def test_adjust_gamma_less_one_uint8(self):
"""Verifying the output with expected results for gamma
correction with gamma equal to two"""
with self.cached_session():
x_np = np.arange(0, 255, 4, np.uint8).reshape(8, 8)
y = image_ops.adjust_gamma(x_np, gamma=2)
y_tf = np.trunc(y.eval())
correction with gamma equal to half for uint8 images"""
self._test_adjust_gamma_uint8(gamma=0.5)

y_np = np.array(
[[0, 0, 0, 0, 1, 1, 2, 3], [4, 5, 6, 7, 9, 10, 12, 14], [
16, 18, 20, 22, 25, 27, 30, 33
], [36, 39, 42, 45, 49, 52, 56, 60], [64, 68, 72, 76, 81, 85, 90, 95],
[100, 105, 110, 116, 121, 127, 132, 138], [
144, 150, 156, 163, 169, 176, 182, 189
], [196, 203, 211, 218, 225, 233, 241, 249]],
dtype=np.float32)
@test_util.run_deprecated_v1
def test_adjust_gamma_greater_one_uint8(self):
"""Verifying the output with expected results for gamma
correction for uint8 images"""
self._test_adjust_gamma_uint8(gamma=1.0)

self.assertAllClose(y_tf, y_np, 1e-6)
@test_util.run_deprecated_v1
def test_adjust_gamma_less_one_float32(self):
"""Verifying the output with expected results for gamma
correction with gamma equal to half for float32 images"""
self._test_adjust_gamma_float32(0.5)

@test_util.run_deprecated_v1
def test_adjust_gamma_greater_one_float32(self):
"""Verifying the output with expected results for gamma
correction with gamma equal to two for float32 images"""
self._test_adjust_gamma_float32(1.0)

@test_util.run_deprecated_v1
def test_adjust_gamma_zero_float32(self):
"""White image should be returned for gamma equal
to zero for float32 images"""
self._test_adjust_gamma_float32(0.0)


class AdjustHueTest(test_util.TensorFlowTestCase):
Expand Down