3434from tensorflow .python .compat import compat
3535from tensorflow .python .data .experimental .ops import get_single_element
3636from tensorflow .python .data .ops import dataset_ops
37+ from tensorflow .python .eager import backprop
3738from tensorflow .python .eager import context
3839from tensorflow .python .eager import def_function
3940from tensorflow .python .framework import constant_op
4748from tensorflow .python .ops import array_ops
4849from tensorflow .python .ops import control_flow_ops
4950from tensorflow .python .ops import gen_image_ops
50- from tensorflow .python .ops import gradients
5151from tensorflow .python .ops import image_ops
5252from tensorflow .python .ops import image_ops_impl
5353from tensorflow .python .ops import io_ops
@@ -5453,7 +5453,6 @@ def _RandomImage(self, shape, max_val):
54535453 """Returns an image or image batch with given shape."""
54545454 return np .random .rand (* shape ).astype (np .float32 ) * max_val
54555455
5456- @test_util .run_deprecated_v1
54575456 def testAgainstMatlab (self ):
54585457 """Tests against MS-SSIM computed with Matlab implementation.
54595458
@@ -5462,32 +5461,68 @@ def testAgainstMatlab(self):
54625461 img = self ._LoadTestImages ()
54635462 expected = self ._msssim [np .triu_indices (3 )]
54645463
5465- ph = [array_ops .placeholder (dtype = dtypes .float32 ) for _ in range (2 )]
5466- msssim = image_ops .ssim_multiscale (
5467- * ph , max_val = 1.0 , filter_size = 11 , filter_sigma = 1.5 , k1 = 0.01 , k2 = 0.03 )
5464+ def ssim_func (x ):
5465+ return image_ops .ssim_multiscale (
5466+ * x , max_val = 1.0 , filter_size = 11 , filter_sigma = 1.5 , k1 = 0.01 , k2 = 0.03 )
5467+
54685468 with self .cached_session (use_gpu = True ):
5469- scores = [msssim .eval (dict (zip (ph , t )))
5470- for t in itertools .combinations_with_replacement (img , 2 )]
5469+ scores = [
5470+ self .evaluate (ssim_func (t ))
5471+ for t in itertools .combinations_with_replacement (img , 2 )
5472+ ]
54715473
54725474 self .assertAllClose (expected , np .squeeze (scores ), atol = 1e-4 )
54735475
5474- @test_util .run_deprecated_v1
54755476 def testUnweightedIsDifferentiable (self ):
54765477 img = self ._LoadTestImages ()
5477- ph = [array_ops .placeholder (dtype = dtypes .float32 ) for _ in range (2 )]
5478+
5479+ @def_function .function
5480+ def msssim_func (x1 , x2 , scalar ):
5481+ return image_ops .ssim_multiscale (
5482+ x1 * scalar ,
5483+ x2 * scalar ,
5484+ max_val = 1.0 ,
5485+ power_factors = (1 , 1 , 1 , 1 , 1 ),
5486+ filter_size = 11 ,
5487+ filter_sigma = 1.5 ,
5488+ k1 = 0.01 ,
5489+ k2 = 0.03 )
5490+
54785491 scalar = constant_op .constant (1.0 , dtype = dtypes .float32 )
5479- scaled_ph = [x * scalar for x in ph ]
5480- msssim = image_ops .ssim_multiscale (
5481- * scaled_ph ,
5482- max_val = 1.0 ,
5483- power_factors = (1 , 1 , 1 , 1 , 1 ),
5484- filter_size = 11 ,
5485- filter_sigma = 1.5 ,
5486- k1 = 0.01 ,
5487- k2 = 0.03 )
5488- grads = gradients .gradients (msssim , scalar )
5489- with self .cached_session (use_gpu = True ) as sess :
5490- np_grads = sess .run (grads , feed_dict = {ph [0 ]: img [0 ], ph [1 ]: img [1 ]})
5492+
5493+ with backprop .GradientTape () as tape :
5494+ tape .watch (scalar )
5495+ y = msssim_func (img [0 ], img [1 ], scalar )
5496+
5497+ grad = tape .gradient (y , scalar )
5498+ np_grads = self .evaluate (grad )
5499+ self .assertTrue (np .isfinite (np_grads ).all ())
5500+
5501+ def testUnweightedIsDifferentiableEager (self ):
5502+ if not context .executing_eagerly ():
5503+ self .skipTest ("Eager mode only" )
5504+
5505+ img = self ._LoadTestImages ()
5506+
5507+ def msssim_func (x1 , x2 , scalar ):
5508+ return image_ops .ssim_multiscale (
5509+ x1 * scalar ,
5510+ x2 * scalar ,
5511+ max_val = 1.0 ,
5512+ power_factors = (1 , 1 , 1 , 1 , 1 ),
5513+ filter_size = 11 ,
5514+ filter_sigma = 1.5 ,
5515+ k1 = 0.01 ,
5516+ k2 = 0.03 )
5517+
5518+ scalar = constant_op .constant (1.0 , dtype = dtypes .float32 )
5519+
5520+ with backprop .GradientTape () as tape :
5521+ tape .watch (scalar )
5522+ y = msssim_func (img [0 ], img [1 ], scalar )
5523+
5524+ grad = tape .gradient (y , scalar )
5525+ np_grads = self .evaluate (grad )
54915526 self .assertTrue (np .isfinite (np_grads ).all ())
54925527
54935528 def testBatch (self ):
@@ -5549,7 +5584,6 @@ def testRange(self):
55495584 self .assertTrue (np .all (msssim >= 0.0 ))
55505585 self .assertTrue (np .all (msssim <= 1.0 ))
55515586
5552- @test_util .run_deprecated_v1
55535587 def testInt (self ):
55545588 img1 = self ._RandomImage ((1 , 180 , 240 , 3 ), 255 )
55555589 img2 = self ._RandomImage ((1 , 180 , 240 , 3 ), 255 )
@@ -5563,7 +5597,7 @@ def testInt(self):
55635597 img1 , img2 , 1.0 , filter_size = 11 , filter_sigma = 1.5 , k1 = 0.01 , k2 = 0.03 )
55645598 with self .cached_session (use_gpu = True ):
55655599 self .assertAllClose (
5566- ssim_uint8 . eval ( ), self .evaluate (ssim_float32 ), atol = 0.001 )
5600+ self . evaluate ( ssim_uint8 ), self .evaluate (ssim_float32 ), atol = 0.001 )
55675601
55685602 def testNumpyInput (self ):
55695603 """Test case for GitHub issue 28241."""
0 commit comments