Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added default parameter for tf.image.ssim #27076

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
86 changes: 68 additions & 18 deletions tensorflow/python/ops/image_ops_impl.py
Expand Up @@ -2818,12 +2818,13 @@ def psnr(a, b, max_val, name=None):
with ops.control_dependencies(checks):
return array_ops.identity(psnr_val)


_SSIM_K1 = 0.01
_SSIM_K2 = 0.03


def _ssim_helper(x, y, reducer, max_val, compensation=1.0):
def _ssim_helper(x,
y,
reducer,
max_val,
compensation=1.0,
k1=0.01,
k2=0.03):
r"""Helper function for computing SSIM.

SSIM estimates covariances with weighted sums. The default parameters
Expand All @@ -2848,12 +2849,17 @@ def _ssim_helper(x, y, reducer, max_val, compensation=1.0):
max_val: The dynamic range (i.e., the difference between the maximum
possible allowed value and the minimum allowed value).
compensation: Compensation factor. See above.
k1: Default value 0.01
k2: Default value 0.03 (SSIM is less sensitivity to K2 for
lower values, so it would be better if we taken the values in range
of 0< K2 <0.4).

Returns:
A pair containing the luminance measure, and the contrast-structure measure.
"""
c1 = (_SSIM_K1 * max_val)**2
c2 = (_SSIM_K2 * max_val)**2

c1 = (k1 * max_val) ** 2
c2 = (k2 * max_val) ** 2

# SSIM luminance measure is
# (2 * mu_x * mu_y + c1) / (mu_x ** 2 + mu_y ** 2 + c1).
Expand Down Expand Up @@ -2894,7 +2900,13 @@ def _fspecial_gauss(size, sigma):
return array_ops.reshape(g, shape=[size, size, 1, 1])


def _ssim_per_channel(img1, img2, max_val=1.0):
def _ssim_per_channel(img1,
img2,
max_val=1.0,
filter_size=11,
filter_sigma=1.5,
k1=0.01,
k2=0.03):
"""Computes SSIM index between img1 and img2 per color channel.

This function matches the standard SSIM implementation from:
Expand All @@ -2911,13 +2923,19 @@ def _ssim_per_channel(img1, img2, max_val=1.0):
img2: Second image batch.
max_val: The dynamic range of the images (i.e., the difference between the
maximum the and minimum allowed values).
filter_size: Default value 11 (size of gaussian filter).
filter_sigma: Default value 1.5 (width of gaussian filter).
k1: Default value 0.01
k2: Default value 0.03 (SSIM is less sensitivity to K2 for
lower values, so it would be better if we taken the values in range
of 0< K2 <0.4).

Returns:
A pair of tensors containing and channel-wise SSIM and contrast-structure
values. The shape is [..., channels].
"""
filter_size = constant_op.constant(11, dtype=dtypes.int32)
filter_sigma = constant_op.constant(1.5, dtype=img1.dtype)
filter_size = constant_op.constant(filter_size, dtype=dtypes.int32)
filter_sigma = constant_op.constant(filter_sigma, dtype=img1.dtype)

