Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix access to undefined memory during shape inference of Cudnn*.
PiperOrigin-RevId: 400324259
Change-Id: Ie3b7859d19ae24ee9ac2adf413bdc1e851bbc604
  • Loading branch information
tensorflower-gardener committed Oct 2, 2021
1 parent 3f560dc commit af5fceb
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 0 deletions.
21 changes: 21 additions & 0 deletions tensorflow/core/ops/cudnn_rnn_ops.cc
Expand Up @@ -81,11 +81,17 @@ REGISTER_OP("CudnnRNN")
.Attr("seed2: int = 0")
.Attr("is_training: bool = true")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
auto input_shape = c->input(0);
auto input_h_shape = c->input(1);
TF_RETURN_IF_ERROR(c->WithRank(input_shape, 3, &unused));
TF_RETURN_IF_ERROR(c->WithRank(input_h_shape, 3, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused));

auto seq_length = c->Dim(input_shape, 0);
auto batch_size = c->Dim(input_shape, 1);
auto num_units = c->Dim(input_h_shape, 2);

string direction;
TF_RETURN_IF_ERROR(c->GetAttr("direction", &direction));
string rnn_mode;
Expand Down Expand Up @@ -124,8 +130,13 @@ REGISTER_OP("CudnnRNNV2")
.Attr("seed2: int = 0")
.Attr("is_training: bool = true")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
auto input_shape = c->input(0);
auto input_h_shape = c->input(1);
TF_RETURN_IF_ERROR(c->WithRank(input_shape, 3, &unused));
TF_RETURN_IF_ERROR(c->WithRank(input_h_shape, 3, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused));

auto seq_length = c->Dim(input_shape, 0);
auto batch_size = c->Dim(input_shape, 1);
auto num_units = c->Dim(input_h_shape, 2);
Expand Down Expand Up @@ -171,16 +182,26 @@ REGISTER_OP("CudnnRNNV3")
.Attr("is_training: bool = true")
.Attr("time_major: bool = true")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
auto input_shape = c->input(0);
auto input_h_shape = c->input(1);
auto input_c_shape = c->input(2);
TF_RETURN_IF_ERROR(c->WithRank(input_shape, 3, &unused));
TF_RETURN_IF_ERROR(c->WithRank(input_h_shape, 3, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &unused));

auto max_seq_length = c->Dim(input_shape, 0);
auto batch_size = c->Dim(input_shape, 1);
auto num_units = c->Dim(input_h_shape, 2);

string direction;
TF_RETURN_IF_ERROR(c->GetAttr("direction", &direction));
string rnn_mode;
TF_RETURN_IF_ERROR(c->GetAttr("rnn_mode", &rnn_mode));
if (rnn_mode == "lstm") {
TF_RETURN_IF_ERROR(c->WithRank(input_c_shape, 3, &unused));
}
int dir_count = (direction == "bidirectional") ? 2 : 1;
DimensionHandle output_size;
TF_RETURN_IF_ERROR(c->Multiply(num_units, dir_count, &output_size));
Expand Down
56 changes: 56 additions & 0 deletions tensorflow/core/ops/cudnn_rnn_ops_test.cc
Expand Up @@ -68,6 +68,11 @@ TEST(CudnnRNNOpsTest, ForwardLstm_ShapeFn) {
.Attr("direction", "unidirectional")
.Finalize(&op.node_def));
INFER_OK(op, input_shapes_desc, output_shapes_desc);
INFER_ERROR("Shape must be rank 3 ", op, "[];[?,?,?];[?,?,?];[?]");
INFER_ERROR("Shape must be rank 3 ", op, "[?,?,?];[];[?,?,?];[?]");
// Disabled because the kernel does not check shape of input_c.
// INFER_ERROR("Shape must be rank 3 ", op, "[?,?,?];[?,?,?];[?];[?]");
INFER_ERROR("Shape must be rank 1 ", op, "[?,?,?];[?,?,?];[?,?,?];[]");
}

