Skip to content

Commit

Permalink
Add the tests for BatchDatasetOp
Browse files Browse the repository at this point in the history
  • Loading branch information
feihugis committed Apr 16, 2019
1 parent dac2bf1 commit b22ebf9
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
20 changes: 20 additions & 0 deletions tensorflow/core/kernels/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,26 @@ tf_kernel_library(
],
)

tf_cc_test(
name = "batch_dataset_op_test",
size = "small",
srcs = ["batch_dataset_op_test.cc"],
deps = [
":batch_dataset_op",
":dataset_test_base",
":dataset_utils",
":iterator_ops",
":range_dataset_op",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)

tf_kernel_library(
name = "shard_dataset_op",
srcs = ["shard_dataset_op.cc"],
Expand Down
10 changes: 8 additions & 2 deletions tensorflow/core/kernels/data/batch_dataset_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,14 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
Node* drop_remainder = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder));
TF_RETURN_IF_ERROR(b->AddDataset(
this, {input_graph_node, batch_size, drop_remainder}, output));
if (type_string() == "BatchDataset") {
TF_RETURN_IF_ERROR(
b->AddDataset(this, {input_graph_node, batch_size}, output));
} else {
TF_RETURN_IF_ERROR(b->AddDataset(
this, {input_graph_node, batch_size, drop_remainder}, output));
}

return Status::OK();
}

Expand Down

0 comments on commit b22ebf9

Please sign in to comment.