Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix tf.raw_ops.EmptyTensorList vulnerability with invalid `element_sh…
…ape`.

Check that given `element_shape` is valid.
Add graph/eager unit tests. Graph mode was already ok but eager mode was not.

PiperOrigin-RevId: 461906461
  • Loading branch information
poulsbo authored and tensorflower-gardener committed Jul 19, 2022
1 parent 49b3824 commit c8ba76d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
11 changes: 7 additions & 4 deletions tensorflow/core/kernels/list_kernels.cc
Expand Up @@ -21,7 +21,11 @@ limitations under the License.

#include "tensorflow/core/kernels/list_kernels.h"

#include <algorithm>
#include <iterator>
#include <limits>
#include <memory>
#include <utility>

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/allocator.h"
Expand All @@ -30,10 +34,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/kernels/concat_lib.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/util.h"

namespace tensorflow {

Expand All @@ -49,6 +49,9 @@ Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out) {
return errors::InvalidArgument(
"The only valid scalar shape tensor is the fully unknown shape "
"specified as -1.");
} else if (t.shape().dims() != 1) {
return errors::InvalidArgument("Shape must be at most rank 1 but is rank ",
t.shape().dims());
}
if (t.dtype() == DT_INT32) {
return PartialTensorShape::MakePartialShape(t.vec<int32>().data(),
Expand Down
Expand Up @@ -1458,6 +1458,15 @@ def testConcatWithUninitializedTensorsFailsIfNoInputLengths(self):
t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
self.evaluate(t)

def testEmptyTensorListInvalidShape(self):
with self.assertRaisesRegex((ValueError, errors.InvalidArgumentError),
r"Shape must be at most rank 1 but is rank 2"):
t = gen_list_ops.EmptyTensorList(
element_shape=array_ops.ones(dtype=dtypes.int32, shape=[1, 0]),
max_num_elements=constant_op.constant(1),
element_dtype=dtypes.int32)
self.evaluate(t)

def testEvenSplit(self):

def RunTest(input_tensor, lengths, expected_stacked_output):
Expand Down

0 comments on commit c8ba76d

Please sign in to comment.