TEST(CudnnRNNOpsTest, ForwardV2Lstm_ShapeFn) {
Expand Down Expand Up @@ -100,6 +105,11 @@ TEST(CudnnRNNOpsTest, ForwardV2Lstm_ShapeFn) {
.Attr("direction", "unidirectional")
.Finalize(&op.node_def));
INFER_OK(op, input_shapes_desc, output_shapes_desc);
INFER_ERROR("Shape must be rank 3 ", op, "[];[?,?,?];[?,?,?];[?]");
INFER_ERROR("Shape must be rank 3 ", op, "[?,?,?];[];[?,?,?];[?]");
// Disabled because the kernel does not check shape of input_c.
// INFER_ERROR("Shape must be rank 3 ", op, "[?,?,?];[?,?,?];[?];[?]");
INFER_ERROR("Shape must be rank 1 ", op, "[?,?,?];[?,?,?];[?,?,?];[]");
}

TEST(CudnnRNNOpsTest, ForwardV3Lstm_ShapeFn) {
Expand Down Expand Up @@ -137,6 +147,52 @@ TEST(CudnnRNNOpsTest, ForwardV3Lstm_ShapeFn) {
.Attr("direction", "unidirectional")
.Finalize(&op.node_def));
INFER_OK(op, input_shapes_desc, output_shapes_desc);
INFER_ERROR("Shape must be rank 3 ", op, "[];[?,?,?];[?,?,?];[?];[?]");
INFER_ERROR("Shape must be rank 3 ", op, "[?,?,?];[];[?,?,?];[?];[?]");
INFER_ERROR("Shape must be rank 3 ", op, "[?,?,?];[?,?,?];[];[?];[?]");
INFER_ERROR("Shape must be rank 1 ", op, "[?,?,?];[?,?,?];[?,?,?];[];[?]");
INFER_ERROR("Shape must be rank 1 ", op, "[?,?,?];[?,?,?];[?,?,?];[?];[]");
}

TEST(CudnnRNNOpsTest, ForwardV3Gru) {
int max_seq_length = 2;
int batch_size = 3;
int num_units = 4;
int num_layers = 5;
int dir_count = 1;
std::vector<int> input_shape = {max_seq_length, batch_size, num_units};
std::vector<int> input_h_shape = {num_layers * dir_count, batch_size,
num_units};
std::vector<int> input_c_shape = {num_layers * dir_count, batch_size,
num_units};
std::vector<int> output_shape = {max_seq_length, batch_size,
num_units * dir_count};
std::vector<int> seq_lengths_shape = {batch_size};
auto shape_to_str = [](const std::vector<int>& v) {
return strings::StrCat("[", absl::StrJoin(v, ","), "]");
};
string input_shapes_desc = strings::StrCat(
shape_to_str(input_shape), ";", shape_to_str(input_h_shape), ";",
shape_to_str(input_c_shape), ";", "[?]", ";",
shape_to_str(seq_lengths_shape));
string output_shapes_desc = "[d0_0,d0_1,d1_2];in1;[];?;?";

ShapeInferenceTestOp op("CudnnRNNV3");
TF_ASSERT_OK(NodeDefBuilder("test", "CudnnRNNV3")
.Input({"input", 0, DT_FLOAT})
.Input({"input_h", 0, DT_FLOAT})
.Input({"input_c", 0, DT_FLOAT})
.Input({"params", 0, DT_FLOAT})
.Input({"sequence_lengths", 0, DT_INT32})
.Attr("rnn_mode", "gru")
.Attr("input_mode", "auto_select")
.Attr("direction", "unidirectional")
.Finalize(&op.node_def));
INFER_OK(op, input_shapes_desc, output_shapes_desc);
INFER_ERROR("Shape must be rank 3 ", op, "[];[?,?,?];[];[?];[?]");
INFER_ERROR("Shape must be rank 3 ", op, "[?,?,?];[];[];[?];[?]");
INFER_ERROR("Shape must be rank 1 ", op, "[?,?,?];[?,?,?];[];[];[?]");
INFER_ERROR("Shape must be rank 1 ", op, "[?,?,?];[?,?,?];[];[?];[]");
}

} // end namespace tensorflow

0 comments on commit af5fceb

Please sign in to comment.