shape1, shape2 = array_ops.shape_n([img1, img2])
checks = [
Expand Down Expand Up @@ -2955,7 +2973,8 @@ def reducer(x):
return array_ops.reshape(
y, array_ops.concat([shape[:-3], array_ops.shape(y)[1:]], 0))

luminance, cs = _ssim_helper(img1, img2, reducer, max_val, compensation)
luminance, cs = _ssim_helper(img1, img2, reducer, max_val, compensation,
k1, k2)

# Average over the second and the third from the last: height, width.
axes = constant_op.constant([-3, -2], dtype=dtypes.int32)
Expand All @@ -2965,7 +2984,13 @@ def reducer(x):


@tf_export('image.ssim')
def ssim(img1, img2, max_val):
def ssim(img1,
img2,
max_val,
filter_size=11,
filter_sigma=1.5,
k1=0.01,
k2=0.03):
"""Computes SSIM index between img1 and img2.

This function is based on the standard SSIM implementation from:
Expand All @@ -2990,12 +3015,14 @@ def ssim(img1, img2, max_val):
im1 = tf.decode_png('path/to/im1.png')
im2 = tf.decode_png('path/to/im2.png')
# Compute SSIM over tf.uint8 Tensors.
ssim1 = tf.image.ssim(im1, im2, max_val=255)
ssim1 = tf.image.ssim(im1, im2, max_val=255, filter_size=11,
filter_sigma=1.5, k1=0.01, k2=0.03)

# Compute SSIM over tf.float32 Tensors.
im1 = tf.image.convert_image_dtype(im1, tf.float32)
im2 = tf.image.convert_image_dtype(im2, tf.float32)
ssim2 = tf.image.ssim(im1, im2, max_val=1.0)
ssim2 = tf.image.ssim(im1, im2, max_val=1.0, filter_size=11,
filter_sigma=1.5, k1=0.01, k2=0.03)
# ssim1 and ssim2 both have type tf.float32 and are almost equal.
```

Expand All @@ -3004,6 +3031,12 @@ def ssim(img1, img2, max_val):
img2: Second image batch.
max_val: The dynamic range of the images (i.e., the difference between the
maximum the and minimum allowed values).
filter_size: Default value 11 (size of gaussian filter).
filter_sigma: Default value 1.5 (width of gaussian filter).
k1: Default value 0.01
k2: Default value 0.03 (SSIM is less sensitivity to K2 for
lower values, so it would be better if we taken the values in range
of 0< K2 <0.4).

Returns:
A tensor containing an SSIM value for each image in batch. Returned SSIM
Expand All @@ -3020,7 +3053,8 @@ def ssim(img1, img2, max_val):
max_val = convert_image_dtype(max_val, dtypes.float32)
img1 = convert_image_dtype(img1, dtypes.float32)
img2 = convert_image_dtype(img2, dtypes.float32)
ssim_per_channel, _ = _ssim_per_channel(img1, img2, max_val)
ssim_per_channel, _ = _ssim_per_channel(img1, img2, max_val, filter_size,
filter_sigma, k1, k2)
# Compute average over color channels.
return math_ops.reduce_mean(ssim_per_channel, [-1])

Expand All @@ -3030,7 +3064,14 @@ def ssim(img1, img2, max_val):


@tf_export('image.ssim_multiscale')
def ssim_multiscale(img1, img2, max_val, power_factors=_MSSSIM_WEIGHTS):
def ssim_multiscale(img1,
img2,
max_val,
power_factors=_MSSSIM_WEIGHTS,
filter_size=11,
filter_sigma=1.5,
k1=0.01,
k2=0.03):
"""Computes the MS-SSIM between img1 and img2.

This function assumes that `img1` and `img2` are image batches, i.e. the last
Expand All @@ -3054,6 +3095,12 @@ def ssim_multiscale(img1, img2, max_val, power_factors=_MSSSIM_WEIGHTS):
resolution's weight and each increasing scale corresponds to the image
being downsampled by 2. Defaults to (0.0448, 0.2856, 0.3001, 0.2363,
0.1333), which are the values obtained in the original paper.
filter_size: Default value 11 (size of gaussian filter).
filter_sigma: Default value 1.5 (width of gaussian filter).
k1: Default value 0.01
k2: Default value 0.03 (SSIM is less sensitivity to K2 for
lower values, so it would be better if we taken the values in range
of 0< K2 <0.4).

Returns:
A tensor containing an MS-SSIM value for each image in batch. The values
Expand Down Expand Up @@ -3124,7 +3171,10 @@ def do_pad(images, remainder):
]

# Overwrite previous ssim value since we only need the last one.
ssim_per_channel, cs = _ssim_per_channel(*imgs, max_val=max_val)
ssim_per_channel, cs = _ssim_per_channel(*imgs, max_val=max_val,
filter_size=filter_size,
filter_sigma=filter_sigma,
k1=k1, k2=k2)
mcs.append(nn_ops.relu(cs))

# Remove the cs score for the last scale. In the MS-SSIM calculation,
Expand Down
47 changes: 32 additions & 15 deletions tensorflow/python/ops/image_ops_test.py
Expand Up @@ -4679,7 +4679,8 @@ def testAgainstMatlab(self):
expected = self._ssim[np.triu_indices(3)]

ph = [array_ops.placeholder(dtype=dtypes.float32) for _ in range(2)]
ssim = image_ops.ssim(*ph, max_val=1.0)
ssim = image_ops.ssim(*ph, max_val=1.0, filter_size=11, filter_sigma=1.5,
k1=0.01, k2=0.03)
with self.cached_session(use_gpu=True):
scores = [ssim.eval(dict(zip(ph, t)))
for t in itertools.combinations_with_replacement(img, 2)]
Expand All @@ -4693,8 +4694,9 @@ def testBatch(self):
img1 = np.concatenate(img1)
img2 = np.concatenate(img2)

