Skip to content
Permalink
Browse files Browse the repository at this point in the history
Prevent unitialized variable use in grappler.
PiperOrigin-RevId: 399702928
Change-Id: Id7e75451fbff297692dfb687f60ea04b25c96b24
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Sep 29, 2021
1 parent fad123a commit 68867bf
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tensorflow/core/grappler/optimizers/auto_parallel.cc
Expand Up @@ -152,7 +152,7 @@ Status AutoParallel::Initialize(const GrapplerItem& item) {
TF_RETURN_IF_ERROR(ComputeTransitiveFanin(graph_, item.fetch, &train_nodes));
LOG(INFO) << "Number of training nodes: " << train_nodes.size();

const NodeDef* dequeue_node;
const NodeDef* dequeue_node = nullptr;
for (const auto& train_node : train_nodes) {
if (IsDequeueOp(*train_node)) {
dequeue_node = train_node;
Expand Down
24 changes: 24 additions & 0 deletions tensorflow/core/grappler/optimizers/auto_parallel_test.cc
Expand Up @@ -126,6 +126,30 @@ TEST_F(AutoParallelTest, SimpleParallel) {
EXPECT_EQ("^AutoParallel-Control-Fetch", node_gradient.input(0));
}

TEST_F(AutoParallelTest, SimpleParallelNoDequeue) {
tensorflow::Scope s = tensorflow::Scope::DisabledShapeInferenceScope();
Output constant_a = ops::Const(s.WithOpName("constant_a"), 1.0f, {1});
Output constant_c = ops::Const(s.WithOpName("constant_c"), 1.0f, {1});
Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1});
Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT);
Output assign = ops::Assign(s.WithOpName("assign"), {var}, {constant_a});
Output add = ops::AddN(s.WithOpName("add"), {constant_a, constant_c});
Output learning_rate = ops::Const(s.WithOpName("learning_rate"), 0.01f, {1});
Output apply_gradient = ops::ApplyGradientDescent(
s.WithOpName("apply_gradient"), {var}, {learning_rate}, {add});

GrapplerItem item;
item.init_ops.push_back("assign");
item.fetch.push_back("apply_gradient");
item.init_ops.push_back("assign");
TF_CHECK_OK(s.ToGraphDef(&item.graph));

AutoParallel parallel(2);
GraphDef output;
Status status = parallel.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
}

} // namespace
} // namespace grappler
} // namespace tensorflow

0 comments on commit 68867bf

Please sign in to comment.