Skip to content

Commit

Permalink
Support ExternalStatePolicy in AsGraphDef.
Browse files Browse the repository at this point in the history
Previously AsGraphDef used a boolean to either fail or
ignore external state when serializing a dataset graph.
Now, it will take a policy which can be either IGNORE,
FAIL, or WARN. This will allow code reuse by other
operations which serialize dataset graphs.

Some code is re-organized to avoid a circular dependency
between dataset_utils and captured_function. The new
serialization_utils implements AsGraphDef with the help
of captured_function for determining whether functions
contain stateful operations.

PiperOrigin-RevId: 278488049
Change-Id: Ifd6390d30c2cd4ddbc0392b4eea479d8d8c881ee
  • Loading branch information
aaudiber authored and tensorflower-gardener committed Nov 5, 2019
1 parent 336925c commit 12ed778
Show file tree
Hide file tree
Showing 10 changed files with 170 additions and 91 deletions.
22 changes: 17 additions & 5 deletions tensorflow/core/framework/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -440,13 +440,23 @@ class IteratorContext {
// Aggregates runtime support needed for dataset and iterator serialization.
class SerializationContext {
public:
// Enum describing what to do during serialization when external state is
// encountered.
enum class ExternalStatePolicy : int64 {
// Proceed with serialization, but log a warning about what state will be
// lost.
kWarn = 0,
// Proceed with serialization without logging any warning.
kIgnore = 1,
// Fail the serialization with an error.
kFail = 2,
};

struct Params {
std::vector<std::pair<string, Tensor>>* input_list = nullptr; // Not owned.

// Indicates whether serialization should check if the dataset depends on
// external state. If the check is enabled and external state is
// encountered, then the serialization will fail.
bool check_external_state = true;
// Indicates what to do if the dataset depends on external state.
ExternalStatePolicy external_state_policy = ExternalStatePolicy::kWarn;

// Indicates whether an attempt to serialize a dataset that does not
// implement serialization should result in an error. If set to `false`, the
Expand All @@ -467,7 +477,9 @@ class SerializationContext {
return params_.input_list;
}

bool check_external_state() const { return params_.check_external_state; }
ExternalStatePolicy external_state_policy() const {
return params_.external_state_policy;
}

bool fail_if_unimplemented() const { return params_.fail_if_unimplemented; }

Expand Down
15 changes: 15 additions & 0 deletions tensorflow/core/kernels/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ cc_library(
":map_dataset_op",
":name_utils",
":range_dataset_op",
":serialization_utils",
":take_dataset_op",
":tensor_slice_dataset_op",
"//tensorflow/core:core_cpu",
Expand Down Expand Up @@ -123,6 +124,19 @@ cc_library(
],
)

cc_library(
name = "serialization_utils",
srcs = ["serialization_utils.cc"],
hdrs = ["serialization_utils.h"],
deps = [
":captured_function",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/core:status",
],
)

cc_library(
name = "stats_utils",
srcs = ["stats_utils.cc"],
Expand Down Expand Up @@ -1259,6 +1273,7 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset_utils",
":serialization_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
Expand Down
55 changes: 7 additions & 48 deletions tensorflow/core/kernels/data/dataset_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,34 +31,6 @@ limitations under the License.

namespace tensorflow {
namespace data {
namespace {
Status FindStatefulOps(const GraphDef& graph_def,
std::vector<string>* stateful_op_names) {
FunctionLibraryDefinition lib_def(OpRegistry::Global(), graph_def.library());

// Iterate over all nodes in the graph.
for (const auto& node : graph_def.node()) {
// Each Dataset graph has a _Retval op in the end which is marked stateful
if (node.op() == FunctionLibraryDefinition::kRetOp) continue;
if (!IsNodeStateful(lib_def, node).ok()) {
stateful_op_names->push_back(node.op());
}
}

// Iterate over all functions.
for (const auto& fdef : graph_def.library().function()) {
if (!fdef.signature().is_stateful()) continue;
for (const auto& node : fdef.node_def()) {
if (!IsNodeStateful(lib_def, node).ok()) {
stateful_op_names->push_back(
absl::StrCat(node.op(), " in function: ", fdef.signature().name()));
}
}
}

return Status::OK();
}
} // namespace

/* static */ constexpr const char* const DatasetToGraphOp::kAllowStateful;
/* static */ constexpr const char* const
Expand All @@ -77,16 +49,19 @@ DatasetToGraphOp::DatasetToGraphOp(OpKernelConstruction* ctx)
int64 state_change_option;
OP_REQUIRES_OK(ctx,
ctx->GetAttr(kExternalStatePolicy, &state_change_option));
external_state_policy_ = ExternalStatePolicy(state_change_option);
external_state_policy_ =
SerializationContext::ExternalStatePolicy(state_change_option);
}
} else {
if (ctx->HasAttr(kAllowStateful)) {
bool allow_stateful;
OP_REQUIRES_OK(ctx, ctx->GetAttr(kAllowStateful, &allow_stateful));
if (allow_stateful) {
external_state_policy_ = ExternalStatePolicy::kWarn;
external_state_policy_ =
SerializationContext::ExternalStatePolicy::kWarn;
} else {
external_state_policy_ = ExternalStatePolicy::kFail;
external_state_policy_ =
SerializationContext::ExternalStatePolicy::kFail;
}
}
}
Expand All @@ -101,8 +76,7 @@ void DatasetToGraphOp::Compute(OpKernelContext* ctx) {
DatasetBase* dataset;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
SerializationContext::Params params;
params.check_external_state =
(external_state_policy_ == ExternalStatePolicy::kFail);
params.external_state_policy = external_state_policy_;

GraphDef graph_def;
Status s = AsGraphDef(ctx, dataset, SerializationContext(params), &graph_def);
Expand All @@ -113,21 +87,6 @@ void DatasetToGraphOp::Compute(OpKernelContext* ctx) {
s.error_message()));
return;
}
// In case we allow stateful ops, we walk the graph and find all the stateful
// ops in the Graph. We then log a warning indicating what ops' state we are
// going to throw away.
if (external_state_policy_ == ExternalStatePolicy::kWarn) {
std::vector<string> stateful_op_names;
OP_REQUIRES_OK(ctx, FindStatefulOps(graph_def, &stateful_op_names));
if (!stateful_op_names.empty()) {
LOG(WARNING)
<< "We found the following stateful ops in the dataset "
"construction graph whose state would not be serialized and might "
"cause subtle bugs: "
<< absl::StrJoin(stateful_op_names, ", ");
}
}

if (strip_device_assignment_) {
auto library = graph_def.mutable_library();
for (auto& function : (*library->mutable_function())) {
Expand Down
11 changes: 2 additions & 9 deletions tensorflow/core/kernels/data/dataset_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,9 @@ class DatasetToGraphOp : public OpKernel {
void Compute(OpKernelContext* ctx) override;

private:
// Enum describing what to do during serialization when external state is
// encountered.
enum class ExternalStatePolicy {
kWarn,
kIgnore,
kFail,
};

const int op_version_;
ExternalStatePolicy external_state_policy_ = ExternalStatePolicy::kWarn;
SerializationContext::ExternalStatePolicy external_state_policy_ =
SerializationContext::ExternalStatePolicy::kWarn;
bool strip_device_assignment_ = false;
};

Expand Down
22 changes: 0 additions & 22 deletions tensorflow/core/kernels/data/dataset_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,28 +297,6 @@ Status HashFunctionImpl(const FunctionDefLibrary& library,

} // anonymous namespace

Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
SerializationContext&& serialization_ctx,
GraphDef* graph_def) {
if (serialization_ctx.check_external_state()) {
TF_RETURN_IF_ERROR(dataset->CheckExternalState());
}
GraphDefBuilder b;
DatasetBase::DatasetGraphDefBuilder db(&b);
Node* output_node = nullptr;
TF_RETURN_IF_ERROR(
db.AddInputDataset(&serialization_ctx, dataset, &output_node));
// Insert a purely symbolic _Retval node to indicate to consumers which node
// represents `dataset`.
ops::UnaryOp("_Retval", output_node,
b.opts()
.WithName("dataset")
.WithAttr("T", DT_VARIANT)
.WithAttr("index", 0));
TF_RETURN_IF_ERROR(b.ToGraphDef(graph_def));
return Status::OK();
}

Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
std::function<void()> register_fn,
std::function<void()>* deregister_fn) {
Expand Down
5 changes: 0 additions & 5 deletions tensorflow/core/kernels/data/dataset_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,6 @@ class AnonymousResourceOp : public OpKernel {
bool create_deleter_ = true;
};

// Returns a GraphDef representation of the given dataset.
Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
SerializationContext&& serialization_ctx,
GraphDef* graph_def);

// Registers the given cancellation callback, returning a function that can be
// used to deregister the callback.
Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
SerializationContext::Params params;
std::vector<std::pair<string, Tensor>> input_list;
params.input_list = &input_list;
params.check_external_state = false;
params.external_state_policy =
SerializationContext::ExternalStatePolicy::kIgnore;

GraphDef graph_def;
OP_REQUIRES_OK(
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/data/rewrite_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
SerializationContext::Params params;
std::vector<std::pair<string, Tensor>> input_list;
params.input_list = &input_list;
params.check_external_state = false;
params.external_state_policy =
SerializationContext::ExternalStatePolicy::kIgnore;
params.fail_if_unimplemented = false;
params.serialize_data_tensors = false;
SerializationContext serialization_ctx(params);
Expand Down
92 changes: 92 additions & 0 deletions tensorflow/core/kernels/data/serialization_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/core/kernels/data/serialization_utils.h"

#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/kernels/data/captured_function.h"

namespace tensorflow {
namespace data {

namespace {

// FindStatefulOps searches `graph_def` for all of its stateful ops storing
// their names in `stateful_op_names`.
Status FindStatefulOps(const GraphDef& graph_def,
std::vector<string>* stateful_op_names) {
FunctionLibraryDefinition lib_def(OpRegistry::Global(), graph_def.library());

// Iterate over all nodes in the graph.
for (const auto& node : graph_def.node()) {
// Each Dataset graph has a _Retval op in the end which is marked stateful
if (node.op() == FunctionLibraryDefinition::kRetOp) continue;
if (!IsNodeStateful(lib_def, node).ok()) {
stateful_op_names->push_back(node.op());
}
}

// Iterate over all functions.
for (const auto& fdef : graph_def.library().function()) {
if (!fdef.signature().is_stateful()) continue;
for (const auto& node : fdef.node_def()) {
if (!IsNodeStateful(lib_def, node).ok()) {
stateful_op_names->push_back(
absl::StrCat(node.op(), " in function: ", fdef.signature().name()));
}
}
}
return Status::OK();
}

} // namespace

Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
SerializationContext&& serialization_ctx,
GraphDef* graph_def) {
if (serialization_ctx.external_state_policy() ==
SerializationContext::ExternalStatePolicy::kFail) {
TF_RETURN_IF_ERROR(dataset->CheckExternalState());
}
if (serialization_ctx.external_state_policy() ==
SerializationContext::ExternalStatePolicy::kWarn) {
std::vector<string> stateful_op_names;
TF_RETURN_IF_ERROR(FindStatefulOps(*graph_def, &stateful_op_names));
if (!stateful_op_names.empty()) {
LOG(WARNING)
<< "We found the following stateful ops in the dataset "
"construction graph whose state would not be serialized and might "
"cause subtle bugs: "
<< absl::StrJoin(stateful_op_names, ", ");
}
}
GraphDefBuilder b;
DatasetBase::DatasetGraphDefBuilder db(&b);
Node* output_node = nullptr;
TF_RETURN_IF_ERROR(
db.AddInputDataset(&serialization_ctx, dataset, &output_node));
// Insert a purely symbolic _Retval node to indicate to consumers which node
// represents `dataset`.
ops::UnaryOp("_Retval", output_node,
b.opts()
.WithName("dataset")
.WithAttr("T", DT_VARIANT)
.WithAttr("index", 0));
TF_RETURN_IF_ERROR(b.ToGraphDef(graph_def));
return Status::OK();
}

} // namespace data
} // namespace tensorflow
33 changes: 33 additions & 0 deletions tensorflow/core/kernels/data/serialization_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_KERNELS_DATA_SERIALIZATION_UTILS_H_
#define TENSORFLOW_CORE_KERNELS_DATA_SERIALIZATION_UTILS_H_

#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/lib/core/status.h"

namespace tensorflow {
namespace data {

// Returns a GraphDef representation of the given dataset.
Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
SerializationContext&& serialization_ctx,
GraphDef* graph_def);

} // namespace data
} // namespace tensorflow

#endif // TENSORFLOW_CORE_KERNELS_DATA_SERIALIZATION_UTILS_H_

0 comments on commit 12ed778

Please sign in to comment.