Skip to content

Commit

Permalink
Merge pull request #52788 from pranve/cherrypick-a871989d7b6c18cdebf2…
Browse files Browse the repository at this point in the history
…fb4f0e5c5b62fbc19edf-on-r2.6

Merge pull request #51658 from yongtang:51618-tf.image.extract_glimpse
  • Loading branch information
mihaimaruseac committed Oct 28, 2021
2 parents d27fbb4 + 95a0f70 commit 7a4c796
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tensorflow/core/kernels/image/attention_ops.cc
Expand Up @@ -85,11 +85,12 @@ class ExtractGlimpseOp : public OpKernel {
"input must be a vector of size 2 (height, width)",
window_size.shape().DebugString()));

const int64 output_height = window_size.tensor<int, 1>()(0);
const int64 output_width = window_size.tensor<int, 1>()(1);
const int64_t output_height = window_size.tensor<int, 1>()(0);
const int64_t output_width = window_size.tensor<int, 1>()(1);

TensorShape output_shape = input_shape;
output_shape.set_dim(1, output_height);
output_shape.set_dim(2, output_width);
OP_REQUIRES_OK(context, output_shape.SetDimWithStatus(1, output_height));
OP_REQUIRES_OK(context, output_shape.SetDimWithStatus(2, output_width));

const Tensor& offsets = context->input(2);
OP_REQUIRES(context, offsets.shape().dims() == 2,
Expand Down
13 changes: 13 additions & 0 deletions tensorflow/python/kernel_tests/attention_ops_test.py
Expand Up @@ -22,6 +22,7 @@

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.ops import image_ops
Expand Down Expand Up @@ -301,6 +302,18 @@ def testGlimpseNonNormalizedNonCentered(self):
np.asarray([[5, 6, 7], [10, 11, 12], [15, 16, 17]]),
self.evaluate(result2)[0, :, :, 0])

def testGlimpseNegativeInput(self):
img = np.arange(9).reshape([1, 3, 3, 1])
with self.test_session():
with self.assertRaises((errors.InternalError, ValueError)):
result = image_ops.extract_glimpse_v2(
img,
size=[1023, -63],
offsets=[1023, 63],
centered=False,
normalized=False)
self.evaluate(result)


if __name__ == '__main__':
test.main()

0 comments on commit 7a4c796

Please sign in to comment.