Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix tf.raw_ops.LoadAndRemapMatrix vulnerability with invalid `row_rem…
…apping`.

Check that `row_remapping` has the correct dims().

PiperOrigin-RevId: 445522800
  • Loading branch information
poulsbo authored and tensorflower-gardener committed Apr 29, 2022
1 parent 0f0b080 commit 3150642
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tensorflow/core/kernels/load_and_remap_matrix_op.cc
Expand Up @@ -74,6 +74,11 @@ class LoadAndRemapMatrixOp : public OpKernel {
std::vector<bool> row_id_present;
const Tensor* row_remapping_t;
OP_REQUIRES_OK(context, context->input("row_remapping", &row_remapping_t));
OP_REQUIRES(
context, row_remapping_t->dims() == 1,
errors::InvalidArgument("The `row_remapping` tensor must be 1-D, got "
"a tensor of shape ",
row_remapping_t->shape().DebugString()));
const auto row_remapping = row_remapping_t->vec<int64_t>();
OP_REQUIRES(context, row_remapping.size() == num_rows_,
errors::InvalidArgument(strings::StrCat(
Expand Down
26 changes: 26 additions & 0 deletions tensorflow/python/kernel_tests/io_ops/checkpoint_ops_test.py
Expand Up @@ -227,6 +227,32 @@ def test_load_and_remap_all_missing_rows_and_cols(self):
np.reshape(initializing_values, (num_rows, num_cols)),
self.evaluate(remapped_matrix))

def test_load_and_remap_invalid_dims(self):
ckpt_path = constant_op.constant(
'/tmp/warm_starting_util_test5kl2a3pc/tmpph76tep2/model-0',
shape=[],
dtype=dtypes.string)
old_tensor_name = constant_op.constant(
'/tmp/warm_starting_util_test5kl2a3pc/tmpph76tep2/model-0',
shape=[],
dtype=dtypes.string)
row_remapping = constant_op.constant(0, shape=[], dtype=dtypes.int64)
col_remapping = constant_op.constant(3, shape=[3], dtype=dtypes.int64)
initializing_values = constant_op.constant([],
shape=[0, 1],
dtype=dtypes.float32)
with self.cached_session(), self.assertRaisesRegex(
(ValueError, errors.InvalidArgumentError), 'tensor must be 1-D'):
self.evaluate(
gen_checkpoint_ops.load_and_remap_matrix(
ckpt_path=ckpt_path,
old_tensor_name=old_tensor_name,
row_remapping=row_remapping,
col_remapping=col_remapping,
initializing_values=initializing_values,
num_rows=1,
num_cols=1))

@test_util.run_deprecated_v1
def test_load_and_remap_invalid_remapping(self):
"""Tests that errors are raised when an ID maps to multiple new IDs.
Expand Down

0 comments on commit 3150642

Please sign in to comment.