Skip to content

Commit

Permalink
Add return_index_map argument in ssim()
Browse files Browse the repository at this point in the history
  • Loading branch information
CohenAriel committed Aug 9, 2022
1 parent 6c36729 commit 8f5a1b1
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 11 deletions.
29 changes: 20 additions & 9 deletions tensorflow/python/ops/image_ops_impl.py
Expand Up @@ -4268,7 +4268,8 @@ def _ssim_per_channel(img1,
filter_size=11,
filter_sigma=1.5,
k1=0.01,
k2=0.03):
k2=0.03,
return_index_map=False):
"""Computes SSIM index between img1 and img2 per color channel.
This function matches the standard SSIM implementation from:
Expand All @@ -4290,6 +4291,7 @@ def _ssim_per_channel(img1,
k1: Default value 0.01
k2: Default value 0.03 (SSIM is less sensitivity to K2 for lower values, so
it would be better if we took the values in the range of 0 < K2 < 0.4).
return_index_map: If True returns local SSIM map instead of the global mean.
Returns:
A pair of tensors containing and channel-wise SSIM and contrast-structure
Expand Down Expand Up @@ -4338,9 +4340,12 @@ def reducer(x):
k2)

# Average over the second and the third from the last: height, width.
axes = constant_op.constant([-3, -2], dtype=dtypes.int32)
ssim_val = math_ops.reduce_mean(luminance * cs, axes)
cs = math_ops.reduce_mean(cs, axes)
if return_index_map:
ssim_val = luminance * cs
else:
axes = constant_op.constant([-3, -2], dtype=dtypes.int32)
ssim_val = math_ops.reduce_mean(luminance * cs, axes)
cs = math_ops.reduce_mean(cs, axes)
return ssim_val, cs


Expand All @@ -4352,7 +4357,8 @@ def ssim(img1,
filter_size=11,
filter_sigma=1.5,
k1=0.01,
k2=0.03):
k2=0.03,
return_index_map=False):
"""Computes SSIM index between img1 and img2.
This function is based on the standard SSIM implementation from:
Expand Down Expand Up @@ -4405,11 +4411,15 @@ def ssim(img1,
k1: Default value 0.01
k2: Default value 0.03 (SSIM is less sensitivity to K2 for lower values, so
it would be better if we took the values in the range of 0 < K2 < 0.4).
return_index_map: If True returns local SSIM map instead of the global mean.
Returns:
A tensor containing an SSIM value for each image in batch. Returned SSIM
values are in range (-1, 1], when pixel values are non-negative. Returns
a tensor with shape: broadcast(img1.shape[:-3], img2.shape[:-3]).
A tensor containing an SSIM value for each image in batch or a tensor
containing an SSIM value for each pixel for each image in batch if
return_index_map is True. Returned SSIM values are in range (-1, 1], when
pixel values are non-negative. Returns a tensor with shape:
broadcast(img1.shape[:-3], img2.shape[:-3]) or broadcast(img1.shape[:-1],
img2.shape[:-1]).
"""
with ops.name_scope(None, 'SSIM', [img1, img2]):
# Convert to tensor if needed.
Expand All @@ -4427,7 +4437,8 @@ def ssim(img1,
img1 = convert_image_dtype(img1, dtypes.float32)
img2 = convert_image_dtype(img2, dtypes.float32)
ssim_per_channel, _ = _ssim_per_channel(img1, img2, max_val, filter_size,
filter_sigma, k1, k2)
filter_sigma, k1, k2,
return_index_map)
# Compute average over color channels.
return math_ops.reduce_mean(ssim_per_channel, [-1])

Expand Down
27 changes: 27 additions & 0 deletions tensorflow/python/ops/image_ops_test.py
Expand Up @@ -5779,6 +5779,33 @@ def testInt(self):
self.assertAllClose(
self.evaluate(ssim_uint8), self.evaluate(ssim_float32), atol=0.001)

def testWithIndexMap(self):
img1 = self._RandomImage((1, 16, 16, 3), 255)
img2 = self._RandomImage((1, 16, 16, 3), 255)

ssim_locals = image_ops.ssim(
img1,
img2,
1.0,
filter_size=11,
filter_sigma=1.5,
k1=0.01,
k2=0.03,
return_index_map=True)
self.assertEqual(ssim_locals.shape, (1, 6, 6))

ssim_global = image_ops.ssim(
img1,
img2,
1.0,
filter_size=11,
filter_sigma=1.5,
k1=0.01,
k2=0.03)

axes = constant_op.constant([-2, -1], dtype=dtypes.int32)
self.assertAllClose(ssim_global, math_ops.reduce_mean(ssim_locals, axes))


class MultiscaleSSIMTest(test_util.TensorFlowTestCase):
"""Tests for MS-SSIM."""
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt
Expand Up @@ -254,7 +254,7 @@ tf_module {
}
member_method {
name: "ssim"
argspec: "args=[\'img1\', \'img2\', \'max_val\', \'filter_size\', \'filter_sigma\', \'k1\', \'k2\'], varargs=None, keywords=None, defaults=[\'11\', \'1.5\', \'0.01\', \'0.03\'], "
argspec: "args=[\'img1\', \'img2\', \'max_val\', \'filter_size\', \'filter_sigma\', \'k1\', \'k2\', \'return_index_map\'], varargs=None, keywords=None, defaults=[\'11\', \'1.5\', \'0.01\', \'0.03\', \'False\'], "
}
member_method {
name: "ssim_multiscale"
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt
Expand Up @@ -226,7 +226,7 @@ tf_module {
}
member_method {
name: "ssim"
argspec: "args=[\'img1\', \'img2\', \'max_val\', \'filter_size\', \'filter_sigma\', \'k1\', \'k2\'], varargs=None, keywords=None, defaults=[\'11\', \'1.5\', \'0.01\', \'0.03\'], "
argspec: "args=[\'img1\', \'img2\', \'max_val\', \'filter_size\', \'filter_sigma\', \'k1\', \'k2\', \'return_index_map\'], varargs=None, keywords=None, defaults=[\'11\', \'1.5\', \'0.01\', \'0.03\', \'False\'], "
}
member_method {
name: "ssim_multiscale"
Expand Down

0 comments on commit 8f5a1b1

Please sign in to comment.