diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 6aeb6f40f1c4a4..c2ebb870f2c146 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -3060,12 +3060,11 @@ def ssim_multiscale(img1, img2, max_val, power_factors=_MSSSIM_WEIGHTS): are in range [0, 1]. Returns a tensor with shape: broadcast(img1.shape[:-3], img2.shape[:-3]). """ - # Shape checking. - shape1 = img1.get_shape().with_rank_at_least(3) - shape2 = img2.get_shape().with_rank_at_least(3) - shape1[-3:].merge_with(shape2[-3:]) - with ops.name_scope(None, 'MS-SSIM', [img1, img2]): + # Convert to tensor if needed. + img1 = ops.convert_to_tensor(img1, name='img1') + img2 = ops.convert_to_tensor(img2, name='img2') + # Shape checking. shape1, shape2, checks = _verify_compatible_image_shapes(img1, img2) with ops.control_dependencies(checks): img1 = array_ops.identity(img1) diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 784df0fe0ea025..7daf3ba50e37a2 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -4860,6 +4860,12 @@ def testInt(self): self.assertAllClose( ssim_uint8.eval(), self.evaluate(ssim_float32), atol=0.001) + def testNumpyInput(self): + """Test case for GitHub issue 28241.""" + image = np.random.random([512, 512, 1]) + score_tensor = image_ops.ssim_multiscale(image, image, max_val=1.0) + with self.cached_session(use_gpu=True): + _ = self.evaluate(score_tensor) class ImageGradientsTest(test_util.TensorFlowTestCase):