Skip to content

Commit 89d06dc

Browse files
hyeygittensorflower-gardener
authored andcommitted
Eager execution coverage for image_ops_test.py. Removed run_deprecated_v1 decorators.
Part 16 (class MultiscaleSSIMTest) PiperOrigin-RevId: 339293048 Change-Id: Ia4b309c4943333cafaa7425995b3564410a82c2b
1 parent 5ad4ef3 commit 89d06dc

File tree

1 file changed

+57
-23
lines changed

1 file changed

+57
-23
lines changed

tensorflow/python/ops/image_ops_test.py

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from tensorflow.python.compat import compat
3535
from tensorflow.python.data.experimental.ops import get_single_element
3636
from tensorflow.python.data.ops import dataset_ops
37+
from tensorflow.python.eager import backprop
3738
from tensorflow.python.eager import context
3839
from tensorflow.python.eager import def_function
3940
from tensorflow.python.framework import constant_op
@@ -47,7 +48,6 @@
4748
from tensorflow.python.ops import array_ops
4849
from tensorflow.python.ops import control_flow_ops
4950
from tensorflow.python.ops import gen_image_ops
50-
from tensorflow.python.ops import gradients
5151
from tensorflow.python.ops import image_ops
5252
from tensorflow.python.ops import image_ops_impl
5353
from tensorflow.python.ops import io_ops
@@ -5453,7 +5453,6 @@ def _RandomImage(self, shape, max_val):
54535453
"""Returns an image or image batch with given shape."""
54545454
return np.random.rand(*shape).astype(np.float32) * max_val
54555455

5456-
@test_util.run_deprecated_v1
54575456
def testAgainstMatlab(self):
54585457
"""Tests against MS-SSIM computed with Matlab implementation.
54595458
@@ -5462,32 +5461,68 @@ def testAgainstMatlab(self):
54625461
img = self._LoadTestImages()
54635462
expected = self._msssim[np.triu_indices(3)]
54645463

5465-
ph = [array_ops.placeholder(dtype=dtypes.float32) for _ in range(2)]
5466-
msssim = image_ops.ssim_multiscale(
5467-
*ph, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03)
5464+
def ssim_func(x):
5465+
return image_ops.ssim_multiscale(
5466+
*x, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03)
5467+
54685468
with self.cached_session(use_gpu=True):
5469-
scores = [msssim.eval(dict(zip(ph, t)))
5470-
for t in itertools.combinations_with_replacement(img, 2)]
5469+
scores = [
5470+
self.evaluate(ssim_func(t))
5471+
for t in itertools.combinations_with_replacement(img, 2)
5472+
]
54715473

54725474
self.assertAllClose(expected, np.squeeze(scores), atol=1e-4)
54735475

5474-
@test_util.run_deprecated_v1
54755476
def testUnweightedIsDifferentiable(self):
54765477
img = self._LoadTestImages()
5477-
ph = [array_ops.placeholder(dtype=dtypes.float32) for _ in range(2)]
5478+
5479+
@def_function.function
5480+
def msssim_func(x1, x2, scalar):
5481+
return image_ops.ssim_multiscale(
5482+
x1 * scalar,
5483+
x2 * scalar,
5484+
max_val=1.0,
5485+
power_factors=(1, 1, 1, 1, 1),
5486+
filter_size=11,
5487+
filter_sigma=1.5,
5488+
k1=0.01,
5489+
k2=0.03)
5490+
54785491
scalar = constant_op.constant(1.0, dtype=dtypes.float32)
5479-
scaled_ph = [x * scalar for x in ph]
5480-
msssim = image_ops.ssim_multiscale(
5481-
*scaled_ph,
5482-
max_val=1.0,
5483-
power_factors=(1, 1, 1, 1, 1),
5484-
filter_size=11,
5485-
filter_sigma=1.5,
5486-
k1=0.01,
5487-
k2=0.03)
5488-
grads = gradients.gradients(msssim, scalar)
5489-
with self.cached_session(use_gpu=True) as sess:
5490-
np_grads = sess.run(grads, feed_dict={ph[0]: img[0], ph[1]: img[1]})
5492+
5493+
with backprop.GradientTape() as tape:
5494+
tape.watch(scalar)
5495+
y = msssim_func(img[0], img[1], scalar)
5496+
5497+
grad = tape.gradient(y, scalar)
5498+
np_grads = self.evaluate(grad)
5499+
self.assertTrue(np.isfinite(np_grads).all())
5500+
5501+
def testUnweightedIsDifferentiableEager(self):
5502+
if not context.executing_eagerly():
5503+
self.skipTest("Eager mode only")
5504+
5505+
img = self._LoadTestImages()
5506+
5507+
def msssim_func(x1, x2, scalar):
5508+
return image_ops.ssim_multiscale(
5509+
x1 * scalar,
5510+
x2 * scalar,
5511+
max_val=1.0,
5512+
power_factors=(1, 1, 1, 1, 1),
5513+
filter_size=11,
5514+
filter_sigma=1.5,
5515+
k1=0.01,
5516+
k2=0.03)
5517+
5518+
scalar = constant_op.constant(1.0, dtype=dtypes.float32)
5519+
5520+
with backprop.GradientTape() as tape:
5521+
tape.watch(scalar)
5522+
y = msssim_func(img[0], img[1], scalar)
5523+
5524+
grad = tape.gradient(y, scalar)
5525+
np_grads = self.evaluate(grad)
54915526
self.assertTrue(np.isfinite(np_grads).all())
54925527

54935528
def testBatch(self):
@@ -5549,7 +5584,6 @@ def testRange(self):
55495584
self.assertTrue(np.all(msssim >= 0.0))
55505585
self.assertTrue(np.all(msssim <= 1.0))
55515586

5552-
@test_util.run_deprecated_v1
55535587
def testInt(self):
55545588
img1 = self._RandomImage((1, 180, 240, 3), 255)
55555589
img2 = self._RandomImage((1, 180, 240, 3), 255)
@@ -5563,7 +5597,7 @@ def testInt(self):
55635597
img1, img2, 1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03)
55645598
with self.cached_session(use_gpu=True):
55655599
self.assertAllClose(
5566-
ssim_uint8.eval(), self.evaluate(ssim_float32), atol=0.001)
5600+
self.evaluate(ssim_uint8), self.evaluate(ssim_float32), atol=0.001)
55675601

55685602
def testNumpyInput(self):
55695603
"""Test case for GitHub issue 28241."""

0 commit comments

Comments
 (0)