Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor AutoShardDatasetOp #30930

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 18 additions & 0 deletions tensorflow/core/kernels/data/experimental/BUILD
Expand Up @@ -45,6 +45,7 @@ tf_cc_test(
tf_kernel_library(
name = "auto_shard_dataset_op",
srcs = ["auto_shard_dataset_op.cc"],
hdrs = ["auto_shard_dataset_op.h"],
deps = [
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
Expand All @@ -57,6 +58,23 @@ tf_kernel_library(
],
)

tf_cc_test(
name = "auto_shard_dataset_op_test",
size = "small",
srcs = ["auto_shard_dataset_op_test.cc"],
deps = [
":auto_shard_dataset_op",
"//tensorflow/core:experimental_dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels/data:dataset_test_base",
"//tensorflow/core/kernels/data:shard_dataset_op",
"//third_party/eigen3",
],
)

tf_kernel_library(
name = "choose_fastest_branch_dataset_op",
srcs = ["choose_fastest_branch_dataset_op.cc"],
Expand Down
Expand Up @@ -16,7 +16,6 @@ limitations under the License.

#include <map>

#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/name_utils.h"
Expand Down
112 changes: 57 additions & 55 deletions tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.cc
Expand Up @@ -12,74 +12,76 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.h"

#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"

namespace tensorflow {
namespace data {
namespace {

/* static */ constexpr const char* const AutoShardDatasetOp::kDatasetType;
/* static */ constexpr const char* const AutoShardDatasetOp::kInputDataset;
/* static */ constexpr const char* const AutoShardDatasetOp::kNumWorkers;
/* static */ constexpr const char* const AutoShardDatasetOp::kIndex;
/* static */ constexpr const char* const AutoShardDatasetOp::kOutputTypes;
/* static */ constexpr const char* const AutoShardDatasetOp::kOutputShapes;

constexpr char kOptimizerName[] = "tf_auto_shard";

class AutoShardDatasetOp : public UnaryDatasetOpKernel {
public:
explicit AutoShardDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {}

protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
int64 index, num_workers;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_workers", &num_workers));
OP_REQUIRES(
ctx, num_workers > 0,
errors::InvalidArgument("num_workers must be greater than zero."));

OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "index", &index));
OP_REQUIRES(ctx, index >= 0 && index < num_workers,
errors::InvalidArgument("index must be between 0 and ",
num_workers - 1));

auto config_factory = [num_workers, index]() {
return CreateConfig(num_workers, index);
};

// We only want to optimize functions for some particular datasets like
// FlatMapDataset, InterleaveDataset etc. So we disable generalized
// function optimization and explicitly handle function modifications
// for those datasets in the rewrite.
OP_REQUIRES_OK(ctx,
RewriteDataset(ctx, input, std::move(config_factory),
/*optimize_function_library=*/false, output));
}

private:
static RewriterConfig CreateConfig(int64 num_workers, int64 index) {
RewriterConfig rewriter_config;
rewriter_config.set_fail_on_optimizer_errors(true);
rewriter_config.add_optimizers(kOptimizerName);
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
auto custom_optimizer = rewriter_config.add_custom_optimizers();
custom_optimizer->set_name(kOptimizerName);
AttrValue num_workers_attr;
num_workers_attr.set_i(num_workers);
(*custom_optimizer->mutable_parameter_map())["num_workers"] =
num_workers_attr;

AttrValue index_attr;
index_attr.set_i(index);
(*custom_optimizer->mutable_parameter_map())["index"] = index_attr;

return rewriter_config;
}
};
AutoShardDatasetOp::AutoShardDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {}

void AutoShardDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) {
int64 index, num_workers;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kNumWorkers, &num_workers));
OP_REQUIRES(
ctx, num_workers > 0,
errors::InvalidArgument("num_workers must be greater than zero."));

OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kIndex, &index));
OP_REQUIRES(
ctx, index >= 0 && index < num_workers,
errors::InvalidArgument("index must be between 0 and ", num_workers - 1));

auto config_factory = [num_workers, index]() {
return CreateConfig(num_workers, index);
};

// We only want to optimize functions for some particular datasets like
// FlatMapDataset, InterleaveDataset etc. So we disable generalized
// function optimization and explicitly handle function modifications
// for those datasets in the rewrite.
OP_REQUIRES_OK(ctx,
RewriteDataset(ctx, input, std::move(config_factory),
/*optimize_function_library=*/false, output));
}

RewriterConfig AutoShardDatasetOp::CreateConfig(int64 num_workers,
int64 index) {
RewriterConfig rewriter_config;
rewriter_config.set_fail_on_optimizer_errors(true);
rewriter_config.add_optimizers(kOptimizerName);
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
auto custom_optimizer = rewriter_config.add_custom_optimizers();
custom_optimizer->set_name(kOptimizerName);
AttrValue num_workers_attr;
num_workers_attr.set_i(num_workers);
(*custom_optimizer->mutable_parameter_map())[kNumWorkers] = num_workers_attr;

AttrValue index_attr;
index_attr.set_i(index);
(*custom_optimizer->mutable_parameter_map())[kIndex] = index_attr;

return rewriter_config;
}

namespace {
REGISTER_KERNEL_BUILDER(Name("AutoShardDataset").Device(DEVICE_CPU),
AutoShardDatasetOp);
REGISTER_KERNEL_BUILDER(Name("ExperimentalAutoShardDataset").Device(DEVICE_CPU),
AutoShardDatasetOp);

} // anonymous namespace
} // namespace data
} // namespace tensorflow
48 changes: 48 additions & 0 deletions tensorflow/core/kernels/data/experimental/auto_shard_dataset_op.h
@@ -0,0 +1,48 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_AUTO_SHARD_DATASET_OP_H_
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_AUTO_SHARD_DATASET_OP_H_

#include "tensorflow/core/framework/dataset.h"

namespace tensorflow {
namespace data {

// See documentation in ../../ops/experimental_dataset_ops.cc for a high-level
// description of the following op.

class AutoShardDatasetOp : public UnaryDatasetOpKernel {
public:
static constexpr const char* const kDatasetType = "AutoShard";
static constexpr const char* const kInputDataset = "input_dataset";
static constexpr const char* const kNumWorkers = "num_workers";
static constexpr const char* const kIndex = "index";
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";

explicit AutoShardDatasetOp(OpKernelConstruction* ctx);

protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override;

private:
static RewriterConfig CreateConfig(int64 num_workers, int64 index);
};

} // namespace data
} // namespace tensorflow

#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_AUTO_SHARD_DATASET_OP_H_