-
Notifications
You must be signed in to change notification settings - Fork 74.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support ExternalStatePolicy in AsGraphDef.
Previously AsGraphDef used a boolean to either fail or ignore external state when serializing a dataset graph. Now, it will take a policy which can be either IGNORE, FAIL, or WARN. This will allow code reuse by other operations which serialize dataset graphs. Some code is re-organized to avoid a circular dependency between dataset_utils and captured_function. The new serialization_utils implements AsGraphDef with the help of captured_function for determining whether functions contain stateful operations. PiperOrigin-RevId: 278488049 Change-Id: Ifd6390d30c2cd4ddbc0392b4eea479d8d8c881ee
- Loading branch information
1 parent
336925c
commit 12ed778
Showing
10 changed files
with
170 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#include "tensorflow/core/kernels/data/serialization_utils.h" | ||
|
||
#include "tensorflow/core/graph/graph_def_builder.h" | ||
#include "tensorflow/core/kernels/data/captured_function.h" | ||
|
||
namespace tensorflow { | ||
namespace data { | ||
|
||
namespace { | ||
|
||
// FindStatefulOps searches `graph_def` for all of its stateful ops storing | ||
// their names in `stateful_op_names`. | ||
Status FindStatefulOps(const GraphDef& graph_def, | ||
std::vector<string>* stateful_op_names) { | ||
FunctionLibraryDefinition lib_def(OpRegistry::Global(), graph_def.library()); | ||
|
||
// Iterate over all nodes in the graph. | ||
for (const auto& node : graph_def.node()) { | ||
// Each Dataset graph has a _Retval op in the end which is marked stateful | ||
if (node.op() == FunctionLibraryDefinition::kRetOp) continue; | ||
if (!IsNodeStateful(lib_def, node).ok()) { | ||
stateful_op_names->push_back(node.op()); | ||
} | ||
} | ||
|
||
// Iterate over all functions. | ||
for (const auto& fdef : graph_def.library().function()) { | ||
if (!fdef.signature().is_stateful()) continue; | ||
for (const auto& node : fdef.node_def()) { | ||
if (!IsNodeStateful(lib_def, node).ok()) { | ||
stateful_op_names->push_back( | ||
absl::StrCat(node.op(), " in function: ", fdef.signature().name())); | ||
} | ||
} | ||
} | ||
return Status::OK(); | ||
} | ||
|
||
} // namespace | ||
|
||
Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset, | ||
SerializationContext&& serialization_ctx, | ||
GraphDef* graph_def) { | ||
if (serialization_ctx.external_state_policy() == | ||
SerializationContext::ExternalStatePolicy::kFail) { | ||
TF_RETURN_IF_ERROR(dataset->CheckExternalState()); | ||
} | ||
if (serialization_ctx.external_state_policy() == | ||
SerializationContext::ExternalStatePolicy::kWarn) { | ||
std::vector<string> stateful_op_names; | ||
TF_RETURN_IF_ERROR(FindStatefulOps(*graph_def, &stateful_op_names)); | ||
if (!stateful_op_names.empty()) { | ||
LOG(WARNING) | ||
<< "We found the following stateful ops in the dataset " | ||
"construction graph whose state would not be serialized and might " | ||
"cause subtle bugs: " | ||
<< absl::StrJoin(stateful_op_names, ", "); | ||
} | ||
} | ||
GraphDefBuilder b; | ||
DatasetBase::DatasetGraphDefBuilder db(&b); | ||
Node* output_node = nullptr; | ||
TF_RETURN_IF_ERROR( | ||
db.AddInputDataset(&serialization_ctx, dataset, &output_node)); | ||
// Insert a purely symbolic _Retval node to indicate to consumers which node | ||
// represents `dataset`. | ||
ops::UnaryOp("_Retval", output_node, | ||
b.opts() | ||
.WithName("dataset") | ||
.WithAttr("T", DT_VARIANT) | ||
.WithAttr("index", 0)); | ||
TF_RETURN_IF_ERROR(b.ToGraphDef(graph_def)); | ||
return Status::OK(); | ||
} | ||
|
||
} // namespace data | ||
} // namespace tensorflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
|
||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_SERIALIZATION_UTILS_H_ | ||
#define TENSORFLOW_CORE_KERNELS_DATA_SERIALIZATION_UTILS_H_ | ||
|
||
#include "tensorflow/core/framework/dataset.h" | ||
#include "tensorflow/core/lib/core/status.h" | ||
|
||
namespace tensorflow { | ||
namespace data { | ||
|
||
// Returns a GraphDef representation of the given dataset. | ||
Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset, | ||
SerializationContext&& serialization_ctx, | ||
GraphDef* graph_def); | ||
|
||
} // namespace data | ||
} // namespace tensorflow | ||
|
||
#endif // TENSORFLOW_CORE_KERNELS_DATA_SERIALIZATION_UTILS_H_ |