ssim = image_ops.ssim(constant_op.constant(img1),
constant_op.constant(img2), 1.0)
ssim = image_ops.ssim(constant_op.constant(img1),constant_op.constant(img2),
1.0, filter_size=11, filter_sigma=1.5, k1=0.01,
k2=0.03)
with self.cached_session(use_gpu=True):
self.assertAllClose(expected, self.evaluate(ssim), atol=1e-4)

Expand All @@ -4706,7 +4708,8 @@ def testBroadcast(self):
img1 = array_ops.expand_dims(img, axis=0) # batch dims: 1, 2.
img2 = array_ops.expand_dims(img, axis=1) # batch dims: 2, 1.

ssim = image_ops.ssim(img1, img2, 1.0)
ssim = image_ops.ssim(img1, img2, 1.0, filter_size=11, filter_sigma=1.5,
k1=0.01, k2=0.03)
with self.cached_session(use_gpu=True):
self.assertAllClose(expected, self.evaluate(ssim), atol=1e-4)

Expand All @@ -4720,8 +4723,9 @@ def testNegative(self):
img1 = img1.reshape((1, 16, 16, 1))
img2 = img2.reshape((1, 16, 16, 1))

ssim = image_ops.ssim(constant_op.constant(img1),
constant_op.constant(img2), 255)
ssim = image_ops.ssim(constant_op.constant(img1),constant_op.constant(img2),
255, filter_size=11, filter_sigma=1.5, k1=0.01,
k2=0.03)
with self.cached_session(use_gpu=True):
self.assertLess(ssim.eval(), 0)

Expand All @@ -4731,10 +4735,12 @@ def testInt(self):
img2 = self._RandomImage((1, 16, 16, 3), 255)
img1 = constant_op.constant(img1, dtypes.uint8)
img2 = constant_op.constant(img2, dtypes.uint8)
ssim_uint8 = image_ops.ssim(img1, img2, 255)
ssim_uint8 = image_ops.ssim(img1, img2, 255, filter_size=11, filter_sigma=1.5,
k1 = 0.01, k2 = 0.03)
img1 = image_ops.convert_image_dtype(img1, dtypes.float32)
img2 = image_ops.convert_image_dtype(img2, dtypes.float32)
ssim_float32 = image_ops.ssim(img1, img2, 1.0)
ssim_float32 = image_ops.ssim(img1, img2, 1.0, filter_size=11,
filter_sigma=1.5, k1 = 0.01, k2 = 0.03)
with self.cached_session(use_gpu=True):
self.assertAllClose(
ssim_uint8.eval(), self.evaluate(ssim_float32), atol=0.001)
Expand Down Expand Up @@ -4777,7 +4783,8 @@ def testAgainstMatlab(self):
expected = self._msssim[np.triu_indices(3)]

ph = [array_ops.placeholder(dtype=dtypes.float32) for _ in range(2)]
msssim = image_ops.ssim_multiscale(*ph, max_val=1.0)
msssim = image_ops.ssim_multiscale(*ph, max_val=1.0,filter_size=11,
filter_sigma=1.5, k1=0.01, k2=0.03)
with self.cached_session(use_gpu=True):
scores = [msssim.eval(dict(zip(ph, t)))
for t in itertools.combinations_with_replacement(img, 2)]
Expand All @@ -4791,7 +4798,9 @@ def testUnweightedIsDifferentiable(self):
scalar = constant_op.constant(1.0, dtype=dtypes.float32)
scaled_ph = [x * scalar for x in ph]
msssim = image_ops.ssim_multiscale(*scaled_ph, max_val=1.0,
power_factors=(1, 1, 1, 1, 1))
power_factors=(1, 1, 1, 1, 1),
filter_size=11, filter_sigma=1.5,
k1=0.01, k2=0.03)
grads = gradients.gradients(msssim, scalar)
with self.cached_session(use_gpu=True) as sess:
np_grads = sess.run(grads, feed_dict={ph[0]: img[0], ph[1]: img[1]})
Expand All @@ -4807,7 +4816,9 @@ def testBatch(self):
img2 = np.concatenate(img2)

msssim = image_ops.ssim_multiscale(constant_op.constant(img1),
constant_op.constant(img2), 1.0)
constant_op.constant(img2), 1.0,
filter_size=11, filter_sigma=1.5,
k1=0.01, k2 =0.03)
with self.cached_session(use_gpu=True):
self.assertAllClose(expected, self.evaluate(msssim), 1e-4)

