diff --git a/tensorflow/core/kernels/data_format_ops.cc b/tensorflow/core/kernels/data_format_ops.cc index b52d4d6c888271..00724b630ee9db 100644 --- a/tensorflow/core/kernels/data_format_ops.cc +++ b/tensorflow/core/kernels/data_format_ops.cc @@ -18,16 +18,52 @@ limitations under the License. #define EIGEN_USE_THREADS #include "tensorflow/core/kernels/data_format_ops.h" + +#include + #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +// Ensure that `src` and `dst` define a valid permutation. +// Ops defined in this file assume that user specifies a permutation via two +// string attributes. This check validates that these attributes properly define +// it to prevent security vulnerabilities. +static bool IsValidPermutation(const std::string& src, const std::string& dst) { + if (src.size() != dst.size()) { + return false; + } + + std::map characters; + + // Every character in `src` must be present only once + for (const auto c : src) { + if (characters[c]) { + return false; + } + characters[c] = true; + } + + // Every character in `dst` must show up in `src` exactly once + for (const auto c : dst) { + if (!characters[c]) { + return false; + } + characters[c] = false; + } + + // At this point, characters[] has been switched to true and false exactly + // once for all character in `src` (and `dst`) so we have a valid permutation + return true; +} + template class DataFormatDimMapOp : public OpKernel { public: @@ -38,15 +74,19 @@ class DataFormatDimMapOp : public OpKernel { string dst_format; OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format)); OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5, - errors::InvalidArgument(strings::StrCat( - "Source format must of length 4 or 5, received " + errors::InvalidArgument( + "Source format must be of length 4 or 5, received " "src_format = ", - src_format))); + src_format)); + OP_REQUIRES(context, dst_format.size() == 4 || dst_format.size() == 5, + errors::InvalidArgument("Destination format must be of length " + "4 or 5, received dst_format = ", + dst_format)); OP_REQUIRES( - context, dst_format.size() == 4 || dst_format.size() == 5, - errors::InvalidArgument(strings::StrCat( - "Destination format must of length 4 or 5, received dst_format = ", - dst_format))); + context, IsValidPermutation(src_format, dst_format), + errors::InvalidArgument( + "Destination and source format must determine a permutation, got ", + src_format, " and ", dst_format)); dst_idx_ = Tensor(DT_INT32, {static_cast(src_format.size())}); for (int i = 0; i < src_format.size(); ++i) { for (int j = 0; j < dst_format.size(); ++j) { @@ -78,8 +118,22 @@ class DataFormatVecPermuteOp : public OpKernel { : OpKernel(context) { string src_format; OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format)); + OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5, + errors::InvalidArgument( + "Source format must be of length 4 or 5, received " + "src_format = ", + src_format)); string dst_format; OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format)); + OP_REQUIRES(context, dst_format.size() == 4 || dst_format.size() == 5, + errors::InvalidArgument("Destination format must be of length " + "4 or 5, received dst_format = ", + dst_format)); + OP_REQUIRES( + context, IsValidPermutation(src_format, dst_format), + errors::InvalidArgument( + "Destination and source format must determine a permutation, got ", + src_format, " and ", dst_format)); src_format_ = src_format; dst_format_ = dst_format; } @@ -127,6 +181,10 @@ class DataFormatVecPermuteOp : public OpKernel { }; keep_only_spatial_dimensions(&src_format_str); keep_only_spatial_dimensions(&dst_format_str); + OP_REQUIRES(context, + src_format_str.size() == 2 && dst_format_str.size() == 2, + errors::InvalidArgument( + "Format specifier must contain H and W for 2D case")); } ComputeDstIndex(src_format_str, dst_format_str, input.dims(), &dst_idx); diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 851bfcb66de3c8..aaf2f77fb2975a 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -27,6 +27,7 @@ from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util @@ -1260,6 +1261,7 @@ def testDNHWCtoWHDCN(self): y_val = self.evaluate(y) self.assertAllEqual(y_val, y_val_expected) + @test_util.disable_xla("XLA catches the error and rethrows as different one") def testArbitraryASCII(self): x_val = [-4, -3, -2, -1, 0, 1, 2, 3] y_val_expected = [3, 2, 1, 0, 3, 2, 1, 0] @@ -1269,6 +1271,46 @@ def testArbitraryASCII(self): y_val = self.evaluate(y) self.assertAllEqual(y_val, y_val_expected) + @test_util.disable_xla("XLA catches the error and rethrows as different one") + def testInvalidLength(self): + x = [-4, -3, -2, -1, 0, 1, 2, 3] + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Source format must be of length 4 or 5"): + op = nn_ops.data_format_dim_map( + x, src_format="12345678", dst_format="87654321") + with test_util.use_gpu(): + self.evaluate(op) + + @test_util.disable_xla("XLA catches the error and rethrows as different one") + def testDuplicateSrc(self): + x = [-4, -3, -2, -1, 0, 1, 2, 3] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Destination and source format must determine a permutation"): + op = nn_ops.data_format_dim_map(x, src_format="1233", dst_format="4321") + with test_util.use_gpu(): + self.evaluate(op) + + @test_util.disable_xla("XLA catches the error and rethrows as different one") + def testDuplicateDst(self): + x = [-4, -3, -2, -1, 0, 1, 2, 3] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Destination and source format must determine a permutation"): + op = nn_ops.data_format_dim_map(x, src_format="1234", dst_format="3321") + with test_util.use_gpu(): + self.evaluate(op) + + @test_util.disable_xla("XLA catches the error and rethrows as different one") + def testExtraSpecifiers(self): + x = [-4, -3, -2, -1, 0, 1, 2, 3] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Destination and source format must determine a permutation"): + op = nn_ops.data_format_dim_map(x, src_format="1234", dst_format="5321") + with test_util.use_gpu(): + self.evaluate(op) + class DataFormatVectorPermuteTest(test_lib.TestCase): @@ -1370,6 +1412,60 @@ def testNCHWToNHWC2D(self): y_val = self.evaluate(y) self.assertAllEqual(y_val, [[7, 4], [4, 5], [5, 1], [9, 3]]) + @test_util.disable_xla("XLA catches the error and rethrows as different one") + def testInvalidLength(self): + x = [0, 1, 2, 3] + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Source format must be of length 4 or 5"): + op = nn_ops.data_format_vec_permute( + x, src_format="12345678", dst_format="87654321") + with test_util.use_gpu(): + self.evaluate(op) + + @test_util.disable_xla("XLA catches the error and rethrows as different one") + def testDuplicateSrc(self): + x = [0, 1, 2, 3] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Destination and source format must determine a permutation"): + op = nn_ops.data_format_vec_permute( + x, src_format="1233", dst_format="4321") + with test_util.use_gpu(): + self.evaluate(op) + + @test_util.disable_xla("XLA catches the error and rethrows as different one") + def testDuplicateDst(self): + x = [0, 1, 2, 3] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Destination and source format must determine a permutation"): + op = nn_ops.data_format_vec_permute( + x, src_format="1234", dst_format="3321") + with test_util.use_gpu(): + self.evaluate(op) + + @test_util.disable_xla("XLA catches the error and rethrows as different one") + def testExtraSpecifiers(self): + x = [0, 1, 2, 3] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Destination and source format must determine a permutation"): + op = nn_ops.data_format_vec_permute( + x, src_format="1234", dst_format="5321") + with test_util.use_gpu(): + self.evaluate(op) + + @test_util.disable_xla("XLA catches the error and rethrows as different one") + def test2DNoWH(self): + x = [[0, 1], [2, 3]] + with self.assertRaisesRegex( + errors.InvalidArgumentError, + "Format specifier must contain H and W for 2D case"): + op = nn_ops.data_format_vec_permute( + x, src_format="1234", dst_format="4321") + with test_util.use_gpu(): + self.evaluate(op) + @test_util.run_all_in_graph_and_eager_modes class AvgPoolTest(test_lib.TestCase):