Skip to content
Permalink
Browse files Browse the repository at this point in the history
Add size input validation to BlockLSTMGradV2.
Invalid sizes lead to a security vulnerability crash.
The added size checks are copied from the shape function assigned
in `REGISTER_OP`.

PiperOrigin-RevId: 462886105
  • Loading branch information
cantonios authored and tensorflower-gardener committed Jul 24, 2022
1 parent 552bfce commit 2a458fc
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 1 deletion.
25 changes: 24 additions & 1 deletion tensorflow/core/kernels/rnn/lstm_ops.cc
Expand Up @@ -1138,19 +1138,30 @@ class BlockLSTMGradOp : public OpKernel {

const Tensor* x;
OP_REQUIRES_OK(ctx, ctx->input("x", &x));
OP_REQUIRES(ctx, x->dims() == 3, errors::InvalidArgument("x must be 3D"));
OP_REQUIRES(
ctx, x->dims() == 3,
errors::InvalidArgument("x must be rank 3 but is rank ", x->dims()));
const int64_t timelen = x->dim_size(0);
const int64_t batch_size = x->dim_size(1);
const int64_t input_size = x->dim_size(2);

const Tensor* cs_prev_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
OP_REQUIRES(ctx, cs_prev_tensor->dims() == 2,
errors::InvalidArgument("cs_prev must be rank 2 but is rank ",
cs_prev_tensor->dims()));

const Tensor* h_prev_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
OP_REQUIRES(ctx, h_prev_tensor->dims() == 2,
errors::InvalidArgument("h_prev must be rank 2 but is rank ",
h_prev_tensor->dims()));

const Tensor* w_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
OP_REQUIRES(ctx, w_tensor->dims() == 2,
errors::InvalidArgument("w must be rank 2 but is rank ",
w_tensor->dims()));
const int64_t cell_size = w_tensor->dim_size(1) / 4;
OP_REQUIRES(ctx, input_size + cell_size == w_tensor->dim_size(0),
errors::InvalidArgument(
Expand All @@ -1159,15 +1170,27 @@ class BlockLSTMGradOp : public OpKernel {

const Tensor* wci_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
OP_REQUIRES(ctx, wci_tensor->dims() == 1,
errors::InvalidArgument("wci must be rank 1 but is rank ",
wci_tensor->dims()));

const Tensor* wcf_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
OP_REQUIRES(ctx, wcf_tensor->dims() == 1,
errors::InvalidArgument("wcf must be rank 1 but is rank ",
wcf_tensor->dims()));

const Tensor* wco_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
OP_REQUIRES(ctx, wco_tensor->dims() == 1,
errors::InvalidArgument("wco must be rank 1 but is rank ",
wco_tensor->dims()));

const Tensor* b_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
OP_REQUIRES(ctx, b_tensor->dims() == 1,
errors::InvalidArgument("b must be rank 1 but is rank ",
b_tensor->dims()));
OP_REQUIRES(
ctx, cell_size == b_tensor->dim_size(0) / 4,
errors::InvalidArgument("w and b cell_size don't match: ", cell_size,
Expand Down
52 changes: 52 additions & 0 deletions tensorflow/python/kernel_tests/nn_ops/rnn_cell_test.py
Expand Up @@ -1354,6 +1354,58 @@ def testLSTMBlockCellErrorHandling(self):
cell_clip=cell_clip,
use_peephole=use_peephole))

@test_util.run_in_graph_and_eager_modes
def testLSTMBlockCellGradErrorHandling(self):
use_peephole = False
seq_len_max = constant_op.constant(1, shape=[], dtype=dtypes.int64)
x = constant_op.constant(0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
cs_prev = constant_op.constant(
0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
h_prev = constant_op.constant(
0.504355371, shape=[1, 1], dtype=dtypes.float32)
w = constant_op.constant(0.504355371, shape=[1, 1], dtype=dtypes.float32)
wci = constant_op.constant(0.504355371, shape=[1], dtype=dtypes.float32)
wcf = constant_op.constant(0.504355371, shape=[1], dtype=dtypes.float32)
wco = constant_op.constant(0.504355371, shape=[1], dtype=dtypes.float32)
b = constant_op.constant(0.504355371, shape=[1], dtype=dtypes.float32)
i = constant_op.constant(0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
cs = constant_op.constant(
0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
f = constant_op.constant(0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
o = constant_op.constant(0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
ci = constant_op.constant(
0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
co = constant_op.constant(
0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
h = constant_op.constant(0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
cs_grad = constant_op.constant(
0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
h_grad = constant_op.constant(
0.504355371, shape=[1, 1, 1], dtype=dtypes.float32)
with self.assertRaisesRegex((ValueError, errors_impl.InvalidArgumentError),
"must be rank"):
self.evaluate(
gen_rnn_ops.block_lstm_grad_v2(
seq_len_max=seq_len_max,
x=x,
cs_prev=cs_prev,
h_prev=h_prev,
w=w,
wci=wci,
wcf=wcf,
wco=wco,
b=b,
i=i,
cs=cs,
f=f,
o=o,
ci=ci,
co=co,
h=h,
cs_grad=cs_grad,
h_grad=h_grad,
use_peephole=use_peephole))


class BidirectionalRNNTest(test.TestCase):

Expand Down

0 comments on commit 2a458fc

Please sign in to comment.