diff --git a/tensorflow/core/kernels/image/attention_ops.cc b/tensorflow/core/kernels/image/attention_ops.cc index 9a0513931ce258..60f6d9070dd55d 100644 --- a/tensorflow/core/kernels/image/attention_ops.cc +++ b/tensorflow/core/kernels/image/attention_ops.cc @@ -87,9 +87,10 @@ class ExtractGlimpseOp : public OpKernel { const int64_t output_height = window_size.tensor()(0); const int64_t output_width = window_size.tensor()(1); + TensorShape output_shape = input_shape; - output_shape.set_dim(1, output_height); - output_shape.set_dim(2, output_width); + OP_REQUIRES_OK(context, output_shape.SetDimWithStatus(1, output_height)); + OP_REQUIRES_OK(context, output_shape.SetDimWithStatus(2, output_width)); const Tensor& offsets = context->input(2); OP_REQUIRES(context, offsets.shape().dims() == 2, diff --git a/tensorflow/python/kernel_tests/attention_ops_test.py b/tensorflow/python/kernel_tests/attention_ops_test.py index 804a0b20cc9dd4..34b6da4c6a8e6f 100644 --- a/tensorflow/python/kernel_tests/attention_ops_test.py +++ b/tensorflow/python/kernel_tests/attention_ops_test.py @@ -22,6 +22,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_image_ops from tensorflow.python.ops import image_ops @@ -301,6 +302,15 @@ def testGlimpseNonNormalizedNonCentered(self): np.asarray([[5, 6, 7], [10, 11, 12], [15, 16, 17]]), self.evaluate(result2)[0, :, :, 0]) + def testGlimpseNegativeInput(self): + img = np.arange(9).reshape([1,3,3,1]) + with self.test_session(): + with self.assertRaises((errors.InternalError, ValueError)): + result = image_ops.extract_glimpse_v2( + img, size=[1023, -63], offsets=[1023, 63], + centered=False, normalized=False) + self.evaluate(result) + if __name__ == '__main__': test.main()