Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test for RepeatDataset #26415

Merged
merged 4 commits into from Mar 12, 2019
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
77 changes: 38 additions & 39 deletions tensorflow/core/kernels/data/repeat_dataset_op_test.cc
Expand Up @@ -49,7 +49,7 @@ class RepeatDatasetOpTest : public DatasetOpsTestBase {
NodeDef node_def_;
};

struct TestParam {
struct TestCase {
std::vector<Tensor> input_tensors;
int64 count;
std::vector<Tensor> expected_outputs;
Expand All @@ -59,8 +59,7 @@ struct TestParam {
std::vector<int> breakpoints;
};

// Test case 1: finite repetition.
TestParam TestCase1() {
TestCase FiniteRepeatTestCase() {
return {
/*input_tensors*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 2}, {1, 2, 3, 4}),
Expand All @@ -82,8 +81,7 @@ TestParam TestCase1() {
/*breakpoints*/ {0, 1, 3}};
}

// Test case 2: empty repetition.
TestParam TestCase2() {
TestCase EmptyRepeatTestCase() {
return {
/*input_tensors*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 2}, {1, 2, 3, 4}),
Expand All @@ -98,8 +96,7 @@ TestParam TestCase2() {
/*breakpoints*/ {0, 1, 3}};
}

// Test case 3: infinite repetition.
TestParam TestCase3() {
TestCase ForeverRepeatTestCase() {
return {/*input_tensors*/
{DatasetOpsTestBase::CreateTensor<int64>(TensorShape{2, 1}, {1, 2})},
/*count*/ -1,
Expand All @@ -120,6 +117,8 @@ class RepeatDatasetOpTestHelper : public RepeatDatasetOpTest {
if (dataset_) dataset_->Unref();
}

using DatasetOpsTestBase::CreateDataset;

protected:
// Creates `TensorSliceDataset` variant tensor from the input vector of
// tensors.
Expand All @@ -133,7 +132,7 @@ class RepeatDatasetOpTestHelper : public RepeatDatasetOpTest {
return Status::OK();
}

Status CreateDatasetFromTestCase(const TestParam &test_case) {
Status CreateDataset(const TestCase &test_case) {
Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({}));
std::vector<Tensor> input_tensors = test_case.input_tensors;
TF_RETURN_IF_ERROR(CreateTensorSliceDatasetTensor(
Expand All @@ -142,26 +141,26 @@ class RepeatDatasetOpTestHelper : public RepeatDatasetOpTest {
Tensor count = CreateTensor<int64>(TensorShape{}, {test_case.count});
inputs.emplace_back(&tensor_slice_dataset_tensor);
inputs.emplace_back(&count);
std::unique_ptr<OpKernel> dataset_kernel;
TF_RETURN_IF_ERROR(CreateRepeatDatasetKernel(
test_case.expected_output_dtypes, test_case.expected_output_shapes,
&dataset_kernel_));
TF_RETURN_IF_ERROR(CreateRepeatDatasetContext(
dataset_kernel_.get(), &inputs, &dataset_kernel_ctx_));
TF_RETURN_IF_ERROR(CreateDataset(dataset_kernel_.get(),
&dataset_kernel));
TF_RETURN_IF_ERROR(CreateRepeatDatasetContext(dataset_kernel.get(), &inputs,
&dataset_kernel_ctx_));
TF_RETURN_IF_ERROR(CreateDataset(dataset_kernel.get(),
dataset_kernel_ctx_.get(), &dataset_));
return Status::OK();
}

Status CreateIteratorFromTestCase(const TestParam &test_case) {
TF_RETURN_IF_ERROR(CreateDatasetFromTestCase(test_case));
Status CreateIterator(const TestCase &test_case) {
TF_RETURN_IF_ERROR(CreateDataset(test_case));
TF_RETURN_IF_ERROR(
CreateIteratorContext(dataset_kernel_ctx_.get(), &iterator_ctx_));
TF_RETURN_IF_ERROR(
dataset_->MakeIterator(iterator_ctx_.get(), "Iterator", &iterator_));
return Status::OK();
}

std::unique_ptr<OpKernel> dataset_kernel_;
std::unique_ptr<OpKernelContext> dataset_kernel_ctx_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My comment was meant for all of the internal state (dataset_kernel_, dataset_kernel_ctx_, and iterator_ctx_).

While writing the code this way makes it less verbose, it makes it harder to understand what is going on in case someone would like to modify the code.

DatasetBase *dataset_ = nullptr; // owned by this class.
std::unique_ptr<IteratorContext> iterator_ctx_;
Expand All @@ -170,22 +169,22 @@ class RepeatDatasetOpTestHelper : public RepeatDatasetOpTest {

class ParameterizedDatasetTest
: public RepeatDatasetOpTestHelper,
public ::testing::WithParamInterface<TestParam> {};
public ::testing::WithParamInterface<TestCase> {};

TEST_P(ParameterizedDatasetTest, GetNext) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
TF_ASSERT_OK(CreateIteratorFromTestCase(test_case));
const TestCase &test_case = GetParam();
TF_ASSERT_OK(CreateIterator(test_case));

auto expected_outputs_it = test_case.expected_outputs.begin();
bool end_of_sequence = false;
std::vector<Tensor> out_tensors;

if (test_case.count < 0) {
int fake_infinite_repetition = 100;
while (fake_infinite_repetition > 0) {
// We test only a finite number of steps of the infinite sequence.
for (int i = 0; i < 100; ++i) {
TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors,
&end_of_sequence));
for (const auto &tensor : out_tensors) {
Expand All @@ -198,7 +197,6 @@ TEST_P(ParameterizedDatasetTest, GetNext) {
expected_outputs_it = test_case.expected_outputs.begin();
}
}
fake_infinite_repetition--;
}
EXPECT_FALSE(end_of_sequence);
} else {
Expand All @@ -221,7 +219,7 @@ TEST_F(RepeatDatasetOpTestHelper, DatasetName) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateDatasetFromTestCase(TestCase1()));
TF_ASSERT_OK(CreateDataset(FiniteRepeatTestCase()));

EXPECT_EQ(dataset_->type_string(), kOpName);
}
Expand All @@ -230,8 +228,8 @@ TEST_P(ParameterizedDatasetTest, DatasetOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
TF_ASSERT_OK(CreateDatasetFromTestCase(test_case));
const TestCase &test_case = GetParam();
TF_ASSERT_OK(CreateDataset(test_case));
TF_EXPECT_OK(VerifyTypesMatch(dataset_->output_dtypes(),
test_case.expected_output_dtypes));
}
Expand All @@ -240,8 +238,8 @@ TEST_P(ParameterizedDatasetTest, DatasetOutputShapes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
TF_ASSERT_OK(CreateDatasetFromTestCase(test_case));
const TestCase &test_case = GetParam();
TF_ASSERT_OK(CreateDataset(test_case));
TF_EXPECT_OK(VerifyShapesCompatible(dataset_->output_shapes(),
test_case.expected_output_shapes));
}
Expand All @@ -250,8 +248,8 @@ TEST_P(ParameterizedDatasetTest, Cardinality) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
TF_ASSERT_OK(CreateDatasetFromTestCase(test_case));
const TestCase &test_case = GetParam();
TF_ASSERT_OK(CreateDataset(test_case));

EXPECT_EQ(dataset_->Cardinality(), GetParam().expected_cardinality);
}
Expand All @@ -260,7 +258,7 @@ TEST_F(RepeatDatasetOpTestHelper, DatasetSave) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateDatasetFromTestCase(TestCase1()));
TF_ASSERT_OK(CreateDataset(FiniteRepeatTestCase()));

std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
Expand All @@ -274,8 +272,8 @@ TEST_P(ParameterizedDatasetTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
TF_ASSERT_OK(CreateIteratorFromTestCase(test_case));
const TestCase &test_case = GetParam();
TF_ASSERT_OK(CreateIterator(test_case));
TF_EXPECT_OK(VerifyTypesMatch(iterator_->output_dtypes(),
test_case.expected_output_dtypes));
}
Expand All @@ -284,8 +282,8 @@ TEST_P(ParameterizedDatasetTest, IteratorOutputShapes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
TF_ASSERT_OK(CreateIteratorFromTestCase(test_case));
const TestCase &test_case = GetParam();
TF_ASSERT_OK(CreateIterator(test_case));
TF_EXPECT_OK(VerifyShapesCompatible(iterator_->output_shapes(),
test_case.expected_output_shapes));
}
Expand All @@ -294,8 +292,8 @@ TEST_P(ParameterizedDatasetTest, IteratorOutputPrefix) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
TF_ASSERT_OK(CreateIteratorFromTestCase(test_case));
const TestCase &test_case = GetParam();
TF_ASSERT_OK(CreateIterator(test_case));
if (test_case.count < 0) {
EXPECT_EQ(iterator_->prefix(), "Iterator::ForeverRepeat");
} else if (test_case.count == 0) {
Expand All @@ -309,9 +307,9 @@ TEST_P(ParameterizedDatasetTest, Roundtrip) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
const TestParam &test_case = GetParam();
const TestCase &test_case = GetParam();
auto expected_outputs_it = test_case.expected_outputs.begin();
TF_ASSERT_OK(CreateIteratorFromTestCase(test_case));
TF_ASSERT_OK(CreateIterator(test_case));

std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
Expand Down Expand Up @@ -359,8 +357,9 @@ TEST_P(ParameterizedDatasetTest, Roundtrip) {
}

INSTANTIATE_TEST_SUITE_P(RepeatDatasetOpTest, ParameterizedDatasetTest,
::testing::ValuesIn(std::vector<TestParam>(
{TestCase1(), TestCase2(), TestCase3()})));
::testing::ValuesIn(std::vector<TestCase>(
{FiniteRepeatTestCase(), EmptyRepeatTestCase(),
ForeverRepeatTestCase()})));

} // namespace
} // namespace data
Expand Down