Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix tf.raw_ops.TensorListResize vulnerability with non-scalar input.
Check that the size input is valid.
Add graph/eager unit tests. Graph mode was already ok but eager mode was not.

Note: This fix will have to be cherry picked in r2.10, r2.9, and r2.8.
PiperOrigin-RevId: 477002316
  • Loading branch information
poulsbo authored and tensorflower-gardener committed Sep 26, 2022
1 parent cc0ba05 commit 888e34b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tensorflow/core/kernels/list_kernels.cc
Expand Up @@ -375,6 +375,8 @@ class TensorListResize : public OpKernel {
void Compute(OpKernelContext* c) override {
const TensorList* input_list = nullptr;
OP_REQUIRES_OK(c, GetInputList(c, 0, &input_list));
OP_REQUIRES(c, TensorShapeUtils::IsScalar(c->input(1).shape()),
errors::InvalidArgument("size must be a scalar"));
int32_t size = c->input(1).scalar<int32>()();
OP_REQUIRES(
c, size >= 0,
Expand Down
Expand Up @@ -1658,6 +1658,15 @@ def testResizeWithInvalidSizeFails(self):
l = list_ops.tensor_list_resize(l, -1)
self.evaluate(l)

@test_util.run_in_graph_and_eager_modes
def testResizeWithNonScalarFails(self):
l = list_ops.tensor_list_from_tensor([3, 4, 5], element_shape=[])
size = np.zeros([0, 2, 3, 3])
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
r"Shape must be rank 0 but is rank \d+|"
r"\w+ must be a scalar"):
self.evaluate(gen_list_ops.TensorListResize(input_handle=l, size=size))

@test_util.run_deprecated_v1
@test_util.enable_control_flow_v2
def testSkipEagerResizeGrad(self):
Expand Down

0 comments on commit 888e34b

Please sign in to comment.