Expand All @@ -4820,7 +4831,9 @@ def testBroadcast(self):
img1 = array_ops.expand_dims(img, axis=0) # batch dims: 1, 2.
img2 = array_ops.expand_dims(img, axis=1) # batch dims: 2, 1.

score_tensor = image_ops.ssim_multiscale(img1, img2, 1.0)
score_tensor = image_ops.ssim_multiscale(img1, img2, 1.0, filter_size=11,
filter_sigma=1.5, k1=0.01,
k2=0.03)
with self.cached_session(use_gpu=True):
self.assertAllClose(expected, self.evaluate(score_tensor), 1e-4)

Expand All @@ -4838,7 +4851,9 @@ def testRange(self):
np.full_like(img1, fill_value=255)]

images = [ops.convert_to_tensor(x, dtype=dtypes.float32) for x in images]
msssim_ops = [image_ops.ssim_multiscale(x, y, 1.0)
msssim_ops = [image_ops.ssim_multiscale(x, y, 1.0, filter_size=11,
filter_sigma=1.5, k1=0.01,
k2=0.03)
for x, y in itertools.combinations(images, 2)]
msssim = self.evaluate(msssim_ops)
msssim = np.squeeze(msssim)
Expand All @@ -4852,10 +4867,12 @@ def testInt(self):
img2 = self._RandomImage((1, 180, 240, 3), 255)
img1 = constant_op.constant(img1, dtypes.uint8)
img2 = constant_op.constant(img2, dtypes.uint8)
ssim_uint8 = image_ops.ssim_multiscale(img1, img2, 255)
ssim_uint8 = image_ops.ssim_multiscale(img1, img2, 255, filter_size=11,
filter_sigma=1.5, k1=0.01, k2=0.03)
img1 = image_ops.convert_image_dtype(img1, dtypes.float32)
img2 = image_ops.convert_image_dtype(img2, dtypes.float32)
ssim_float32 = image_ops.ssim_multiscale(img1, img2, 1.0)
ssim_float32 = image_ops.ssim_multiscale(img1, img2, 1.0, filter_size=11,
filter_sigma=1.5, k1=0.01, k2=0.03)
with self.cached_session(use_gpu=True):
self.assertAllClose(
ssim_uint8.eval(), self.evaluate(ssim_float32), atol=0.001)
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt
Expand Up @@ -246,11 +246,11 @@ tf_module {
}
member_method {
name: "ssim"
argspec: "args=[\'img1\', \'img2\', \'max_val\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'img1\', \'img2\', \'max_val\', \'filter_size\', \'filter_sigma\', \'k1\', \'k2\'], varargs=None, keywords=None, defaults=[\'11\', \'1.5\', \'0.01\', \'0.03\'], "
}
member_method {
name: "ssim_multiscale"
argspec: "args=[\'img1\', \'img2\', \'max_val\', \'power_factors\'], varargs=None, keywords=None, defaults=[\'(0.0448, 0.2856, 0.3001, 0.2363, 0.1333)\'], "
argspec: "args=[\'img1\', \'img2\', \'max_val\', \'power_factors\', \'filter_size\', \'filter_sigma\', \'k1\', \'k2\'], varargs=None, keywords=None, defaults=[\'(0.0448, 0.2856, 0.3001, 0.2363, 0.1333)\', \'11\', \'1.5\', \'0.01\', \'0.03\'], "
}
member_method {
name: "total_variation"
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt
Expand Up @@ -218,11 +218,11 @@ tf_module {
}
member_method {
name: "ssim"
argspec: "args=[\'img1\', \'img2\', \'max_val\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'img1\', \'img2\', \'max_val\', \'filter_size\', \'filter_sigma\', \'k1\', \'k2\'], varargs=None, keywords=None, defaults=[\'11\', \'1.5\', \'0.01\', \'0.03\'], "
}
member_method {
name: "ssim_multiscale"
argspec: "args=[\'img1\', \'img2\', \'max_val\', \'power_factors\'], varargs=None, keywords=None, defaults=[\'(0.0448, 0.2856, 0.3001, 0.2363, 0.1333)\'], "
argspec: "args=[\'img1\', \'img2\', \'max_val\', \'power_factors\', \'filter_size\', \'filter_sigma\', \'k1\', \'k2\'], varargs=None, keywords=None, defaults=[\'(0.0448, 0.2856, 0.3001, 0.2363, 0.1333)\', \'11\', \'1.5\', \'0.01\', \'0.03\'], "
}
member_method {
name: "total_variation"
Expand Down