Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix security vulnerability with LSTMBlockCellOp
PiperOrigin-RevId: 446028341
  • Loading branch information
sagunb authored and tensorflower-gardener committed May 2, 2022
1 parent 24cf5e1 commit 8034040
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 0 deletions.
59 changes: 59 additions & 0 deletions tensorflow/core/kernels/rnn/lstm_ops.cc
Expand Up @@ -416,6 +416,65 @@ class LSTMBlockCellOp : public OpKernel {

const Device& device = ctx->eigen_device<Device>();

// Sanity check that each of the tensors have the required NDIMS.
OP_REQUIRES(ctx, x_tensor->dims() == 2,
errors::InvalidArgument("x_tensor must be rank 2 but is rank ",
x_tensor->dims(), "."));
OP_REQUIRES(
ctx, cs_prev_tensor->dims() == 2,
errors::InvalidArgument("cs_prev_tensor must be rank 2 but is rank ",
cs_prev_tensor->dims(), "."));
OP_REQUIRES(
ctx, h_prev_tensor->dims() == 2,
errors::InvalidArgument("h_prev_tensor must be rank 2 but is rank ",
h_prev_tensor->dims(), "."));
OP_REQUIRES(ctx, w_tensor->dims() == 2,
errors::InvalidArgument("w_tensor must be rank 2 but is rank ",
w_tensor->dims(), "."));
OP_REQUIRES(
ctx, wci_tensor->dims() == 1,
errors::InvalidArgument("wci_tensor must be rank 1 but is rank ",
wci_tensor->dims(), "."));
OP_REQUIRES(
ctx, wcf_tensor->dims() == 1,
errors::InvalidArgument("wcf_tensor must be rank 1 but is rank ",
wci_tensor->dims(), "."));
OP_REQUIRES(
ctx, wco_tensor->dims() == 1,
errors::InvalidArgument("wco_tensor must be rank 1 but is rank ",
wco_tensor->dims(), "."));
OP_REQUIRES(ctx, b_tensor->dims() == 1,
errors::InvalidArgument("b_tensor must be rank 1 but is rank ",
b_tensor->dims(), "."));
OP_REQUIRES(ctx, xh_tensor.dims() == 2,
errors::InvalidArgument("xh_tensor must be rank 2 but is rank ",
xh_tensor.dims(), "."));
OP_REQUIRES(ctx, i_tensor->dims() == 2,
errors::InvalidArgument("i_tensor must be rank 2 but is rank ",
i_tensor->dims(), "."));
OP_REQUIRES(ctx, cs_tensor->dims() == 2,
errors::InvalidArgument("cs_tensor must be rank 2 but is rank ",
cs_tensor->dims(), "."));
OP_REQUIRES(ctx, f_tensor->dims() == 2,
errors::InvalidArgument("f_tensor must be rank 2 but is rank ",
f_tensor->dims(), "."));
OP_REQUIRES(ctx, o_tensor->dims() == 2,
errors::InvalidArgument("o_tensor must be rank 2 but is rank ",
o_tensor->dims(), "."));
OP_REQUIRES(ctx, ci_tensor->dims() == 2,
errors::InvalidArgument("ci_tensor must be rank 2 but is rank ",
ci_tensor->dims(), "."));
OP_REQUIRES(ctx, co_tensor->dims() == 2,
errors::InvalidArgument("co_tensor must be rank 2 but is rank ",
co_tensor->dims(), "."));
OP_REQUIRES(
ctx, gates_tensor.dims() == 2,
errors::InvalidArgument("gates_tensor must be rank 2 but is rank ",
gates_tensor.dims(), "."));
OP_REQUIRES(ctx, h_tensor->dims() == 2,
errors::InvalidArgument("h_tensor must be rank 2 but is rank ",
h_tensor->dims(), "."));

functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS, gate_layout>(
batch_size, input_size, cell_size)(
ctx, device, forget_bias_, cell_clip_, use_peephole_,
Expand Down
31 changes: 31 additions & 0 deletions tensorflow/python/kernel_tests/nn_ops/rnn_cell_test.py
Expand Up @@ -33,6 +33,7 @@
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_rnn_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
Expand Down Expand Up @@ -1323,6 +1324,36 @@ def testDynamicEquivalentToStaticRNN(self):
def testDynamicEquivalentToStaticRNNWithSequenceLength(self):
self._testDynamicEquivalentToStaticRNN(use_sequence_length=True)

@test_util.run_in_graph_and_eager_modes
def testLSTMBlockCellErrorHandling(self):
forget_bias = 1
cell_clip = 0
use_peephole = False
x = constant_op.constant(0.837607, shape=[28, 29], dtype=dtypes.float32)
cs_prev = constant_op.constant(0, shape=[28, 17], dtype=dtypes.float32)
h_prev = constant_op.constant(
0.592631638, shape=[28, 17], dtype=dtypes.float32)
w = constant_op.constant(0.887386262, shape=[46, 68], dtype=dtypes.float32)
wci = constant_op.constant(0, shape=[], dtype=dtypes.float32)
wcf = constant_op.constant(0, shape=[17], dtype=dtypes.float32)
wco = constant_op.constant(
0.592631638, shape=[28, 17], dtype=dtypes.float32)
b = constant_op.constant(0.75259006, shape=[68], dtype=dtypes.float32)
with self.assertRaises(errors_impl.InvalidArgumentError):
self.evaluate(
gen_rnn_ops.lstm_block_cell(
x=x,
cs_prev=cs_prev,
h_prev=h_prev,
w=w,
wci=wci,
wcf=wcf,
wco=wco,
b=b,
forget_bias=forget_bias,
cell_clip=cell_clip,
use_peephole=use_peephole))


class BidirectionalRNNTest(test.TestCase):

Expand Down

0 comments on commit 8034040

Please sign in to comment.