Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor {Generator, Interleave, PaddedBatch} DatasetOps #29098

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions tensorflow/core/kernels/data/BUILD
Expand Up @@ -284,7 +284,9 @@ tf_cc_test(
tf_kernel_library(
name = "padded_batch_dataset_op",
srcs = ["padded_batch_dataset_op.cc"],
hdrs = ["padded_batch_dataset_op.h"],
deps = [
":name_utils",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
Expand Down Expand Up @@ -451,6 +453,7 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset_utils",
":name_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
Expand Down Expand Up @@ -502,9 +505,11 @@ tf_cc_test(
tf_kernel_library(
name = "interleave_dataset_op",
srcs = ["interleave_dataset_op.cc"],
hdrs = ["interleave_dataset_op.h"],
deps = [
":captured_function",
":dataset_utils",
":name_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
Expand Down
13 changes: 8 additions & 5 deletions tensorflow/core/kernels/data/concatenate_dataset_op.cc
Expand Up @@ -24,11 +24,11 @@ namespace data {
// See documentation in ../../ops/dataset_ops.cc for a high-level
// description of the following op.

constexpr const char ConcatenateDatasetOp::kDatasetType[];
constexpr const char ConcatenateDatasetOp::kInputDataset[];
constexpr const char ConcatenateDatasetOp::kAnotherDataset[];
constexpr const char ConcatenateDatasetOp::kOutputTypes[];
constexpr const char ConcatenateDatasetOp::kOutputShapes[];
/* static */ constexpr const char* const ConcatenateDatasetOp::kDatasetType;
/* static */ constexpr const char* const ConcatenateDatasetOp::kInputDataset;
/* static */ constexpr const char* const ConcatenateDatasetOp::kAnotherDataset;
/* static */ constexpr const char* const ConcatenateDatasetOp::kOutputTypes;
/* static */ constexpr const char* const ConcatenateDatasetOp::kOutputShapes;

constexpr char kIndex[] = "i";
constexpr char kInputImplUninitialized[] = "input_impl_uninitialized";
Expand Down Expand Up @@ -202,6 +202,9 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase {
std::vector<PartialTensorShape> output_shapes_;
};

ConcatenateDatasetOp::ConcatenateDatasetOp(OpKernelConstruction* ctx)
: BinaryDatasetOpKernel(ctx) {}

void ConcatenateDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase* to_concatenate,
DatasetBase** output) {
Expand Down
15 changes: 7 additions & 8 deletions tensorflow/core/kernels/data/concatenate_dataset_op.h
Expand Up @@ -22,14 +22,13 @@ namespace data {

class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
public:
static constexpr const char kDatasetType[] = "Concatenate";
static constexpr const char kInputDataset[] = "input_dataset";
static constexpr const char kAnotherDataset[] = "another_dataset";
static constexpr const char kOutputTypes[] = "output_types";
static constexpr const char kOutputShapes[] = "output_shapes";

explicit ConcatenateDatasetOp(OpKernelConstruction* ctx)
: BinaryDatasetOpKernel(ctx) {}
static constexpr const char* const kDatasetType = "Concatenate";
static constexpr const char* const kInputDataset = "input_dataset";
static constexpr const char* const kAnotherDataset = "another_dataset";
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";

explicit ConcatenateDatasetOp(OpKernelConstruction* ctx);

protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
Expand Down
23 changes: 16 additions & 7 deletions tensorflow/core/kernels/data/filter_dataset_op.cc
Expand Up @@ -32,13 +32,13 @@ namespace data {
// See documentation in ../../ops/dataset_ops.cc for a high-level
// description of the following op.

constexpr const char FilterDatasetOp::kDatasetType[];
constexpr const char FilterDatasetOp::kInputDataset[];
constexpr const char FilterDatasetOp::kOtherArguments[];
constexpr const char FilterDatasetOp::kPredicate[];
constexpr const char FilterDatasetOp::kTarguments[];
constexpr const char FilterDatasetOp::kOutputTypes[];
constexpr const char FilterDatasetOp::kOutputShapes[];
/* static */ constexpr const char* const FilterDatasetOp::kDatasetType;
/* static */ constexpr const char* const FilterDatasetOp::kInputDataset;
/* static */ constexpr const char* const FilterDatasetOp::kOtherArguments;
/* static */ constexpr const char* const FilterDatasetOp::kPredicate;
/* static */ constexpr const char* const FilterDatasetOp::kTarguments;
/* static */ constexpr const char* const FilterDatasetOp::kOutputTypes;
/* static */ constexpr const char* const FilterDatasetOp::kOutputShapes;

constexpr char kInputImplsEmpty[] = "input_impls_empty";
constexpr char kFilteredElements[] = "filtered_elements";
Expand Down Expand Up @@ -228,6 +228,15 @@ class FilterDatasetOp::Dataset : public DatasetBase {
const std::unique_ptr<CapturedFunction> captured_func_;
};

FilterDatasetOp::FilterDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kPredicate, /*params=*/{},
&func_metadata_));
OP_REQUIRES(ctx, func_metadata_->short_circuit_info().indices.size() <= 1,
errors::InvalidArgument(
"predicate function has more than one return value."));
}

void FilterDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) {
std::unique_ptr<CapturedFunction> captured_func;
Expand Down
25 changes: 9 additions & 16 deletions tensorflow/core/kernels/data/filter_dataset_op.h
Expand Up @@ -23,22 +23,15 @@ namespace data {

class FilterDatasetOp : public UnaryDatasetOpKernel {
public:
static constexpr const char kDatasetType[] = "Filter";
static constexpr const char kInputDataset[] = "input_dataset";
static constexpr const char kOtherArguments[] = "other_arguments";
static constexpr const char kPredicate[] = "predicate";
static constexpr const char kTarguments[] = "Targuments";
static constexpr const char kOutputTypes[] = "output_types";
static constexpr const char kOutputShapes[] = "output_shapes";

explicit FilterDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kPredicate, /*params=*/{},
&func_metadata_));
OP_REQUIRES(ctx, func_metadata_->short_circuit_info().indices.size() <= 1,
errors::InvalidArgument(
"predicate function has more than one return value."));
}
static constexpr const char* const kDatasetType = "Filter";
static constexpr const char* const kInputDataset = "input_dataset";
static constexpr const char* const kOtherArguments = "other_arguments";
static constexpr const char* const kPredicate = "predicate";
static constexpr const char* const kTarguments = "Targuments";
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";

explicit FilterDatasetOp(OpKernelConstruction* ctx);

protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
Expand Down
24 changes: 16 additions & 8 deletions tensorflow/core/kernels/data/flat_map_dataset_op.cc
Expand Up @@ -27,13 +27,13 @@ namespace data {
// See documentation in ../../ops/dataset_ops.cc for a high-level
// description of the following op.

constexpr const char FlatMapDatasetOp::kDatasetType[];
constexpr const char FlatMapDatasetOp::kInputDataset[];
constexpr const char FlatMapDatasetOp::kOtherArguments[];
constexpr const char FlatMapDatasetOp::kF[];
constexpr const char FlatMapDatasetOp::kTarguments[];
constexpr const char FlatMapDatasetOp::kOutputTypes[];
constexpr const char FlatMapDatasetOp::kOutputShapes[];
/* static */ constexpr const char* const FlatMapDatasetOp::kDatasetType;
/* static */ constexpr const char* const FlatMapDatasetOp::kInputDataset;
/* static */ constexpr const char* const FlatMapDatasetOp::kOtherArguments;
/* static */ constexpr const char* const FlatMapDatasetOp::kFunc;
/* static */ constexpr const char* const FlatMapDatasetOp::kTarguments;
/* static */ constexpr const char* const FlatMapDatasetOp::kOutputTypes;
/* static */ constexpr const char* const FlatMapDatasetOp::kOutputShapes;

constexpr char kElementIndex[] = "element_index";
constexpr char kCapturedFuncInputsSize[] = "captured_func_inputs_size";
Expand Down Expand Up @@ -92,7 +92,7 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
TF_RETURN_IF_ERROR(b->AddDataset(
this, {std::make_pair(0, input_graph_node)}, // Single tensor inputs.
{std::make_pair(1, other_arguments)}, // Tensor list inputs.
{std::make_pair(kF, f),
{std::make_pair(kFunc, f),
std::make_pair(kTarguments, other_arguments_types_attr)}, // Attrs
output));
return Status::OK();
Expand Down Expand Up @@ -246,6 +246,14 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
const std::vector<PartialTensorShape> output_shapes_;
};

FlatMapDatasetOp::FlatMapDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kFunc, /*params=*/{},
&func_metadata_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
}

void FlatMapDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) {
std::unique_ptr<CapturedFunction> captured_func;
Expand Down
25 changes: 9 additions & 16 deletions tensorflow/core/kernels/data/flat_map_dataset_op.h
Expand Up @@ -23,22 +23,15 @@ namespace data {

class FlatMapDatasetOp : public UnaryDatasetOpKernel {
public:
static constexpr const char kDatasetType[] = "FlatMap";
static constexpr const char kInputDataset[] = "input_dataset";
static constexpr const char kOtherArguments[] = "other_arguments";
static constexpr const char kF[] = "f";
static constexpr const char kTarguments[] = "Targuments";
static constexpr const char kOutputTypes[] = "output_types";
static constexpr const char kOutputShapes[] = "output_shapes";

explicit FlatMapDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
graph_def_version_(ctx->graph_def_version()) {
OP_REQUIRES_OK(
ctx, FunctionMetadata::Create(ctx, kF, /*params=*/{}, &func_metadata_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
}
static constexpr const char* const kDatasetType = "FlatMap";
static constexpr const char* const kInputDataset = "input_dataset";
static constexpr const char* const kOtherArguments = "other_arguments";
static constexpr const char* const kFunc = "f";
static constexpr const char* const kTarguments = "Targuments";
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";

explicit FlatMapDatasetOp(OpKernelConstruction* ctx);

protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/kernels/data/flat_map_dataset_op_test.cc
Expand Up @@ -42,7 +42,7 @@ class FlatMapDatasetOpTest : public DatasetOpsTestBase {
NodeDef node_def = test::function::NDef(
kNodeName, name_utils::OpName(FlatMapDatasetOp::kDatasetType),
{FlatMapDatasetOp::kInputDataset},
{{FlatMapDatasetOp::kF, func},
{{FlatMapDatasetOp::kFunc, func},
{FlatMapDatasetOp::kTarguments, {}},
{FlatMapDatasetOp::kOutputTypes, output_types},
{FlatMapDatasetOp::kOutputShapes, output_shapes}});
Expand Down
49 changes: 32 additions & 17 deletions tensorflow/core/kernels/data/generator_dataset_op.cc
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/name_utils.h"
#include "tensorflow/core/lib/random/random.h"

namespace tensorflow {
Expand All @@ -29,6 +30,20 @@ namespace data {
// See documentation in ../../ops/dataset_ops.cc for a high-level
// description of the following op.

/* static */ constexpr const char* const GeneratorDatasetOp::kDatasetType;
/* static */ constexpr const char* const GeneratorDatasetOp::kInitFuncOtherArgs;
/* static */ constexpr const char* const GeneratorDatasetOp::kNextFuncOtherArgs;
/* static */ constexpr const char* const
GeneratorDatasetOp::kFinalizeFuncOtherArgs;
/* static */ constexpr const char* const GeneratorDatasetOp::kInitFunc;
/* static */ constexpr const char* const GeneratorDatasetOp::kNextFunc;
/* static */ constexpr const char* const GeneratorDatasetOp::kFinalizeFunc;
/* static */ constexpr const char* const GeneratorDatasetOp::kTinitFuncArgs;
/* static */ constexpr const char* const GeneratorDatasetOp::kTnextFuncArgs;
/* static */ constexpr const char* const GeneratorDatasetOp::kTfinalizeFuncArgs;
/* static */ constexpr const char* const GeneratorDatasetOp::kOutputTypes;
/* static */ constexpr const char* const GeneratorDatasetOp::kOutputShapes;

class GeneratorDatasetOp::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, std::unique_ptr<CapturedFunction> init_func,
Expand All @@ -45,8 +60,8 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {

std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return absl::make_unique<Iterator>(
Iterator::Params{this, strings::StrCat(prefix, "::Generator")});
return absl::make_unique<Iterator>(Iterator::Params{
this, name_utils::IteratorPrefix(kDatasetType, prefix)});
}

const DataTypeVector& output_dtypes() const override { return output_types_; }
Expand All @@ -55,7 +70,9 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
return output_shapes_;
}

string DebugString() const override { return "GeneratorDatasetOp::Dataset"; }
string DebugString() const override {
return name_utils::DatasetDebugString(kDatasetType);
}

protected:
Status AsGraphDefInternal(SerializationContext* ctx,
Expand Down Expand Up @@ -155,33 +172,31 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {

GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx)
: DatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, "init_func", /*params=*/{},
OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kInitFunc, /*params=*/{},
&init_func_metadata_));
OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, "next_func", /*params=*/{},
OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, kNextFunc, /*params=*/{},
&next_func_metadata_));
OP_REQUIRES_OK(ctx,
FunctionMetadata::Create(ctx, "finalize_func", /*params=*/{},
FunctionMetadata::Create(ctx, kFinalizeFunc, /*params=*/{},
&finalize_func_metadata_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
}

void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx,
DatasetBase** output) {
std::unique_ptr<CapturedFunction> init_func;
OP_REQUIRES_OK(ctx,
CapturedFunction::Create(ctx, init_func_metadata_,
"init_func_other_args", &init_func));
OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, init_func_metadata_,
kInitFuncOtherArgs, &init_func));

std::unique_ptr<CapturedFunction> next_func;
OP_REQUIRES_OK(ctx,
CapturedFunction::Create(ctx, next_func_metadata_,
"next_func_other_args", &next_func));
OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, next_func_metadata_,
kNextFuncOtherArgs, &next_func));

std::unique_ptr<CapturedFunction> finalize_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, finalize_func_metadata_,
"finalize_func_other_args",
&finalize_func));
OP_REQUIRES_OK(
ctx, CapturedFunction::Create(ctx, finalize_func_metadata_,
kFinalizeFuncOtherArgs, &finalize_func));

*output =
new Dataset(ctx, std::move(init_func), std::move(next_func),
Expand Down
16 changes: 16 additions & 0 deletions tensorflow/core/kernels/data/generator_dataset_op.h
Expand Up @@ -24,6 +24,22 @@ namespace data {

class GeneratorDatasetOp : public DatasetOpKernel {
public:
static constexpr const char* const kDatasetType = "Generator";
static constexpr const char* const kInitFuncOtherArgs =
"init_func_other_args";
static constexpr const char* const kNextFuncOtherArgs =
"next_func_other_args";
static constexpr const char* const kFinalizeFuncOtherArgs =
"finalize_func_other_args";
static constexpr const char* const kInitFunc = "init_func";
static constexpr const char* const kNextFunc = "next_func";
static constexpr const char* const kFinalizeFunc = "finalize_func";
static constexpr const char* const kTinitFuncArgs = "Tinit_func_args";
static constexpr const char* const kTnextFuncArgs = "Tnext_func_args";
static constexpr const char* const kTfinalizeFuncArgs = "Tfinalize_func_args";
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";

explicit GeneratorDatasetOp(OpKernelConstruction* ctx);

void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override;
Expand Down