Skip to content

Commit

Permalink
Fix security vulnerability with UnbatchGradKernel
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 460992964
  • Loading branch information
sagunb authored and tensorflow-jenkins committed Aug 19, 2022
1 parent f9764a2 commit c122773
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
10 changes: 10 additions & 0 deletions tensorflow/core/kernels/batch_kernels.cc
Expand Up @@ -879,8 +879,13 @@ class UnbatchGradResource : public ResourceBase {
const Tensor& data_t = context->input(0);
const Tensor& batch_index_t = context->input(1);
const Tensor& grad_t = context->input(2);
const Tensor& batch_key_t = context->input(3);

mutex_lock ml(mu_);
if (batch_key_t.NumElements() != 1) {
return errors::InvalidArgument("Expected `id` to be scalar. Received ",
batch_key_t.DebugString());
}

const int64_t batch_key = context->input(3).scalar<int64_t>()();
// Mark our tensor as available.
Expand All @@ -896,6 +901,11 @@ class UnbatchGradResource : public ResourceBase {
"batch_index is empty while the tensor isn't.");
}
std::unordered_set<int64_t> missing_tensors;
if (batch_index_t.NumElements() != batch_index_t.dim_size(0) * 3) {
return errors::InvalidArgument(
"batch_index should contain ", batch_index_t.dim_size(0) * 3,
" elements. Received ", batch_index_t.NumElements());
}
const auto batch_index =
batch_index_t.shaped<int64_t, 2>({batch_index_t.dim_size(0), 3});
for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
Expand Down
53 changes: 53 additions & 0 deletions tensorflow/python/ops/batch_ops_test.py
Expand Up @@ -24,7 +24,9 @@

from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
Expand All @@ -34,6 +36,7 @@
from tensorflow.python.ops import gen_batch_ops
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import variables
Expand Down Expand Up @@ -561,6 +564,56 @@ def worker():
# The thread's call should hit the timeout, and thus get 0 results.
self.assertEqual(len(thread_results), 0)

def testUnbatchGradInvalidId(self):
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(
gen_batch_ops.unbatch_grad(
original_input=constant_op.constant([1]),
batch_index=constant_op.constant([
[0, 0, 0],
], dtype=dtypes.int64),
grad=constant_op.constant([
1,
]),
id=constant_op.constant([
1,
1,
], dtype=dtypes.int64)))

def testUnbatchGradInvalidBatchId(self):
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(
gen_batch_ops.unbatch_grad(
original_input=constant_op.constant([1]),
batch_index=constant_op.constant([
[0, 0],
], dtype=dtypes.int64),
grad=constant_op.constant([
1,
]),
id=constant_op.constant([
1,
], dtype=dtypes.int64)))

def testUnbatchGradInvalidArgs(self):
original_input = random_ops.random_uniform(
shape=(3, 1), dtype=dtypes.float64, maxval=None)
batch_index = random_ops.random_uniform(
shape=(3, 1), dtype=dtypes.int64, maxval=65536)
grad = random_ops.random_uniform(
shape=(3, 1), dtype=dtypes.float64, maxval=None)
batch_id = random_ops.random_uniform(
shape=(3, 1), dtype=dtypes.int64, maxval=65536)
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(
gen_batch_ops.unbatch_grad(
original_input=original_input,
batch_index=batch_index,
grad=grad,
id=batch_id,
container="",
shared_name="",
name=""))

if __name__ == "__main__":
test.main()

0 comments on commit c122773

Please sign in to comment.