diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op_test.cc b/tensorflow/core/kernels/data/concatenate_dataset_op_test.cc index 1885c50c298cb6..248d2981ffccbb 100644 --- a/tensorflow/core/kernels/data/concatenate_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/concatenate_dataset_op_test.cc @@ -48,10 +48,10 @@ class ConcatenateDatasetOpTest : public DatasetOpsTestBase { const DataTypeVector &output_types, const std::vector &output_shapes, std::unique_ptr *op_kernel) { - node_def_ = test::function::NDef( + NodeDef node_def = test::function::NDef( kNodeName, kOpName, {"input_dataset", "another_dataset"}, {{"output_types", output_types}, {"output_shapes", output_shapes}}); - TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, op_kernel)); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel)); return Status::OK(); } @@ -64,12 +64,9 @@ class ConcatenateDatasetOpTest : public DatasetOpsTestBase { TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); return Status::OK(); } - - private: - NodeDef node_def_; }; -struct TestParam { +struct TestCase { std::vector> input_tensors; std::vector expected_outputs; DataTypeVector expected_output_dtypes; @@ -77,8 +74,9 @@ struct TestParam { int64 expected_cardinality; std::vector breakpoints; }; -TestParam TestCase1() { - // Test case 1: same shape. + +// Test case 1: same shape. +TestCase SameShapeTestCase() { return {/*input_tensors*/ {{DatasetOpsTestBase::CreateTensor(TensorShape{2, 2}, {1, 2, 3, 4}), @@ -104,8 +102,8 @@ TestParam TestCase1() { /*breakpoints*/ {0, 2, 5}}; } -TestParam TestCase2() { - // Test case 2: different shape. +// Test case 2: different shape. +TestCase DifferentShapeTestCase() { return { /*input_tensors*/ {{DatasetOpsTestBase::CreateTensor(TensorShape{2, 3}, @@ -131,64 +129,59 @@ TestParam TestCase2() { /*breakpoints*/ {0, 2, 5}}; } -class ConcatenateDatasetOpTestHelper : public ConcatenateDatasetOpTest { - public: - ~ConcatenateDatasetOpTestHelper() override { - if (dataset_) dataset_->Unref(); - } - - protected: - Status CreateDatasetFromTestCase(const TestParam &test_case) { - std::vector tensor_slice_dataset_tensors; - TF_RETURN_IF_ERROR(CreateTensorSliceDatasetTensors( - test_case.input_tensors, &tensor_slice_dataset_tensors)); - gtl::InlinedVector inputs; - for (auto &tensor : tensor_slice_dataset_tensors) { - inputs.emplace_back(&tensor); - } - TF_RETURN_IF_ERROR(CreateConcatenateDatasetKernel( - test_case.expected_output_dtypes, test_case.expected_output_shapes, - &dataset_kernel_)); - TF_RETURN_IF_ERROR(CreateConcatenateDatasetContext( - dataset_kernel_.get(), &inputs, &dataset_kernel_ctx_)); - TF_RETURN_IF_ERROR(CreateDataset(dataset_kernel_.get(), - dataset_kernel_ctx_.get(), &dataset_)); - return Status::OK(); - } - - Status CreateIteratorFromTestCase(const TestParam &test_case) { - TF_RETURN_IF_ERROR(CreateDatasetFromTestCase(test_case)); - TF_RETURN_IF_ERROR( - CreateIteratorContext(dataset_kernel_ctx_.get(), &iterator_ctx_)); - TF_RETURN_IF_ERROR( - dataset_->MakeIterator(iterator_ctx_.get(), "Iterator", &iterator_)); - return Status::OK(); - } - - std::unique_ptr dataset_kernel_; - std::unique_ptr dataset_kernel_ctx_; - DatasetBase *dataset_ = nullptr; // owned by this class. - std::unique_ptr iterator_ctx_; - std::unique_ptr iterator_; -}; +// Test case 3: different dtypes +TestCase DifferentDtypeTestCase() { + return {/*input_tensors*/ {{DatasetOpsTestBase::CreateTensor( + TensorShape({2, 2}), {1, 2, 3, 4})}, + {DatasetOpsTestBase::CreateTensor( + TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}}, + /*expected_outputs*/ {}, + /*expected_output_dtypes*/ {DT_INT64}, + /*expected_output_shapes*/ {PartialTensorShape({2})}, + /*expected_cardinality*/ 0, + /*breakpoints*/ {}}; +} -class ParameterizedDatasetTest - : public ConcatenateDatasetOpTestHelper, - public ::testing::WithParamInterface {}; +class ParameterizedConcatenateDatasetOpTest + : public ConcatenateDatasetOpTest, + public ::testing::WithParamInterface {}; -TEST_P(ParameterizedDatasetTest, GetNext) { +TEST_P(ParameterizedConcatenateDatasetOpTest, GetNext) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - const TestParam &test_case = GetParam(); - TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); + + const TestCase &test_case = GetParam(); + std::vector tensor_slice_dataset_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors, + &tensor_slice_dataset_tensors)); + gtl::InlinedVector inputs; + for (auto &tensor : tensor_slice_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *concatenate_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &concatenate_dataset)); + core::ScopedUnref scoped_unref(concatenate_dataset); + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(concatenate_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); auto expected_outputs_it = test_case.expected_outputs.begin(); bool end_of_sequence = false; std::vector out_tensors; while (!end_of_sequence) { - TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors, - &end_of_sequence)); + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence)); if (!end_of_sequence) { for (const auto &tensor : out_tensors) { EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); @@ -200,113 +193,334 @@ TEST_P(ParameterizedDatasetTest, GetNext) { EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); } -TEST_F(ConcatenateDatasetOpTestHelper, DifferentDtypes) { +TEST_F(ConcatenateDatasetOpTest, DifferentDtypes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - TestParam test_case_with_different_dtypes = { - /*input_tensors*/ { - {CreateTensor(TensorShape({2, 2}), {1, 2, 3, 4})}, - {CreateTensor(TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}}, - /*expected_outputs*/ {}, - /*expected_output_dtypes*/ {DT_INT64}, - /*expected_output_shapes*/ {PartialTensorShape({2})}, - /*expected_cardinality*/ 0, - /*breakpoints*/ {}}; - - EXPECT_EQ(CreateDatasetFromTestCase(test_case_with_different_dtypes).code(), + const TestCase &test_case = DifferentDtypeTestCase(); + std::vector tensor_slice_dataset_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors, + &tensor_slice_dataset_tensors)); + gtl::InlinedVector inputs; + for (auto &tensor : tensor_slice_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *concatenate_dataset; + EXPECT_EQ(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &concatenate_dataset) + .code(), tensorflow::error::INVALID_ARGUMENT); } -TEST_F(ConcatenateDatasetOpTestHelper, DatasetName) { +TEST_F(ConcatenateDatasetOpTest, DatasetNodeName) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - TF_ASSERT_OK(CreateDatasetFromTestCase(TestCase1())); - EXPECT_EQ(dataset_->type_string(), kOpName); + const TestCase &test_case = SameShapeTestCase(); + std::vector tensor_slice_dataset_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors, + &tensor_slice_dataset_tensors)); + gtl::InlinedVector inputs; + for (auto &tensor : tensor_slice_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *concatenate_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &concatenate_dataset)); + core::ScopedUnref scoped_unref(concatenate_dataset); + + EXPECT_EQ(concatenate_dataset->node_name(), kNodeName); } -TEST_P(ParameterizedDatasetTest, DatasetOutputDtypes) { +TEST_F(ConcatenateDatasetOpTest, DatasetTypeString) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - const TestParam &test_case = GetParam(); - TF_ASSERT_OK(CreateDatasetFromTestCase(test_case)); - TF_EXPECT_OK(VerifyTypesMatch(dataset_->output_dtypes(), + + const TestCase &test_case = SameShapeTestCase(); + std::vector tensor_slice_dataset_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors, + &tensor_slice_dataset_tensors)); + gtl::InlinedVector inputs; + for (auto &tensor : tensor_slice_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *concatenate_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &concatenate_dataset)); + core::ScopedUnref scoped_unref(concatenate_dataset); + + EXPECT_EQ(concatenate_dataset->type_string(), kOpName); +} + +TEST_P(ParameterizedConcatenateDatasetOpTest, DatasetOutputDtypes) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = GetParam(); + std::vector tensor_slice_dataset_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors, + &tensor_slice_dataset_tensors)); + gtl::InlinedVector inputs; + for (auto &tensor : tensor_slice_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *concatenate_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &concatenate_dataset)); + core::ScopedUnref scoped_unref(concatenate_dataset); + TF_EXPECT_OK(VerifyTypesMatch(concatenate_dataset->output_dtypes(), test_case.expected_output_dtypes)); } -TEST_P(ParameterizedDatasetTest, DatasetOutputShapes) { +TEST_P(ParameterizedConcatenateDatasetOpTest, DatasetOutputShapes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - const TestParam &test_case = GetParam(); - TF_ASSERT_OK(CreateDatasetFromTestCase(test_case)); - TF_EXPECT_OK(VerifyShapesCompatible(dataset_->output_shapes(), + + const TestCase &test_case = GetParam(); + std::vector tensor_slice_dataset_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors, + &tensor_slice_dataset_tensors)); + gtl::InlinedVector inputs; + for (auto &tensor : tensor_slice_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *concatenate_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &concatenate_dataset)); + core::ScopedUnref scoped_unref(concatenate_dataset); + + TF_EXPECT_OK(VerifyShapesCompatible(concatenate_dataset->output_shapes(), test_case.expected_output_shapes)); } -TEST_P(ParameterizedDatasetTest, Cardinality) { +TEST_P(ParameterizedConcatenateDatasetOpTest, Cardinality) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - const TestParam &test_case = GetParam(); - TF_ASSERT_OK(CreateDatasetFromTestCase(test_case)); - EXPECT_EQ(dataset_->Cardinality(), GetParam().expected_cardinality); + const TestCase &test_case = GetParam(); + std::vector tensor_slice_dataset_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors, + &tensor_slice_dataset_tensors)); + gtl::InlinedVector inputs; + for (auto &tensor : tensor_slice_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *concatenate_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &concatenate_dataset)); + core::ScopedUnref scoped_unref(concatenate_dataset); + + EXPECT_EQ(concatenate_dataset->Cardinality(), test_case.expected_cardinality); } -TEST_F(ConcatenateDatasetOpTestHelper, DatasetSave) { +TEST_F(ConcatenateDatasetOpTest, DatasetSave) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - TF_ASSERT_OK(CreateDatasetFromTestCase(TestCase1())); + + const TestCase &test_case = SameShapeTestCase(); + std::vector tensor_slice_dataset_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors, + &tensor_slice_dataset_tensors)); + gtl::InlinedVector inputs; + for (auto &tensor : tensor_slice_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *concatenate_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &concatenate_dataset)); + core::ScopedUnref scoped_unref(concatenate_dataset); std::unique_ptr serialization_ctx; TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); VariantTensorData data; VariantTensorDataWriter writer(&data); - TF_ASSERT_OK(dataset_->Save(serialization_ctx.get(), &writer)); + TF_ASSERT_OK(concatenate_dataset->Save(serialization_ctx.get(), &writer)); TF_ASSERT_OK(writer.Flush()); } -TEST_P(ParameterizedDatasetTest, IteratorOutputDtypes) { +TEST_P(ParameterizedConcatenateDatasetOpTest, IteratorOutputDtypes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - const TestParam &test_case = GetParam(); - TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); - TF_EXPECT_OK(VerifyTypesMatch(iterator_->output_dtypes(), + + const TestCase &test_case = GetParam(); + std::vector tensor_slice_dataset_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors, + &tensor_slice_dataset_tensors)); + gtl::InlinedVector inputs; + for (auto &tensor : tensor_slice_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *concatenate_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &concatenate_dataset)); + core::ScopedUnref scoped_unref(concatenate_dataset); + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(concatenate_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); + + TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(), test_case.expected_output_dtypes)); } -TEST_P(ParameterizedDatasetTest, IteratorOutputShapes) { +TEST_P(ParameterizedConcatenateDatasetOpTest, IteratorOutputShapes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - const TestParam &test_case = GetParam(); - TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); - TF_EXPECT_OK(VerifyShapesCompatible(iterator_->output_shapes(), + + const TestCase &test_case = GetParam(); + std::vector tensor_slice_dataset_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors, + &tensor_slice_dataset_tensors)); + gtl::InlinedVector inputs; + for (auto &tensor : tensor_slice_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *concatenate_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &concatenate_dataset)); + core::ScopedUnref scoped_unref(concatenate_dataset); + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(concatenate_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), test_case.expected_output_shapes)); } -TEST_F(ConcatenateDatasetOpTestHelper, IteratorOutputPrefix) { +TEST_F(ConcatenateDatasetOpTest, IteratorOutputPrefix) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - TF_ASSERT_OK(CreateIteratorFromTestCase(TestCase1())); - EXPECT_EQ(iterator_->prefix(), "Iterator::Concatenate"); + + const TestCase &test_case = SameShapeTestCase(); + std::vector tensor_slice_dataset_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors, + &tensor_slice_dataset_tensors)); + gtl::InlinedVector inputs; + for (auto &tensor : tensor_slice_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *concatenate_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &concatenate_dataset)); + core::ScopedUnref scoped_unref(concatenate_dataset); + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(concatenate_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); + EXPECT_EQ(iterator->prefix(), "Iterator::Concatenate"); } -TEST_P(ParameterizedDatasetTest, Roundtrip) { +TEST_P(ParameterizedConcatenateDatasetOpTest, Roundtrip) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - const TestParam &test_case = GetParam(); - auto expected_outputs_it = test_case.expected_outputs.begin(); - TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); + const TestCase &test_case = GetParam(); + std::vector tensor_slice_dataset_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensors(test_case.input_tensors, + &tensor_slice_dataset_tensors)); + gtl::InlinedVector inputs; + for (auto &tensor : tensor_slice_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateConcatenateDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateConcatenateDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *concatenate_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &concatenate_dataset)); + core::ScopedUnref scoped_unref(concatenate_dataset); + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK(concatenate_dataset->MakeIterator(iterator_ctx.get(), "Iterator", + &iterator)); std::unique_ptr serialization_ctx; TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); @@ -314,18 +528,19 @@ TEST_P(ParameterizedDatasetTest, Roundtrip) { bool end_of_sequence = false; std::vector out_tensors; int cur_iteration = 0; + auto expected_outputs_it = test_case.expected_outputs.begin(); std::vector breakpoints = GetParam().breakpoints; for (int breakpoint : breakpoints) { VariantTensorData data; VariantTensorDataWriter writer(&data); - TF_EXPECT_OK(iterator_->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); TF_EXPECT_OK(writer.Flush()); VariantTensorDataReader reader(&data); - TF_EXPECT_OK(iterator_->Restore(iterator_ctx_.get(), &reader)); + TF_EXPECT_OK(iterator->Restore(iterator_ctx.get(), &reader)); while (cur_iteration < breakpoint) { - TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors, - &end_of_sequence)); + TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors, + &end_of_sequence)); if (!end_of_sequence) { for (auto &tensor : out_tensors) { EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); @@ -336,7 +551,7 @@ TEST_P(ParameterizedDatasetTest, Roundtrip) { cur_iteration++; } - if (breakpoint >= dataset_->Cardinality()) { + if (breakpoint >= concatenate_dataset->Cardinality()) { EXPECT_TRUE(end_of_sequence); EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); } else { @@ -345,9 +560,10 @@ TEST_P(ParameterizedDatasetTest, Roundtrip) { } } -INSTANTIATE_TEST_SUITE_P( - ConcatenateDatasetOpTest, ParameterizedDatasetTest, - ::testing::ValuesIn(std::vector({TestCase1(), TestCase2()}))); +INSTANTIATE_TEST_SUITE_P(ConcatenateDatasetOpTest, + ParameterizedConcatenateDatasetOpTest, + ::testing::ValuesIn(std::vector( + {SameShapeTestCase(), DifferentShapeTestCase()}))); } // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/repeat_dataset_op_test.cc b/tensorflow/core/kernels/data/repeat_dataset_op_test.cc index 61f314c5d1c524..a4db73771f5e75 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op_test.cc @@ -41,10 +41,10 @@ class RepeatDatasetOpTest : public DatasetOpsTestBase { const DataTypeVector &output_types, const std::vector &output_shapes, std::unique_ptr *op_kernel) { - node_def_ = test::function::NDef( + NodeDef node_def = test::function::NDef( kNodeName, kOpName, {"input_dataset", "count"}, {{"output_types", output_types}, {"output_shapes", output_shapes}}); - TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, op_kernel)); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel)); return Status::OK(); } @@ -56,9 +56,6 @@ class RepeatDatasetOpTest : public DatasetOpsTestBase { TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); return Status::OK(); } - - private: - NodeDef node_def_; }; struct TestCase { @@ -123,11 +120,11 @@ TestCase ForeverRepeatTestCase() { /*breakpoints*/ {0, 1, 3}}; } -class ParameterizedDatasetTest +class ParameterizedDatasetOpTest : public RepeatDatasetOpTest, public ::testing::WithParamInterface {}; -TEST_P(ParameterizedDatasetTest, GetNext) { +TEST_P(ParameterizedDatasetOpTest, GetNext) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -198,7 +195,38 @@ TEST_P(ParameterizedDatasetTest, GetNext) { } } -TEST_F(RepeatDatasetOpTest, DatasetName) { +TEST_F(RepeatDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = FiniteRepeatTestCase(); + Tensor tensor_slice_dataset_tensor(DT_VARIANT, TensorShape({})); + std::vector inputs_for_tensor_slice_dataset = test_case.input_tensors; + TF_ASSERT_OK(CreateTensorSliceDatasetTensor(&inputs_for_tensor_slice_dataset, + &tensor_slice_dataset_tensor)); + Tensor count = CreateTensor(TensorShape{}, {test_case.count}); + gtl::InlinedVector inputs_for_repeat_dataset; + inputs_for_repeat_dataset.emplace_back(&tensor_slice_dataset_tensor); + inputs_for_repeat_dataset.emplace_back(&count); + + std::unique_ptr repeat_dataset_kernel; + TF_ASSERT_OK(CreateRepeatDatasetKernel(test_case.expected_output_dtypes, + test_case.expected_output_shapes, + &repeat_dataset_kernel)); + std::unique_ptr repeat_dataset_context; + TF_ASSERT_OK(CreateRepeatDatasetContext(repeat_dataset_kernel.get(), + &inputs_for_repeat_dataset, + &repeat_dataset_context)); + DatasetBase *repeat_dataset; + TF_ASSERT_OK(CreateDataset(repeat_dataset_kernel.get(), + repeat_dataset_context.get(), &repeat_dataset)); + core::ScopedUnref scoped_unref(repeat_dataset); + + EXPECT_EQ(repeat_dataset->node_name(), kNodeName); +} + +TEST_F(RepeatDatasetOpTest, DatasetTypeString) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -229,7 +257,7 @@ TEST_F(RepeatDatasetOpTest, DatasetName) { EXPECT_EQ(repeat_dataset->type_string(), kOpName); } -TEST_P(ParameterizedDatasetTest, DatasetOutputDtypes) { +TEST_P(ParameterizedDatasetOpTest, DatasetOutputDtypes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -259,7 +287,7 @@ TEST_P(ParameterizedDatasetTest, DatasetOutputDtypes) { test_case.expected_output_dtypes)); } -TEST_P(ParameterizedDatasetTest, DatasetOutputShapes) { +TEST_P(ParameterizedDatasetOpTest, DatasetOutputShapes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -289,7 +317,7 @@ TEST_P(ParameterizedDatasetTest, DatasetOutputShapes) { test_case.expected_output_shapes)); } -TEST_P(ParameterizedDatasetTest, Cardinality) { +TEST_P(ParameterizedDatasetOpTest, Cardinality) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -354,7 +382,7 @@ TEST_F(RepeatDatasetOpTest, DatasetSave) { TF_ASSERT_OK(writer.Flush()); } -TEST_P(ParameterizedDatasetTest, IteratorOutputDtypes) { +TEST_P(ParameterizedDatasetOpTest, IteratorOutputDtypes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -391,7 +419,7 @@ TEST_P(ParameterizedDatasetTest, IteratorOutputDtypes) { test_case.expected_output_dtypes)); } -TEST_P(ParameterizedDatasetTest, IteratorOutputShapes) { +TEST_P(ParameterizedDatasetOpTest, IteratorOutputShapes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -428,7 +456,7 @@ TEST_P(ParameterizedDatasetTest, IteratorOutputShapes) { test_case.expected_output_shapes)); } -TEST_P(ParameterizedDatasetTest, IteratorOutputPrefix) { +TEST_P(ParameterizedDatasetOpTest, IteratorOutputPrefix) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -470,7 +498,7 @@ TEST_P(ParameterizedDatasetTest, IteratorOutputPrefix) { } } -TEST_P(ParameterizedDatasetTest, Roundtrip) { +TEST_P(ParameterizedDatasetOpTest, Roundtrip) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); @@ -550,7 +578,7 @@ TEST_P(ParameterizedDatasetTest, Roundtrip) { } } -INSTANTIATE_TEST_SUITE_P(RepeatDatasetOpTest, ParameterizedDatasetTest, +INSTANTIATE_TEST_SUITE_P(RepeatDatasetOpTest, ParameterizedDatasetOpTest, ::testing::ValuesIn(std::vector( {FiniteRepeatTestCase(), EmptyRepeatTestCase(), ForeverRepeatTestCase()}))); diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op_test.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op_test.cc index c5421aeb37a07c..91aaaca6d03144 100644 --- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op_test.cc @@ -13,19 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/function_testlib.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/partial_tensor_shape.h" -#include "tensorflow/core/framework/variant.h" -#include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/kernels/data/dataset_test_base.h" -#include "tensorflow/core/kernels/data/dataset_utils.h" -#include "tensorflow/core/kernels/data/iterator_ops.h" -#include "tensorflow/core/kernels/ops_testutil.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -39,10 +27,10 @@ class SparseTensorSliceDatasetOpTest : public DatasetOpsTestBase { // Creates a new SparseTensorSliceDataset op kernel. Status CreateSparseTensorSliceDatasetKernel( DataType tvalues, std::unique_ptr *op_kernel) { - node_def_ = test::function::NDef(kNodeName, kOpName, - {"indices", "values", "dense_shape"}, - {{"Tvalues", tvalues}}); - TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, op_kernel)); + NodeDef node_def = test::function::NDef( + kNodeName, kOpName, {"indices", "values", "dense_shape"}, + {{"Tvalues", tvalues}}); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel)); return Status::OK(); } @@ -54,9 +42,6 @@ class SparseTensorSliceDatasetOpTest : public DatasetOpsTestBase { TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); return Status::OK(); } - - private: - NodeDef node_def_; }; struct SparseTensorParam { @@ -71,123 +56,180 @@ struct TestCase { std::vector breakpoints; }; -std::vector TestCases() { +TestCase TwoDimsTestCase() { return { - {{{DatasetOpsTestBase::CreateTensor({2, 2}, {0, 0, 1, 1})}, - {DatasetOpsTestBase::CreateTensor({2}, {888, 999})}, - {DatasetOpsTestBase::CreateTensor({2}, {2, 2})}}, - {{{DatasetOpsTestBase::CreateTensor({1, 1}, {0})}, - {DatasetOpsTestBase::CreateTensor({1}, {888})}, - {DatasetOpsTestBase::CreateTensor({1}, {2})}}, - {{DatasetOpsTestBase::CreateTensor({1, 1}, {1})}, - {DatasetOpsTestBase::CreateTensor({1}, {999})}, - {DatasetOpsTestBase::CreateTensor({1}, {2})}}}, - {0, 1, 2}}, // 2-D sparse tensor - {{{DatasetOpsTestBase::CreateTensor({2, 3}, {0, 0, 0, 1, 1, 1})}, - {DatasetOpsTestBase::CreateTensor({2}, {888.0, 999.0})}, - {DatasetOpsTestBase::CreateTensor({3}, {2, 2, 2})}}, - {{{DatasetOpsTestBase::CreateTensor({1, 2}, {0, 0})}, - {DatasetOpsTestBase::CreateTensor({1}, {888.0})}, - {DatasetOpsTestBase::CreateTensor({2}, {2, 2})}}, - {{DatasetOpsTestBase::CreateTensor({1, 2}, {1, 1})}, - {DatasetOpsTestBase::CreateTensor({1}, {999.0})}, - {DatasetOpsTestBase::CreateTensor({2}, {2, 2})}}}, - {0, 1, 2}}, // 3-D sparse tensor - {{{DatasetOpsTestBase::CreateTensor({2, 4}, - {0, 0, 0, 0, 1, 1, 1, 1})}, - {DatasetOpsTestBase::CreateTensor({2}, {"a", "b"})}, - {DatasetOpsTestBase::CreateTensor({4}, {3, 2, 2, 2})}}, - {{{DatasetOpsTestBase::CreateTensor({1, 3}, {0, 0, 0})}, - {DatasetOpsTestBase::CreateTensor({1}, {"a"})}, - {DatasetOpsTestBase::CreateTensor({3}, {2, 2, 2})}}, - {{DatasetOpsTestBase::CreateTensor({1, 3}, {1, 1, 1})}, - {DatasetOpsTestBase::CreateTensor({1}, {"b"})}, - {DatasetOpsTestBase::CreateTensor({3}, {2, 2, 2})}}, - {{DatasetOpsTestBase::CreateTensor({0, 3}, {})}, - {DatasetOpsTestBase::CreateTensor({0}, {})}, - {DatasetOpsTestBase::CreateTensor({3}, {2, 2, 2})}}}, - {0, 1, 3}}, // 4-D sparse tensor - {{{DatasetOpsTestBase::CreateTensor( - {2, 5}, {0, 0, 0, 0, 0, 1, 1, 1, 1, 1})}, - {DatasetOpsTestBase::CreateTensor({2}, {888, 999})}, - {DatasetOpsTestBase::CreateTensor({5}, {3, 2, 2, 2, 2})}}, - {{{DatasetOpsTestBase::CreateTensor({1, 4}, {0, 0, 0, 0})}, - {DatasetOpsTestBase::CreateTensor({1}, {888})}, - {DatasetOpsTestBase::CreateTensor({4}, {2, 2, 2, 2})}}, - {{DatasetOpsTestBase::CreateTensor({1, 4}, {1, 1, 1, 1})}, - {DatasetOpsTestBase::CreateTensor({1}, {999})}, - {DatasetOpsTestBase::CreateTensor({4}, {2, 2, 2, 2})}}, - {{DatasetOpsTestBase::CreateTensor({0, 4}, {})}, - {DatasetOpsTestBase::CreateTensor({0}, {})}, - {DatasetOpsTestBase::CreateTensor({4}, {2, 2, 2, 2})}}}, - {0, 1, 3}} // 5-D sparse tensor - }; + /*input_sparse_tensor*/ + {/*indices*/ DatasetOpsTestBase::CreateTensor({2, 2}, + {0, 0, 1, 1}), + /*values*/ DatasetOpsTestBase::CreateTensor({2}, {888, 999}), + /*dense_shape*/ DatasetOpsTestBase::CreateTensor({2}, {2, 2})}, + /*expected_outputs*/ + {{/*indices*/ DatasetOpsTestBase::CreateTensor({1, 1}, {0}), + /*values*/ DatasetOpsTestBase::CreateTensor({1}, {888}), + /*dense_shape*/ DatasetOpsTestBase::CreateTensor({1}, {2})}, + {/*indices*/ DatasetOpsTestBase::CreateTensor({1, 1}, {1}), + /*values*/ DatasetOpsTestBase::CreateTensor({1}, {999}), + /*dense_shape*/ DatasetOpsTestBase::CreateTensor({1}, {2})}}, + /*breakpoints*/ {0, 1, 2}}; +} + +TestCase ThreeDimsTestCase() { + return { + /*input_sparse_tensor*/ + {/*indices*/ DatasetOpsTestBase::CreateTensor({2, 3}, + {0, 0, 0, 1, 1, 1}), + /*values*/ DatasetOpsTestBase::CreateTensor({2}, {888.0, 999.0}), + /*dense_shape*/ DatasetOpsTestBase::CreateTensor({3}, {2, 2, 2})}, + /*expected_outputs*/ + {{/*indices*/ DatasetOpsTestBase::CreateTensor({1, 2}, {0, 0}), + /*values*/ DatasetOpsTestBase::CreateTensor({1}, {888.0}), + /*dense_shape*/ DatasetOpsTestBase::CreateTensor({2}, {2, 2})}, + {{/*indices*/ DatasetOpsTestBase::CreateTensor({1, 2}, {1, 1})}, + {/*values*/ DatasetOpsTestBase::CreateTensor({1}, {999.0})}, + {/*dense_shape*/ DatasetOpsTestBase::CreateTensor({2}, + {2, 2})}}}, + /*breakpoints*/ {0, 1, 2}}; +} + +TestCase FourDimsTestCase() { + return { + /*input_sparse_tensor*/ + {/*indices*/ DatasetOpsTestBase::CreateTensor( + {2, 4}, {0, 0, 0, 0, 1, 1, 1, 1}), + /*values*/ DatasetOpsTestBase::CreateTensor({2}, {"a", "b"}), + /*dense_shape*/ + DatasetOpsTestBase::CreateTensor({4}, {3, 2, 2, 2})}, + /*expected_outputs*/ + {{/*indices*/ DatasetOpsTestBase::CreateTensor({1, 3}, {0, 0, 0}), + /*values*/ DatasetOpsTestBase::CreateTensor({1}, {"a"}), + /*dense_shape*/ + DatasetOpsTestBase::CreateTensor({3}, {2, 2, 2})}, + {/*indices*/ DatasetOpsTestBase::CreateTensor({1, 3}, {1, 1, 1}), + /*values*/ DatasetOpsTestBase::CreateTensor({1}, {"b"}), + /*dense_shape*/ + DatasetOpsTestBase::CreateTensor({3}, {2, 2, 2})}, + {/*indices*/ DatasetOpsTestBase::CreateTensor({0, 3}, {}), + /*values*/ DatasetOpsTestBase::CreateTensor({0}, {}), + /*dense_shape*/ + DatasetOpsTestBase::CreateTensor({3}, {2, 2, 2})}}, + /*breakpoints*/ {0, 1, 3}}; +} + +TestCase FiveDimsTestCase() { + return {/*input_sparse_tensor*/ + {/*indices*/ DatasetOpsTestBase::CreateTensor( + {2, 5}, {0, 0, 0, 0, 0, 1, 1, 1, 1, 1}), + /*values*/ DatasetOpsTestBase::CreateTensor({2}, {888, 999}), + /*dense_shape*/ + DatasetOpsTestBase::CreateTensor({5}, {3, 2, 2, 2, 2})}, + /*expected_outputs*/ + {{/*indices*/ DatasetOpsTestBase::CreateTensor({1, 4}, + {0, 0, 0, 0}), + /*values*/ DatasetOpsTestBase::CreateTensor({1}, {888}), + /*dense_shape*/ + DatasetOpsTestBase::CreateTensor({4}, {2, 2, 2, 2})}, + {/*indices*/ DatasetOpsTestBase::CreateTensor({1, 4}, + {1, 1, 1, 1}), + /*values*/ DatasetOpsTestBase::CreateTensor({1}, {999}), + /*dense_shape*/ + DatasetOpsTestBase::CreateTensor({4}, {2, 2, 2, 2})}, + {/*indices*/ DatasetOpsTestBase::CreateTensor({0, 4}, {}), + /*values*/ DatasetOpsTestBase::CreateTensor({0}, {}), + /*dense_shape*/ + DatasetOpsTestBase::CreateTensor({4}, {2, 2, 2, 2})}}, + /*breakpoints*/ {0, 1, 3}}; } -TEST_F(SparseTensorSliceDatasetOpTest, GetNext) { +class ParameterizedSparseTensorSliceDatasetOpTest + : public SparseTensorSliceDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, GetNext) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - for (auto &test_case : TestCases()) { - SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; - std::vector expected_outputs = - test_case.expected_outputs; - DataType tvalues = input_sparse_tensor.values.dtype(); - gtl::InlinedVector inputs = { - &input_sparse_tensor.indices, &input_sparse_tensor.values, - &input_sparse_tensor.dense_shape}; - - std::unique_ptr dataset_kernel; - TF_ASSERT_OK( - CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); - std::unique_ptr dataset_kernel_ctx; - TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( - dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); - DatasetBase *dataset; - TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), - &dataset)); - core::ScopedUnref scoped_unref(dataset); - - std::unique_ptr iterator_ctx; - TF_ASSERT_OK( - CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); - std::unique_ptr iterator; - TF_ASSERT_OK( - dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); - bool end_of_sequence = false; - std::vector out_tensors; - int cur_slice = 0; - while (!end_of_sequence) { - TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors, - &end_of_sequence)); - if (!end_of_sequence) { - TF_EXPECT_OK( - ExpectEqual(out_tensors[0], expected_outputs[cur_slice].indices)); - TF_EXPECT_OK( - ExpectEqual(out_tensors[1], expected_outputs[cur_slice].values)); - TF_EXPECT_OK(ExpectEqual(out_tensors[2], - expected_outputs[cur_slice].dense_shape)); - cur_slice++; - } + const TestCase &test_case = GetParam(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; + DataType tvalues = input_sparse_tensor.values.dtype(); + gtl::InlinedVector inputs = { + &input_sparse_tensor.indices, &input_sparse_tensor.values, + &input_sparse_tensor.dense_shape}; + + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( + dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); + DatasetBase *dataset; + TF_ASSERT_OK( + CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), &dataset)); + core::ScopedUnref scoped_unref(dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + bool end_of_sequence = false; + std::vector out_tensors; + auto expected_outputs_it = expected_outputs.begin(); + while (!end_of_sequence) { + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence)); + if (!end_of_sequence) { + TF_EXPECT_OK(ExpectEqual(out_tensors[0], expected_outputs_it->indices)); + TF_EXPECT_OK(ExpectEqual(out_tensors[1], expected_outputs_it->values)); + TF_EXPECT_OK( + ExpectEqual(out_tensors[2], expected_outputs_it->dense_shape)); + expected_outputs_it++; } } + EXPECT_EQ(expected_outputs_it, expected_outputs.end()); +} + +TEST_F(SparseTensorSliceDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = TwoDimsTestCase(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; + DataType tvalues = input_sparse_tensor.values.dtype(); + gtl::InlinedVector inputs = { + &input_sparse_tensor.indices, &input_sparse_tensor.values, + &input_sparse_tensor.dense_shape}; + + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( + dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); + DatasetBase *dataset; + TF_ASSERT_OK( + CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), &dataset)); + core::ScopedUnref scoped_unref(dataset); + + EXPECT_EQ(dataset->node_name(), kNodeName); } -TEST_F(SparseTensorSliceDatasetOpTest, DatasetName) { +TEST_F(SparseTensorSliceDatasetOpTest, DatasetTypeString) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - int N = 2; - const int NDIM = 2; - Tensor indices = CreateTensor(TensorShape({N, NDIM}), {0, 0, 1, 1}); - Tensor values = CreateTensor(TensorShape({N}), {888, 999}); - Tensor dense_shape = CreateTensor(TensorShape({NDIM}), {5, 5}); - gtl::InlinedVector inputs = {&indices, &values, &dense_shape}; + const TestCase &test_case = TwoDimsTestCase(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; + DataType tvalues = input_sparse_tensor.values.dtype(); + gtl::InlinedVector inputs = { + &input_sparse_tensor.indices, &input_sparse_tensor.values, + &input_sparse_tensor.dense_shape}; std::unique_ptr dataset_kernel; - TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(DT_INT32, &dataset_kernel)); + TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); std::unique_ptr dataset_kernel_ctx; TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); @@ -199,99 +241,90 @@ TEST_F(SparseTensorSliceDatasetOpTest, DatasetName) { EXPECT_EQ(dataset->type_string(), kOpName); } -TEST_F(SparseTensorSliceDatasetOpTest, DatasetOutputDtypes) { +TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, DatasetOutputDtypes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - for (auto &test_case : TestCases()) { - SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; - std::vector expected_outputs = - test_case.expected_outputs; - DataType tvalues = input_sparse_tensor.values.dtype(); - gtl::InlinedVector inputs = { - &input_sparse_tensor.indices, &input_sparse_tensor.values, - &input_sparse_tensor.dense_shape}; - - std::unique_ptr dataset_kernel; - TF_ASSERT_OK( - CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); - std::unique_ptr dataset_kernel_ctx; - TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( - dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); - DatasetBase *dataset; - TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), - &dataset)); - core::ScopedUnref scoped_unref(dataset); - - DataTypeVector expected_output_dtypes = { - expected_outputs[0].indices.dtype(), expected_outputs[0].values.dtype(), - expected_outputs[0].dense_shape.dtype()}; - TF_EXPECT_OK( - VerifyTypesMatch(dataset->output_dtypes(), expected_output_dtypes)); - } + const TestCase &test_case = GetParam(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; + DataType tvalues = input_sparse_tensor.values.dtype(); + gtl::InlinedVector inputs = { + &input_sparse_tensor.indices, &input_sparse_tensor.values, + &input_sparse_tensor.dense_shape}; + + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( + dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); + DatasetBase *dataset; + TF_ASSERT_OK( + CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), &dataset)); + core::ScopedUnref scoped_unref(dataset); + + DataTypeVector expected_output_dtypes = { + expected_outputs[0].indices.dtype(), expected_outputs[0].values.dtype(), + expected_outputs[0].dense_shape.dtype()}; + TF_EXPECT_OK( + VerifyTypesMatch(dataset->output_dtypes(), expected_output_dtypes)); } -TEST_F(SparseTensorSliceDatasetOpTest, DatasetOutputShapes) { +TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, DatasetOutputShapes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - for (auto &test_case : TestCases()) { - SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; - std::vector expected_outputs = - test_case.expected_outputs; - DataType tvalues = input_sparse_tensor.values.dtype(); - gtl::InlinedVector inputs = { - &input_sparse_tensor.indices, &input_sparse_tensor.values, - &input_sparse_tensor.dense_shape}; - - std::unique_ptr dataset_kernel; - TF_ASSERT_OK( - CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); - std::unique_ptr dataset_kernel_ctx; - TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( - dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); - DatasetBase *dataset; - TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), - &dataset)); - core::ScopedUnref scoped_unref(dataset); - - std::vector expected_output_shapes = { - expected_outputs[0].indices.shape(), expected_outputs[0].values.shape(), - expected_outputs[0].dense_shape.shape()}; - TF_EXPECT_OK(VerifyShapesCompatible(dataset->output_shapes(), - expected_output_shapes)); - } + const TestCase &test_case = GetParam(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; + DataType tvalues = input_sparse_tensor.values.dtype(); + gtl::InlinedVector inputs = { + &input_sparse_tensor.indices, &input_sparse_tensor.values, + &input_sparse_tensor.dense_shape}; + + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( + dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); + DatasetBase *dataset; + TF_ASSERT_OK( + CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), &dataset)); + core::ScopedUnref scoped_unref(dataset); + + std::vector expected_output_shapes = { + expected_outputs[0].indices.shape(), expected_outputs[0].values.shape(), + expected_outputs[0].dense_shape.shape()}; + TF_EXPECT_OK( + VerifyShapesCompatible(dataset->output_shapes(), expected_output_shapes)); } -TEST_F(SparseTensorSliceDatasetOpTest, Cardinality) { +TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, Cardinality) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - for (auto &test_case : TestCases()) { - SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; - std::vector expected_outputs = - test_case.expected_outputs; - DataType tvalues = input_sparse_tensor.values.dtype(); - gtl::InlinedVector inputs = { - &input_sparse_tensor.indices, &input_sparse_tensor.values, - &input_sparse_tensor.dense_shape}; - - std::unique_ptr dataset_kernel; - TF_ASSERT_OK( - CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); - std::unique_ptr dataset_kernel_ctx; - TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( - dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); - DatasetBase *dataset; - TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), - &dataset)); - core::ScopedUnref scoped_unref(dataset); - - EXPECT_EQ(dataset->Cardinality(), expected_outputs.size()); - } + const TestCase &test_case = TwoDimsTestCase(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; + DataType tvalues = input_sparse_tensor.values.dtype(); + gtl::InlinedVector inputs = { + &input_sparse_tensor.indices, &input_sparse_tensor.values, + &input_sparse_tensor.dense_shape}; + + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( + dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); + DatasetBase *dataset; + TF_ASSERT_OK( + CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), &dataset)); + core::ScopedUnref scoped_unref(dataset); + + EXPECT_EQ(dataset->Cardinality(), expected_outputs.size()); } TEST_F(SparseTensorSliceDatasetOpTest, DatasetSave) { @@ -299,15 +332,16 @@ TEST_F(SparseTensorSliceDatasetOpTest, DatasetSave) { TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - int N = 2; - const int NDIM = 2; - Tensor indices = CreateTensor(TensorShape({N, NDIM}), {0, 0, 1, 1}); - Tensor values = CreateTensor(TensorShape({N}), {888, 999}); - Tensor dense_shape = CreateTensor(TensorShape({NDIM}), {5, 5}); - gtl::InlinedVector inputs = {&indices, &values, &dense_shape}; + const TestCase &test_case = TwoDimsTestCase(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; + DataType tvalues = input_sparse_tensor.values.dtype(); + gtl::InlinedVector inputs = { + &input_sparse_tensor.indices, &input_sparse_tensor.values, + &input_sparse_tensor.dense_shape}; std::unique_ptr dataset_kernel; - TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(DT_INT32, &dataset_kernel)); + TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); std::unique_ptr dataset_kernel_ctx; TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); @@ -324,82 +358,74 @@ TEST_F(SparseTensorSliceDatasetOpTest, DatasetSave) { TF_ASSERT_OK(writer.Flush()); } -TEST_F(SparseTensorSliceDatasetOpTest, IteratorOutputDtypes) { +TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, IteratorOutputDtypes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - for (auto &test_case : TestCases()) { - SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; - std::vector expected_outputs = - test_case.expected_outputs; - DataType tvalues = input_sparse_tensor.values.dtype(); - gtl::InlinedVector inputs = { - &input_sparse_tensor.indices, &input_sparse_tensor.values, - &input_sparse_tensor.dense_shape}; - - std::unique_ptr dataset_kernel; - TF_ASSERT_OK( - CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); - std::unique_ptr dataset_kernel_ctx; - TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( - dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); - DatasetBase *dataset; - TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), - &dataset)); - core::ScopedUnref scoped_unref(dataset); - - std::unique_ptr iterator_ctx; - TF_ASSERT_OK( - CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); - std::unique_ptr iterator; - TF_ASSERT_OK( - dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); - DataTypeVector expected_output_dtypes = { - expected_outputs[0].indices.dtype(), expected_outputs[0].values.dtype(), - expected_outputs[0].dense_shape.dtype()}; - TF_EXPECT_OK( - VerifyTypesMatch(iterator->output_dtypes(), expected_output_dtypes)); - } + const TestCase &test_case = GetParam(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; + DataType tvalues = input_sparse_tensor.values.dtype(); + gtl::InlinedVector inputs = { + &input_sparse_tensor.indices, &input_sparse_tensor.values, + &input_sparse_tensor.dense_shape}; + + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( + dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); + DatasetBase *dataset; + TF_ASSERT_OK( + CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), &dataset)); + core::ScopedUnref scoped_unref(dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + DataTypeVector expected_output_dtypes = { + expected_outputs[0].indices.dtype(), expected_outputs[0].values.dtype(), + expected_outputs[0].dense_shape.dtype()}; + TF_EXPECT_OK( + VerifyTypesMatch(iterator->output_dtypes(), expected_output_dtypes)); } -TEST_F(SparseTensorSliceDatasetOpTest, IteratorOutputShapes) { +TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, IteratorOutputShapes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - for (auto &test_case : TestCases()) { - SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; - std::vector expected_outputs = - test_case.expected_outputs; - DataType tvalues = input_sparse_tensor.values.dtype(); - gtl::InlinedVector inputs = { - &input_sparse_tensor.indices, &input_sparse_tensor.values, - &input_sparse_tensor.dense_shape}; - - std::unique_ptr dataset_kernel; - TF_ASSERT_OK( - CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); - std::unique_ptr dataset_kernel_ctx; - TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( - dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); - DatasetBase *dataset; - TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), - &dataset)); - core::ScopedUnref scoped_unref(dataset); - - std::unique_ptr iterator_ctx; - TF_ASSERT_OK( - CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); - std::unique_ptr iterator; - TF_ASSERT_OK( - dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); - std::vector expected_output_shapes = { - expected_outputs[0].indices.shape(), expected_outputs[0].values.shape(), - expected_outputs[0].dense_shape.shape()}; - TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), - expected_output_shapes)); - } + const TestCase &test_case = GetParam(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; + DataType tvalues = input_sparse_tensor.values.dtype(); + gtl::InlinedVector inputs = { + &input_sparse_tensor.indices, &input_sparse_tensor.values, + &input_sparse_tensor.dense_shape}; + + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( + dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); + DatasetBase *dataset; + TF_ASSERT_OK( + CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), &dataset)); + core::ScopedUnref scoped_unref(dataset); + + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + std::vector expected_output_shapes = { + expected_outputs[0].indices.shape(), expected_outputs[0].values.shape(), + expected_outputs[0].dense_shape.shape()}; + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), + expected_output_shapes)); } TEST_F(SparseTensorSliceDatasetOpTest, IteratorOutputPrefix) { @@ -407,15 +433,16 @@ TEST_F(SparseTensorSliceDatasetOpTest, IteratorOutputPrefix) { TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - int N = 2; - const int NDIM = 2; - Tensor indices = CreateTensor(TensorShape({N, NDIM}), {0, 0, 1, 1}); - Tensor values = CreateTensor(TensorShape({N}), {888, 999}); - Tensor dense_shape = CreateTensor(TensorShape({NDIM}), {5, 5}); - gtl::InlinedVector inputs = {&indices, &values, &dense_shape}; + const TestCase &test_case = TwoDimsTestCase(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; + DataType tvalues = input_sparse_tensor.values.dtype(); + gtl::InlinedVector inputs = { + &input_sparse_tensor.indices, &input_sparse_tensor.values, + &input_sparse_tensor.dense_shape}; std::unique_ptr dataset_kernel; - TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(DT_INT32, &dataset_kernel)); + TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); std::unique_ptr dataset_kernel_ctx; TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); @@ -432,79 +459,81 @@ TEST_F(SparseTensorSliceDatasetOpTest, IteratorOutputPrefix) { EXPECT_EQ(iterator->prefix(), strings::StrCat("Iterator::SparseTensorSlice")); } -TEST_F(SparseTensorSliceDatasetOpTest, Roundtrip) { +TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, Roundtrip) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - for (auto &test_case : TestCases()) { - SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; - std::vector expected_outputs = - test_case.expected_outputs; - std::vector breakpoints = test_case.breakpoints; - DataType tvalues = input_sparse_tensor.values.dtype(); - gtl::InlinedVector inputs = { - &input_sparse_tensor.indices, &input_sparse_tensor.values, - &input_sparse_tensor.dense_shape}; - - std::unique_ptr dataset_kernel; - TF_ASSERT_OK( - CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); - std::unique_ptr dataset_kernel_ctx; - TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( - dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); - DatasetBase *dataset; - TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), - &dataset)); - core::ScopedUnref scoped_unref(dataset); - - std::unique_ptr iterator_ctx; - TF_ASSERT_OK( - CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); - std::unique_ptr iterator; - TF_ASSERT_OK( - dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); - - std::unique_ptr serialization_ctx; - TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); - - int cur_iteration = 0; - bool end_of_sequence = false; - int64 num_slices = input_sparse_tensor.dense_shape.dim_size(0); - std::vector out_tensors; - - for (int breakpoint : breakpoints) { - while (cur_iteration < breakpoint) { - TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors, - &end_of_sequence)); - cur_iteration++; - } + const TestCase &test_case = GetParam(); + SparseTensorParam input_sparse_tensor = test_case.input_sparse_tensor; + std::vector expected_outputs = test_case.expected_outputs; + std::vector breakpoints = test_case.breakpoints; + DataType tvalues = input_sparse_tensor.values.dtype(); + gtl::InlinedVector inputs = { + &input_sparse_tensor.indices, &input_sparse_tensor.values, + &input_sparse_tensor.dense_shape}; - if (breakpoint == 0) { - EXPECT_FALSE(end_of_sequence); - } else if (breakpoint <= num_slices) { - for (int i = 0; i < out_tensors.size(); ++i) { - TF_EXPECT_OK(ExpectEqual( - out_tensors[0], expected_outputs[cur_iteration - 1].indices)); - TF_EXPECT_OK(ExpectEqual(out_tensors[1], - expected_outputs[cur_iteration - 1].values)); - TF_EXPECT_OK(ExpectEqual( - out_tensors[2], expected_outputs[cur_iteration - 1].dense_shape)); - } - } else { - EXPECT_TRUE(end_of_sequence); - } + std::unique_ptr dataset_kernel; + TF_ASSERT_OK(CreateSparseTensorSliceDatasetKernel(tvalues, &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateSparseTensorSliceDatasetContext( + dataset_kernel.get(), &inputs, &dataset_kernel_ctx)); + DatasetBase *dataset; + TF_ASSERT_OK( + CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), &dataset)); + core::ScopedUnref scoped_unref(dataset); - VariantTensorData data; - VariantTensorDataWriter writer(&data); - TF_ASSERT_OK(iterator->Save(serialization_ctx.get(), &writer)); - TF_ASSERT_OK(writer.Flush()); - VariantTensorDataReader reader(&data); - TF_ASSERT_OK(iterator->Restore(iterator_ctx.get(), &reader)); + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + std::unique_ptr serialization_ctx; + TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); + + int cur_iteration = 0; + bool end_of_sequence = false; + int64 num_slices = input_sparse_tensor.dense_shape.dim_size(0); + std::vector out_tensors; + + for (int breakpoint : breakpoints) { + while (cur_iteration < breakpoint) { + TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors, + &end_of_sequence)); + cur_iteration++; + } + + if (breakpoint == 0) { + EXPECT_FALSE(end_of_sequence); + } else if (breakpoint <= num_slices) { + for (int i = 0; i < out_tensors.size(); ++i) { + TF_EXPECT_OK(ExpectEqual(out_tensors[0], + expected_outputs[cur_iteration - 1].indices)); + TF_EXPECT_OK(ExpectEqual(out_tensors[1], + expected_outputs[cur_iteration - 1].values)); + TF_EXPECT_OK(ExpectEqual( + out_tensors[2], expected_outputs[cur_iteration - 1].dense_shape)); + } + } else { + EXPECT_TRUE(end_of_sequence); } + + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_ASSERT_OK(iterator->Save(serialization_ctx.get(), &writer)); + TF_ASSERT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_ASSERT_OK(iterator->Restore(iterator_ctx.get(), &reader)); } } +INSTANTIATE_TEST_SUITE_P(SparseTensorSliceDatasetOpTest, + ParameterizedSparseTensorSliceDatasetOpTest, + ::testing::ValuesIn(std::vector( + {TwoDimsTestCase(), ThreeDimsTestCase(), + FourDimsTestCase(), FiveDimsTestCase()}))); + } // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/take_dataset_op_test.cc b/tensorflow/core/kernels/data/take_dataset_op_test.cc index afe22726552b60..0f2dac9a6e8142 100644 --- a/tensorflow/core/kernels/data/take_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/take_dataset_op_test.cc @@ -38,10 +38,10 @@ class TakeDatasetOpTest : public DatasetOpsTestBase { const DataTypeVector &output_types, const std::vector &output_shapes, std::unique_ptr *op_kernel) { - node_def_ = test::function::NDef( + NodeDef node_def = test::function::NDef( kNodeName, kOpName, {"input_dataset", "count"}, {{"output_types", output_types}, {"output_shapes", output_shapes}}); - TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, op_kernel)); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel)); return Status::OK(); } @@ -53,9 +53,6 @@ class TakeDatasetOpTest : public DatasetOpsTestBase { TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); return Status::OK(); } - - private: - NodeDef node_def_; }; struct TestCase { diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc index 883440924f06f5..2ef3502607c6df 100644 --- a/tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc @@ -46,10 +46,10 @@ class TensorSliceDatasetOpTest : public DatasetOpsTestBase { components.emplace_back(strings::StrCat("component_", i)); } - node_def_ = test::function::NDef( + NodeDef node_def = test::function::NDef( kNodeName, kOpName, components, {{"Toutput_types", dtypes}, {"output_shapes", shapes}}); - TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, tensor_dataset_kernel)); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, tensor_dataset_kernel)); return Status::OK(); } @@ -63,9 +63,6 @@ class TensorSliceDatasetOpTest : public DatasetOpsTestBase { CreateOpKernelContext(tensor_dataset_kernel, inputs, context)); return Status::OK(); } - - private: - NodeDef node_def_; }; struct TestCase { @@ -74,142 +71,189 @@ struct TestCase { std::vector breakpoints; }; -std::vector TestCases() { +TestCase PlainTensorTestCase() { + return {/*components*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({2}), {1, 2}), + DatasetOpsTestBase::CreateTensor(TensorShape({2, 2}), + {1, 2, 3, 4}), + DatasetOpsTestBase::CreateTensor(TensorShape({2, 1}), + {37.0, 38.0}), + DatasetOpsTestBase::CreateTensor(TensorShape({2, 1}), + {"a", "b"})}, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), + DatasetOpsTestBase::CreateTensor(TensorShape({2}), {1, 2}), + DatasetOpsTestBase::CreateTensor(TensorShape({1}), {37.0}), + DatasetOpsTestBase::CreateTensor(TensorShape({1}), {"a"}), + DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), + DatasetOpsTestBase::CreateTensor(TensorShape({2}), {3, 4}), + DatasetOpsTestBase::CreateTensor(TensorShape({1}), {38.0}), + DatasetOpsTestBase::CreateTensor(TensorShape({1}), {"b"})}, + /*breakpoints*/ {0, 1, 3}}; +} + +TestCase NestedTensorTestCase() { return { - // A single tuple of tensors. - {{{DatasetOpsTestBase::CreateTensor(TensorShape({2}), {1, 2}), - DatasetOpsTestBase::CreateTensor(TensorShape({2, 2}), - {1, 2, 3, 4}), - DatasetOpsTestBase::CreateTensor(TensorShape({2, 1}), - {37.0, 38.0}), - DatasetOpsTestBase::CreateTensor(TensorShape({2, 1}), - {"a", "b"})}}, // components - {{DatasetOpsTestBase::CreateTensor(TensorShape({}), {1}), - DatasetOpsTestBase::CreateTensor(TensorShape({2}), {1, 2}), - DatasetOpsTestBase::CreateTensor(TensorShape({1}), {37.0}), - DatasetOpsTestBase::CreateTensor(TensorShape({1}), {"a"}), - DatasetOpsTestBase::CreateTensor(TensorShape({}), {2}), - DatasetOpsTestBase::CreateTensor(TensorShape({2}), {3, 4}), - DatasetOpsTestBase::CreateTensor(TensorShape({1}), {38.0}), - DatasetOpsTestBase::CreateTensor(TensorShape({1}), - {"b"})}}, // expected_outputs - {{0, 1, 3}}}, // breakpoints - // Nested tensors - {{{DatasetOpsTestBase::CreateTensor( - TensorShape({2, 1}), - {DatasetOpsTestBase::CreateTensor(TensorShape({2, 2}), - {1.0, 2.0, 3.0, 4.0}), - DatasetOpsTestBase::CreateTensor(TensorShape({2, 2}), - {5.0, 6.0, 7.0, 8.0})}), - DatasetOpsTestBase::CreateTensor( - TensorShape({2, 1}), {DatasetOpsTestBase::CreateTensor( - TensorShape({1, 2}), {"a", "b"}), - DatasetOpsTestBase::CreateTensor( - TensorShape({1, 2}), {"c", "d"})}), - DatasetOpsTestBase::CreateTensor( - TensorShape({2, 3}), {1, 2, 3, 4, 5, 6})}}, // components - {{DatasetOpsTestBase::CreateTensor( - TensorShape({1}), {DatasetOpsTestBase::CreateTensor( - TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}), - DatasetOpsTestBase::CreateTensor( - TensorShape({1}), {DatasetOpsTestBase::CreateTensor( - TensorShape({1, 2}), {"a", "b"})}), - DatasetOpsTestBase::CreateTensor(TensorShape({3}), {1, 2, 3}), - DatasetOpsTestBase::CreateTensor( - TensorShape({1}), {DatasetOpsTestBase::CreateTensor( - TensorShape({2, 2}), {5.0, 6.0, 7.0, 8.0})}), - DatasetOpsTestBase::CreateTensor( - TensorShape({1}), {DatasetOpsTestBase::CreateTensor( - TensorShape({1, 2}), {"c", "d"})}), - DatasetOpsTestBase::CreateTensor( - TensorShape({3}), {4, 5, 6})}}, // expected_outputs - {{0, 1, 2}}} // breakpoints - }; + /*components*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape({2, 1}), + {DatasetOpsTestBase::CreateTensor(TensorShape({2, 2}), + {1.0, 2.0, 3.0, 4.0}), + DatasetOpsTestBase::CreateTensor(TensorShape({2, 2}), + {5.0, 6.0, 7.0, 8.0})}), + DatasetOpsTestBase::CreateTensor( + TensorShape({2, 1}), {DatasetOpsTestBase::CreateTensor( + TensorShape({1, 2}), {"a", "b"}), + DatasetOpsTestBase::CreateTensor( + TensorShape({1, 2}), {"c", "d"})}), + DatasetOpsTestBase::CreateTensor(TensorShape({2, 3}), + {1, 2, 3, 4, 5, 6})}, + /*expected_outputs*/ + {DatasetOpsTestBase::CreateTensor( + TensorShape({1}), {DatasetOpsTestBase::CreateTensor( + TensorShape({2, 2}), {1.0, 2.0, 3.0, 4.0})}), + DatasetOpsTestBase::CreateTensor( + TensorShape({1}), {DatasetOpsTestBase::CreateTensor( + TensorShape({1, 2}), {"a", "b"})}), + DatasetOpsTestBase::CreateTensor(TensorShape({3}), {1, 2, 3}), + DatasetOpsTestBase::CreateTensor( + TensorShape({1}), {DatasetOpsTestBase::CreateTensor( + TensorShape({2, 2}), {5.0, 6.0, 7.0, 8.0})}), + DatasetOpsTestBase::CreateTensor( + TensorShape({1}), {DatasetOpsTestBase::CreateTensor( + TensorShape({1, 2}), {"c", "d"})}), + DatasetOpsTestBase::CreateTensor(TensorShape({3}), {4, 5, 6})}, + /*breakpoints*/ {0, 1, 2}}; } -TEST_F(TensorSliceDatasetOpTest, GetNext) { +class ParameterizedTensorSliceDatasetOpTest + : public TensorSliceDatasetOpTest, + public ::testing::WithParamInterface {}; + +TEST_P(ParameterizedTensorSliceDatasetOpTest, GetNext) { int thread_num = 2, cpu_num = 2; - for (auto &test_case : TestCases()) { - std::vector components = test_case.components; - std::vector expected_outputs = test_case.expected_outputs; - size_t num_tensors_per_slice = components.size(); - - TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - - DataTypeVector dtypes; - std::vector shapes; - gtl::InlinedVector inputs; - for (auto &component : components) { - inputs.push_back(&component); - dtypes.push_back(component.dtype()); - } - for (int i = 0; i < num_tensors_per_slice; ++i) { - shapes.emplace_back(expected_outputs[i].shape()); - } + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr tensor_slice_dataset_kernel; - TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, - &tensor_slice_dataset_kernel)); - std::unique_ptr tensor_slice_dataset_context; - TF_ASSERT_OK(CreateTensorSliceDatasetContext( - tensor_slice_dataset_kernel.get(), &inputs, - &tensor_slice_dataset_context)); - DatasetBase *tensor_slice_dataset; - TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), - tensor_slice_dataset_context.get(), - &tensor_slice_dataset)); - core::ScopedUnref scored_unref(tensor_slice_dataset); - - std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(), - &iterator_context)); - std::unique_ptr iterator; - TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(), - "Iterator", &iterator)); - bool end_of_sequence = false; - std::vector out_tensors; - int cur_slice = 0; - - while (!end_of_sequence) { - TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, - &end_of_sequence)); - for (int i = 0; i < out_tensors.size(); ++i) { - EXPECT_LT(i + num_tensors_per_slice * cur_slice, - expected_outputs.size()); - if (out_tensors[i].dtype() == DT_VARIANT) { - // Currently `ExpectEqual()` does not support the variant tensor - // yet, so we manually cast the variant to numeric/string tensor. - const Tensor *output = - out_tensors[i].scalar()().get(); - const Tensor *expected_output = - expected_outputs[i + num_tensors_per_slice * cur_slice] - .scalar()() - .get(); - TF_EXPECT_OK(ExpectEqual(*output, *expected_output)); - } else { - TF_EXPECT_OK(ExpectEqual( - out_tensors[i], - expected_outputs[i + num_tensors_per_slice * cur_slice])); - } + const TestCase &test_case = GetParam(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; + DataTypeVector dtypes; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.emplace_back(&component); + dtypes.emplace_back(component.dtype()); + } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); + for (int i = 0; i < num_tensors_per_slice; ++i) { + shapes.emplace_back(expected_outputs[i].shape()); + } + std::unique_ptr tensor_slice_dataset_kernel; + TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, + &tensor_slice_dataset_kernel)); + std::unique_ptr tensor_slice_dataset_context; + TF_ASSERT_OK( + CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(), + &inputs, &tensor_slice_dataset_context)); + DatasetBase *tensor_slice_dataset; + TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), + tensor_slice_dataset_context.get(), + &tensor_slice_dataset)); + core::ScopedUnref scoped_unref(tensor_slice_dataset); + + std::unique_ptr iterator_context; + TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(), + &iterator_context)); + std::unique_ptr iterator; + TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(), + "Iterator", &iterator)); + bool end_of_sequence = false; + std::vector out_tensors; + int cur_slice = 0; + + while (!end_of_sequence) { + TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, + &end_of_sequence)); + for (int i = 0; i < out_tensors.size(); ++i) { + EXPECT_LT(i + num_tensors_per_slice * cur_slice, expected_outputs.size()); + if (out_tensors[i].dtype() == DT_VARIANT) { + // Currently `ExpectEqual()` does not support the variant tensor + // yet, so we manually cast the variant to numeric/string tensor. + const Tensor *output = out_tensors[i].scalar()().get(); + const Tensor *expected_output = + expected_outputs[i + num_tensors_per_slice * cur_slice] + .scalar()() + .get(); + TF_EXPECT_OK(ExpectEqual(*output, *expected_output)); + } else { + TF_EXPECT_OK(ExpectEqual( + out_tensors[i], + expected_outputs[i + num_tensors_per_slice * cur_slice])); } - out_tensors.clear(); - cur_slice++; } + out_tensors.clear(); + cur_slice++; } } -TEST_F(TensorSliceDatasetOpTest, DatasetName) { +TEST_F(TensorSliceDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = PlainTensorTestCase(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; + DataTypeVector dtypes; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.emplace_back(&component); + dtypes.emplace_back(component.dtype()); + } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); + for (int i = 0; i < num_tensors_per_slice; ++i) { + shapes.emplace_back(expected_outputs[i].shape()); + } + std::unique_ptr tensor_slice_dataset_kernel; + TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, + &tensor_slice_dataset_kernel)); + std::unique_ptr tensor_slice_dataset_context; + TF_ASSERT_OK( + CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(), + &inputs, &tensor_slice_dataset_context)); + DatasetBase *tensor_slice_dataset; + TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), + tensor_slice_dataset_context.get(), + &tensor_slice_dataset)); + core::ScopedUnref scoped_unref(tensor_slice_dataset); + + EXPECT_EQ(tensor_slice_dataset->node_name(), kNodeName); +} + +TEST_F(TensorSliceDatasetOpTest, DatasetTypeString) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - Tensor t1 = CreateTensor(TensorShape({2, 2}), {1, 2, 3, 4}); - Tensor t2 = CreateTensor(TensorShape({2, 2}), {5, 6, 7, 8}); - gtl::InlinedVector inputs = {&t1, &t2}; - DataTypeVector dtypes({DT_INT64, DT_INT64}); - std::vector shapes = {PartialTensorShape({2}), - PartialTensorShape({2})}; + const TestCase &test_case = PlainTensorTestCase(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; + DataTypeVector dtypes; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.emplace_back(&component); + dtypes.emplace_back(component.dtype()); + } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); + for (int i = 0; i < num_tensors_per_slice; ++i) { + shapes.emplace_back(expected_outputs[i].shape()); + } std::unique_ptr tensor_slice_dataset_kernel; TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, &tensor_slice_dataset_kernel)); @@ -221,134 +265,129 @@ TEST_F(TensorSliceDatasetOpTest, DatasetName) { TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), tensor_slice_dataset_context.get(), &tensor_slice_dataset)); - core::ScopedUnref scored_unref(tensor_slice_dataset); + core::ScopedUnref scoped_unref(tensor_slice_dataset); EXPECT_EQ(tensor_slice_dataset->type_string(), kOpName); } -TEST_F(TensorSliceDatasetOpTest, DatasetOutputDtypes) { +TEST_P(ParameterizedTensorSliceDatasetOpTest, DatasetOutputDtypes) { int thread_num = 2, cpu_num = 2; - for (auto &test_case : TestCases()) { - std::vector components = test_case.components; - std::vector expected_outputs = test_case.expected_outputs; - size_t num_tensors_per_slice = components.size(); - - TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - - DataTypeVector dtypes; - std::vector shapes; - gtl::InlinedVector inputs; - for (auto &component : components) { - inputs.emplace_back(&component); - dtypes.emplace_back(component.dtype()); - } - for (int i = 0; i < num_tensors_per_slice; ++i) { - shapes.emplace_back(expected_outputs[i].shape()); - } + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr tensor_slice_dataset_kernel; - TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, - &tensor_slice_dataset_kernel)); - std::unique_ptr tensor_slice_dataset_context; - TF_ASSERT_OK(CreateTensorSliceDatasetContext( - tensor_slice_dataset_kernel.get(), &inputs, - &tensor_slice_dataset_context)); - DatasetBase *tensor_slice_dataset; - TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), - tensor_slice_dataset_context.get(), - &tensor_slice_dataset)); - core::ScopedUnref scored_unref(tensor_slice_dataset); - - const DataTypeVector produced_output_dtypes = - tensor_slice_dataset->output_dtypes(); - EXPECT_EQ(produced_output_dtypes.size(), num_tensors_per_slice); - for (int i = 0; i < num_tensors_per_slice; ++i) { - EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype()); - } + const TestCase &test_case = GetParam(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; + DataTypeVector dtypes; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.emplace_back(&component); + dtypes.emplace_back(component.dtype()); + } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); + for (int i = 0; i < num_tensors_per_slice; ++i) { + shapes.emplace_back(expected_outputs[i].shape()); + } + std::unique_ptr tensor_slice_dataset_kernel; + TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, + &tensor_slice_dataset_kernel)); + std::unique_ptr tensor_slice_dataset_context; + TF_ASSERT_OK( + CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(), + &inputs, &tensor_slice_dataset_context)); + DatasetBase *tensor_slice_dataset; + TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), + tensor_slice_dataset_context.get(), + &tensor_slice_dataset)); + core::ScopedUnref scoped_unref(tensor_slice_dataset); + + const DataTypeVector produced_output_dtypes = + tensor_slice_dataset->output_dtypes(); + EXPECT_EQ(produced_output_dtypes.size(), num_tensors_per_slice); + for (int i = 0; i < num_tensors_per_slice; ++i) { + EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype()); } } -TEST_F(TensorSliceDatasetOpTest, DatasetOutputShapes) { +TEST_P(ParameterizedTensorSliceDatasetOpTest, DatasetOutputShapes) { int thread_num = 2, cpu_num = 2; - for (auto &test_case : TestCases()) { - std::vector components = test_case.components; - std::vector expected_outputs = test_case.expected_outputs; - size_t num_tensors_per_slice = components.size(); - - TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - - DataTypeVector dtypes; - std::vector shapes; - gtl::InlinedVector inputs; - for (auto &component : components) { - inputs.emplace_back(&component); - dtypes.emplace_back(component.dtype()); - } - for (int i = 0; i < num_tensors_per_slice; ++i) { - shapes.emplace_back(expected_outputs[i].shape()); - } - std::unique_ptr tensor_slice_dataset_kernel; - TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, - &tensor_slice_dataset_kernel)); - std::unique_ptr tensor_slice_dataset_context; - TF_ASSERT_OK(CreateTensorSliceDatasetContext( - tensor_slice_dataset_kernel.get(), &inputs, - &tensor_slice_dataset_context)); - DatasetBase *tensor_slice_dataset; - TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), - tensor_slice_dataset_context.get(), - &tensor_slice_dataset)); - core::ScopedUnref scored_unref(tensor_slice_dataset); - - const std::vector produced_output_shapes = - tensor_slice_dataset->output_shapes(); - std::vector expected_output_shapes; - EXPECT_EQ(produced_output_shapes.size(), num_tensors_per_slice); - for (int i = 0; i < num_tensors_per_slice; ++i) { - EXPECT_TRUE( - produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape())); - } + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = GetParam(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; + DataTypeVector dtypes; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.emplace_back(&component); + dtypes.emplace_back(component.dtype()); + } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); + for (int i = 0; i < num_tensors_per_slice; ++i) { + shapes.emplace_back(expected_outputs[i].shape()); + } + std::unique_ptr tensor_slice_dataset_kernel; + TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, + &tensor_slice_dataset_kernel)); + std::unique_ptr tensor_slice_dataset_context; + TF_ASSERT_OK( + CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(), + &inputs, &tensor_slice_dataset_context)); + DatasetBase *tensor_slice_dataset; + TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), + tensor_slice_dataset_context.get(), + &tensor_slice_dataset)); + core::ScopedUnref scoped_unref(tensor_slice_dataset); + + const std::vector produced_output_shapes = + tensor_slice_dataset->output_shapes(); + std::vector expected_output_shapes; + EXPECT_EQ(produced_output_shapes.size(), num_tensors_per_slice); + for (int i = 0; i < num_tensors_per_slice; ++i) { + EXPECT_TRUE( + produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape())); } } -TEST_F(TensorSliceDatasetOpTest, Cardinality) { +TEST_P(ParameterizedTensorSliceDatasetOpTest, Cardinality) { int thread_num = 2, cpu_num = 2; - for (auto &test_case : TestCases()) { - std::vector components = test_case.components; - std::vector expected_outputs = test_case.expected_outputs; - size_t num_tensors_per_slice = components.size(); - - TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - - DataTypeVector dtypes; - std::vector shapes; - gtl::InlinedVector inputs; - for (auto &component : components) { - inputs.emplace_back(&component); - dtypes.emplace_back(component.dtype()); - } - for (int i = 0; i < num_tensors_per_slice; ++i) { - shapes.emplace_back(expected_outputs[i].shape()); - } - std::unique_ptr tensor_slice_dataset_kernel; - TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, - &tensor_slice_dataset_kernel)); - std::unique_ptr tensor_slice_dataset_context; - TF_ASSERT_OK(CreateTensorSliceDatasetContext( - tensor_slice_dataset_kernel.get(), &inputs, - &tensor_slice_dataset_context)); - DatasetBase *tensor_slice_dataset; - TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), - tensor_slice_dataset_context.get(), - &tensor_slice_dataset)); - core::ScopedUnref scored_unref(tensor_slice_dataset); - - EXPECT_EQ(tensor_slice_dataset->Cardinality(), - inputs[0].tensor->dim_size(0)); + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestCase &test_case = GetParam(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; + DataTypeVector dtypes; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.emplace_back(&component); + dtypes.emplace_back(component.dtype()); } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); + for (int i = 0; i < num_tensors_per_slice; ++i) { + shapes.emplace_back(expected_outputs[i].shape()); + } + std::unique_ptr tensor_slice_dataset_kernel; + TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, + &tensor_slice_dataset_kernel)); + std::unique_ptr tensor_slice_dataset_context; + TF_ASSERT_OK( + CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(), + &inputs, &tensor_slice_dataset_context)); + DatasetBase *tensor_slice_dataset; + TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), + tensor_slice_dataset_context.get(), + &tensor_slice_dataset)); + core::ScopedUnref scoped_unref(tensor_slice_dataset); + + EXPECT_EQ(tensor_slice_dataset->Cardinality(), inputs[0].tensor->dim_size(0)); } TEST_F(TensorSliceDatasetOpTest, DatasetSave) { @@ -356,12 +395,21 @@ TEST_F(TensorSliceDatasetOpTest, DatasetSave) { TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - Tensor t1 = CreateTensor(TensorShape({2, 2}), {1, 2, 3, 4}); - Tensor t2 = CreateTensor(TensorShape({2, 2}), {5, 6, 7, 8}); - gtl::InlinedVector inputs = {&t1, &t2}; - DataTypeVector dtypes({DT_INT64, DT_INT64}); - std::vector shapes = {PartialTensorShape({2}), - PartialTensorShape({2})}; + const TestCase &test_case = PlainTensorTestCase(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; + DataTypeVector dtypes; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.emplace_back(&component); + dtypes.emplace_back(component.dtype()); + } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); + for (int i = 0; i < num_tensors_per_slice; ++i) { + shapes.emplace_back(expected_outputs[i].shape()); + } std::unique_ptr tensor_slice_dataset_kernel; TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, &tensor_slice_dataset_kernel)); @@ -373,7 +421,7 @@ TEST_F(TensorSliceDatasetOpTest, DatasetSave) { TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), tensor_slice_dataset_context.get(), &tensor_slice_dataset)); - core::ScopedUnref scored_unref(tensor_slice_dataset); + core::ScopedUnref scoped_unref(tensor_slice_dataset); std::unique_ptr serialization_context; TF_ASSERT_OK(CreateSerializationContext(&serialization_context)); @@ -384,102 +432,98 @@ TEST_F(TensorSliceDatasetOpTest, DatasetSave) { TF_ASSERT_OK(writer.Flush()); } -TEST_F(TensorSliceDatasetOpTest, IteratorOutputDtypes) { +TEST_P(ParameterizedTensorSliceDatasetOpTest, IteratorOutputDtypes) { int thread_num = 2, cpu_num = 2; - for (auto &test_case : TestCases()) { - std::vector components = test_case.components; - std::vector expected_outputs = test_case.expected_outputs; - size_t num_tensors_per_slice = components.size(); - - TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - - DataTypeVector dtypes; - std::vector shapes; - gtl::InlinedVector inputs; - for (auto &component : components) { - inputs.emplace_back(&component); - dtypes.emplace_back(component.dtype()); - } - for (int i = 0; i < num_tensors_per_slice; ++i) { - shapes.emplace_back(expected_outputs[i].shape()); - } + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr tensor_slice_dataset_kernel; - TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, - &tensor_slice_dataset_kernel)); - std::unique_ptr tensor_slice_dataset_context; - TF_ASSERT_OK(CreateTensorSliceDatasetContext( - tensor_slice_dataset_kernel.get(), &inputs, - &tensor_slice_dataset_context)); - DatasetBase *tensor_slice_dataset; - TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), - tensor_slice_dataset_context.get(), - &tensor_slice_dataset)); - core::ScopedUnref scored_unref(tensor_slice_dataset); - - std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(), - &iterator_context)); - std::unique_ptr iterator; - TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(), - "Iterator", &iterator)); - const DataTypeVector produced_output_dtypes = iterator->output_dtypes(); - - EXPECT_EQ(produced_output_dtypes.size(), num_tensors_per_slice); - for (int i = 0; i < num_tensors_per_slice; ++i) { - EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype()); - } + const TestCase &test_case = GetParam(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; + DataTypeVector dtypes; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.emplace_back(&component); + dtypes.emplace_back(component.dtype()); + } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); + for (int i = 0; i < num_tensors_per_slice; ++i) { + shapes.emplace_back(expected_outputs[i].shape()); + } + std::unique_ptr tensor_slice_dataset_kernel; + TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, + &tensor_slice_dataset_kernel)); + std::unique_ptr tensor_slice_dataset_context; + TF_ASSERT_OK( + CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(), + &inputs, &tensor_slice_dataset_context)); + DatasetBase *tensor_slice_dataset; + TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), + tensor_slice_dataset_context.get(), + &tensor_slice_dataset)); + core::ScopedUnref scoped_unref(tensor_slice_dataset); + + std::unique_ptr iterator_context; + TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(), + &iterator_context)); + std::unique_ptr iterator; + TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(), + "Iterator", &iterator)); + const DataTypeVector produced_output_dtypes = iterator->output_dtypes(); + + EXPECT_EQ(produced_output_dtypes.size(), num_tensors_per_slice); + for (int i = 0; i < num_tensors_per_slice; ++i) { + EXPECT_EQ(produced_output_dtypes[i], expected_outputs[i].dtype()); } } -TEST_F(TensorSliceDatasetOpTest, IteratorOutputShapes) { +TEST_P(ParameterizedTensorSliceDatasetOpTest, IteratorOutputShapes) { int thread_num = 2, cpu_num = 2; - for (auto &test_case : TestCases()) { - std::vector components = test_case.components; - std::vector expected_outputs = test_case.expected_outputs; - size_t num_tensors_per_slice = components.size(); - - TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - - DataTypeVector dtypes; - std::vector shapes; - gtl::InlinedVector inputs; - for (auto &component : components) { - inputs.emplace_back(&component); - dtypes.emplace_back(component.dtype()); - } - for (int i = 0; i < num_tensors_per_slice; ++i) { - shapes.emplace_back(expected_outputs[i].shape()); - } + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr tensor_slice_dataset_kernel; - TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, - &tensor_slice_dataset_kernel)); - std::unique_ptr tensor_slice_dataset_context; - TF_ASSERT_OK(CreateTensorSliceDatasetContext( - tensor_slice_dataset_kernel.get(), &inputs, - &tensor_slice_dataset_context)); - DatasetBase *tensor_slice_dataset; - TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), - tensor_slice_dataset_context.get(), - &tensor_slice_dataset)); - core::ScopedUnref scored_unref(tensor_slice_dataset); - - std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(), - &iterator_context)); - std::unique_ptr iterator; - TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(), - "Iterator", &iterator)); - const std::vector produced_output_shapes = - iterator->output_shapes(); - EXPECT_EQ(produced_output_shapes.size(), num_tensors_per_slice); - for (int i = 0; i < num_tensors_per_slice; ++i) { - EXPECT_TRUE( - produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape())); - } + const TestCase &test_case = GetParam(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; + DataTypeVector dtypes; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.emplace_back(&component); + dtypes.emplace_back(component.dtype()); + } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); + for (int i = 0; i < num_tensors_per_slice; ++i) { + shapes.emplace_back(expected_outputs[i].shape()); + } + std::unique_ptr tensor_slice_dataset_kernel; + TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, + &tensor_slice_dataset_kernel)); + std::unique_ptr tensor_slice_dataset_context; + TF_ASSERT_OK( + CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(), + &inputs, &tensor_slice_dataset_context)); + DatasetBase *tensor_slice_dataset; + TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), + tensor_slice_dataset_context.get(), + &tensor_slice_dataset)); + core::ScopedUnref scoped_unref(tensor_slice_dataset); + + std::unique_ptr iterator_context; + TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(), + &iterator_context)); + std::unique_ptr iterator; + TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(), + "Iterator", &iterator)); + const std::vector produced_output_shapes = + iterator->output_shapes(); + EXPECT_EQ(produced_output_shapes.size(), num_tensors_per_slice); + for (int i = 0; i < num_tensors_per_slice; ++i) { + EXPECT_TRUE( + produced_output_shapes[i].IsIdenticalTo(expected_outputs[i].shape())); } } @@ -488,12 +532,21 @@ TEST_F(TensorSliceDatasetOpTest, IteratorOutputPrefix) { TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - Tensor t1 = CreateTensor(TensorShape({2, 2}), {1, 2, 3, 4}); - Tensor t2 = CreateTensor(TensorShape({2, 2}), {5, 6, 7, 8}); - gtl::InlinedVector inputs = {&t1, &t2}; - DataTypeVector dtypes({DT_INT64, DT_INT64}); - std::vector shapes = {PartialTensorShape({2}), - PartialTensorShape({2})}; + const TestCase &test_case = PlainTensorTestCase(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; + DataTypeVector dtypes; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.emplace_back(&component); + dtypes.emplace_back(component.dtype()); + } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); + for (int i = 0; i < num_tensors_per_slice; ++i) { + shapes.emplace_back(expected_outputs[i].shape()); + } std::unique_ptr tensor_slice_dataset_kernel; TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, &tensor_slice_dataset_kernel)); @@ -505,7 +558,7 @@ TEST_F(TensorSliceDatasetOpTest, IteratorOutputPrefix) { TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), tensor_slice_dataset_context.get(), &tensor_slice_dataset)); - core::ScopedUnref scored_unref(tensor_slice_dataset); + core::ScopedUnref scoped_unref(tensor_slice_dataset); std::unique_ptr iterator_context; TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(), @@ -516,95 +569,96 @@ TEST_F(TensorSliceDatasetOpTest, IteratorOutputPrefix) { EXPECT_EQ(iterator->prefix(), "Iterator::TensorSlice"); } -TEST_F(TensorSliceDatasetOpTest, Roundtrip) { +TEST_P(ParameterizedTensorSliceDatasetOpTest, Roundtrip) { int thread_num = 2, cpu_num = 2; - for (auto &test_case : TestCases()) { - std::vector components = test_case.components; - std::vector expected_outputs = test_case.expected_outputs; - std::vector breakpoints = test_case.breakpoints; - size_t num_tensors_per_slice = components.size(); - - TF_ASSERT_OK(InitThreadPool(thread_num)); - TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - - DataTypeVector dtypes; - std::vector shapes; - gtl::InlinedVector inputs; - for (auto &component : components) { - inputs.emplace_back(&component); - dtypes.emplace_back(component.dtype()); - } - for (int i = 0; i < num_tensors_per_slice; ++i) { - shapes.emplace_back(expected_outputs[i].shape()); - } + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - std::unique_ptr tensor_slice_dataset_kernel; - TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, - &tensor_slice_dataset_kernel)); - std::unique_ptr tensor_slice_dataset_context; - TF_ASSERT_OK(CreateTensorSliceDatasetContext( - tensor_slice_dataset_kernel.get(), &inputs, - &tensor_slice_dataset_context)); - DatasetBase *tensor_slice_dataset; - TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), - tensor_slice_dataset_context.get(), - &tensor_slice_dataset)); - core::ScopedUnref scored_unref(tensor_slice_dataset); - - std::unique_ptr iterator_context; - TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(), - &iterator_context)); - std::unique_ptr iterator; - TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(), - "Iterator", &iterator)); - std::unique_ptr serialization_context; - TF_ASSERT_OK(CreateSerializationContext(&serialization_context)); - - int cur_iteration = 0; - bool end_of_sequence = false; - int64 num_slices = inputs[0].tensor->dim_size(0); - std::vector out_tensors; - - for (int breakpoint : breakpoints) { - while (cur_iteration < breakpoint) { - TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, - &end_of_sequence)); - cur_iteration++; - } + const TestCase &test_case = GetParam(); + const std::vector &expected_outputs = test_case.expected_outputs; + std::vector components = test_case.components; + DataTypeVector dtypes; + gtl::InlinedVector inputs; + for (auto &component : components) { + inputs.emplace_back(&component); + dtypes.emplace_back(component.dtype()); + } + size_t num_tensors_per_slice = components.size(); + std::vector shapes; + shapes.reserve(num_tensors_per_slice); + for (int i = 0; i < num_tensors_per_slice; ++i) { + shapes.emplace_back(expected_outputs[i].shape()); + } + std::unique_ptr tensor_slice_dataset_kernel; + TF_ASSERT_OK(CreateTensorSliceDatasetKernel(dtypes, shapes, + &tensor_slice_dataset_kernel)); + std::unique_ptr tensor_slice_dataset_context; + TF_ASSERT_OK( + CreateTensorSliceDatasetContext(tensor_slice_dataset_kernel.get(), + &inputs, &tensor_slice_dataset_context)); + DatasetBase *tensor_slice_dataset; + TF_ASSERT_OK(CreateDataset(tensor_slice_dataset_kernel.get(), + tensor_slice_dataset_context.get(), + &tensor_slice_dataset)); + core::ScopedUnref scoped_unref(tensor_slice_dataset); - if (breakpoint == 0) { - EXPECT_FALSE(end_of_sequence); - } else if (breakpoint <= num_slices) { - for (int i = 0; i < out_tensors.size(); ++i) { - if (out_tensors[i].dtype() == DT_VARIANT) { - const Tensor *output = - out_tensors[i].scalar()().get(); - const Tensor *expected_output = - expected_outputs[i + - num_tensors_per_slice * (cur_iteration - 1)] - .scalar()() - .get(); - TF_EXPECT_OK(ExpectEqual(*output, *expected_output)); - } else { - TF_EXPECT_OK(ExpectEqual( - out_tensors[i], expected_outputs[i + num_tensors_per_slice * - (cur_iteration - 1)])); - } + std::unique_ptr iterator_context; + TF_ASSERT_OK(CreateIteratorContext(tensor_slice_dataset_context.get(), + &iterator_context)); + std::unique_ptr iterator; + TF_ASSERT_OK(tensor_slice_dataset->MakeIterator(iterator_context.get(), + "Iterator", &iterator)); + std::unique_ptr serialization_context; + TF_ASSERT_OK(CreateSerializationContext(&serialization_context)); + + int cur_iteration = 0; + bool end_of_sequence = false; + int64 num_slices = inputs[0].tensor->dim_size(0); + std::vector out_tensors; + const std::vector &breakpoints = test_case.breakpoints; + for (int breakpoint : breakpoints) { + while (cur_iteration < breakpoint) { + TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors, + &end_of_sequence)); + cur_iteration++; + } + + if (breakpoint == 0) { + EXPECT_FALSE(end_of_sequence); + } else if (breakpoint <= num_slices) { + for (int i = 0; i < out_tensors.size(); ++i) { + if (out_tensors[i].dtype() == DT_VARIANT) { + const Tensor *output = + out_tensors[i].scalar()().get(); + const Tensor *expected_output = + expected_outputs[i + num_tensors_per_slice * (cur_iteration - 1)] + .scalar()() + .get(); + TF_EXPECT_OK(ExpectEqual(*output, *expected_output)); + } else { + TF_EXPECT_OK(ExpectEqual( + out_tensors[i], expected_outputs[i + num_tensors_per_slice * + (cur_iteration - 1)])); } - } else { - EXPECT_TRUE(end_of_sequence); } - - VariantTensorData data; - VariantTensorDataWriter writer(&data); - TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer)); - TF_ASSERT_OK(writer.Flush()); - VariantTensorDataReader reader(&data); - TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader)); + } else { + EXPECT_TRUE(end_of_sequence); } + + VariantTensorData data; + VariantTensorDataWriter writer(&data); + TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer)); + TF_ASSERT_OK(writer.Flush()); + VariantTensorDataReader reader(&data); + TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader)); } } +INSTANTIATE_TEST_SUITE_P(TensorSliceDatasetOpTest, + ParameterizedTensorSliceDatasetOpTest, + ::testing::ValuesIn(std::vector( + {PlainTensorTestCase(), NestedTensorTestCase()}))); + } // namespace } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/zip_dataset_op_test.cc b/tensorflow/core/kernels/data/zip_dataset_op_test.cc index 9f9e86a3d088b6..7c51c0443335ed 100644 --- a/tensorflow/core/kernels/data/zip_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/zip_dataset_op_test.cc @@ -58,10 +58,10 @@ class ZipDatasetOpTest : public DatasetOpsTestBase { // Create the placeholder names for the input components of `ZipDataset`. input_datasets.emplace_back(strings::StrCat("input_dataset_", i)); } - node_def_ = test::function::NDef( + NodeDef node_def = test::function::NDef( kNodeName, kOpName, input_datasets, {{"output_types", dtypes}, {"output_shapes", output_shapes}, {"N", n}}); - TF_RETURN_IF_ERROR(CreateOpKernel(node_def_, op_kernel)); + TF_RETURN_IF_ERROR(CreateOpKernel(node_def, op_kernel)); return Status::OK(); } @@ -74,9 +74,6 @@ class ZipDatasetOpTest : public DatasetOpsTestBase { TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); return Status::OK(); } - - private: - NodeDef node_def_; }; struct TestParam { @@ -85,8 +82,8 @@ struct TestParam { std::vector breakpoints; }; +// Test case 1: the input datasets with same number of outputs. TestParam TestCase1() { - // Test case 1: the input datasets with same number of outputs. return {/*input_range_dataset_params*/ {RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 13, 1}}, /*expected_outputs*/ @@ -99,8 +96,8 @@ TestParam TestCase1() { /*breakpoints*/ {0, 1, 4}}; } +// Test case 2: the input datasets with different number of outputs. TestParam TestCase2() { - // Test case 2: the input datasets with different number of outputs. return {/*input_range_dataset_params*/ {RangeDatasetParam{0, 3, 1}, RangeDatasetParam{10, 15, 1}}, /*expected_outputs*/ @@ -113,67 +110,48 @@ TestParam TestCase2() { /*breakpoints*/ {0, 1, 4}}; } -class ZipDatasetOpTestHelper : public ZipDatasetOpTest { - public: - ~ZipDatasetOpTestHelper() override { - if (dataset_) dataset_->Unref(); - } - - protected: - Status CreateDatasetFromTestCase(const TestParam &test_case) { - std::vector range_dataset_tensors; - range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); - TF_RETURN_IF_ERROR(CreateRangeDatasetTensors( - test_case.input_range_dataset_params, &range_dataset_tensors)); - gtl::InlinedVector inputs; - inputs.reserve(range_dataset_tensors.size()); - for (auto &tensor : range_dataset_tensors) { - inputs.emplace_back(&tensor); - } - int num_tensors_per_slice = test_case.input_range_dataset_params.size(); - TF_RETURN_IF_ERROR(CreateZipDatasetKernel({DT_INT64}, - {{num_tensors_per_slice}}, - inputs.size(), &dataset_kernel_)); - TF_RETURN_IF_ERROR(CreateZipDatasetContext(dataset_kernel_.get(), &inputs, - &dataset_kernel_ctx_)); - TF_RETURN_IF_ERROR(CreateDataset(dataset_kernel_.get(), - dataset_kernel_ctx_.get(), &dataset_)); - return Status::OK(); - } - - Status CreateIteratorFromTestCase(const TestParam &test_case) { - TF_RETURN_IF_ERROR(CreateDatasetFromTestCase(test_case)); - TF_RETURN_IF_ERROR( - CreateIteratorContext(dataset_kernel_ctx_.get(), &iterator_ctx_)); - TF_RETURN_IF_ERROR( - dataset_->MakeIterator(iterator_ctx_.get(), "Iterator", &iterator_)); - return Status::OK(); - } - - std::unique_ptr dataset_kernel_; - std::unique_ptr dataset_kernel_ctx_; - DatasetBase *dataset_ = nullptr; // owned by this class. - std::unique_ptr iterator_ctx_; - std::unique_ptr iterator_; -}; - -class ParameterizedDatasetTest - : public ZipDatasetOpTestHelper, +class ParameterizedZipDatasetOpTest + : public ZipDatasetOpTest, public ::testing::WithParamInterface {}; -TEST_P(ParameterizedDatasetTest, GetNext) { +TEST_P(ParameterizedZipDatasetOpTest, GetNext) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestParam &test_case = GetParam(); - TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + int num_tensors_per_slice = test_case.input_range_dataset_params.size(); + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); auto expected_outputs_it = test_case.expected_outputs.begin(); bool end_of_sequence = false; std::vector out_tensors; while (!end_of_sequence) { - TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors, - &end_of_sequence)); + TF_EXPECT_OK( + iterator->GetNext(iterator_ctx.get(), &out_tensors, &end_of_sequence)); if (!end_of_sequence) { for (const auto &tensor : out_tensors) { EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); @@ -185,22 +163,92 @@ TEST_P(ParameterizedDatasetTest, GetNext) { EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); } -TEST_F(ZipDatasetOpTestHelper, DatasetName) { +TEST_F(ZipDatasetOpTest, DatasetNodeName) { + int thread_num = 2, cpu_num = 2; + TF_ASSERT_OK(InitThreadPool(thread_num)); + TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + + const TestParam &test_case = TestCase1(); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + int num_tensors_per_slice = test_case.input_range_dataset_params.size(); + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); + + EXPECT_EQ(zip_dataset->node_name(), kNodeName); +} + +TEST_F(ZipDatasetOpTest, DatasetTypeString) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - TF_ASSERT_OK(CreateDatasetFromTestCase(TestCase1())); - EXPECT_EQ(dataset_->type_string(), kOpName); + const TestParam &test_case = TestCase1(); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + int num_tensors_per_slice = test_case.input_range_dataset_params.size(); + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); + + EXPECT_EQ(zip_dataset->type_string(), kOpName); } -TEST_P(ParameterizedDatasetTest, DatasetOutputDtypes) { +TEST_P(ParameterizedZipDatasetOpTest, DatasetOutputDtypes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestParam &test_case = GetParam(); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; int num_tensors_per_slice = test_case.input_range_dataset_params.size(); - TF_ASSERT_OK(CreateDatasetFromTestCase(test_case)); + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); DataTypeVector expected_output_dtypes; expected_output_dtypes.reserve(num_tensors_per_slice); @@ -209,16 +257,35 @@ TEST_P(ParameterizedDatasetTest, DatasetOutputDtypes) { } TF_EXPECT_OK( - VerifyTypesMatch(dataset_->output_dtypes(), expected_output_dtypes)); + VerifyTypesMatch(zip_dataset->output_dtypes(), expected_output_dtypes)); } -TEST_P(ParameterizedDatasetTest, DatasetOutputShapes) { +TEST_P(ParameterizedZipDatasetOpTest, DatasetOutputShapes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestParam &test_case = GetParam(); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; int num_tensors_per_slice = test_case.input_range_dataset_params.size(); - TF_ASSERT_OK(CreateDatasetFromTestCase(test_case)); + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); std::vector expected_output_shapes; expected_output_shapes.reserve(num_tensors_per_slice); @@ -226,43 +293,107 @@ TEST_P(ParameterizedDatasetTest, DatasetOutputShapes) { expected_output_shapes.emplace_back(test_case.expected_outputs[i].shape()); } - TF_EXPECT_OK(VerifyShapesCompatible(dataset_->output_shapes(), + TF_EXPECT_OK(VerifyShapesCompatible(zip_dataset->output_shapes(), expected_output_shapes)); } -TEST_P(ParameterizedDatasetTest, Cardinality) { +TEST_P(ParameterizedZipDatasetOpTest, Cardinality) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestParam &test_case = GetParam(); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; int num_tensors_per_slice = test_case.input_range_dataset_params.size(); - TF_ASSERT_OK(CreateDatasetFromTestCase(test_case)); - - EXPECT_EQ(dataset_->Cardinality(), + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); + + EXPECT_EQ(zip_dataset->Cardinality(), test_case.expected_outputs.size() / num_tensors_per_slice); } -TEST_F(ZipDatasetOpTestHelper, DatasetSave) { +TEST_F(ZipDatasetOpTest, DatasetSave) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - TF_ASSERT_OK(CreateDatasetFromTestCase(TestCase1())); + + const TestParam &test_case = TestCase1(); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + int num_tensors_per_slice = test_case.input_range_dataset_params.size(); + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); std::unique_ptr serialization_ctx; TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); VariantTensorData data; VariantTensorDataWriter writer(&data); - TF_ASSERT_OK(dataset_->Save(serialization_ctx.get(), &writer)); + TF_ASSERT_OK(zip_dataset->Save(serialization_ctx.get(), &writer)); TF_ASSERT_OK(writer.Flush()); } -TEST_P(ParameterizedDatasetTest, IteratorOutputDtypes) { +TEST_P(ParameterizedZipDatasetOpTest, IteratorOutputDtypes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestParam &test_case = GetParam(); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; int num_tensors_per_slice = test_case.input_range_dataset_params.size(); - TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); DataTypeVector expected_output_dtypes; expected_output_dtypes.reserve(num_tensors_per_slice); @@ -271,16 +402,40 @@ TEST_P(ParameterizedDatasetTest, IteratorOutputDtypes) { } TF_EXPECT_OK( - VerifyTypesMatch(iterator_->output_dtypes(), expected_output_dtypes)); + VerifyTypesMatch(iterator->output_dtypes(), expected_output_dtypes)); } -TEST_P(ParameterizedDatasetTest, IteratorOutputShapes) { +TEST_P(ParameterizedZipDatasetOpTest, IteratorOutputShapes) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestParam &test_case = GetParam(); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; int num_tensors_per_slice = test_case.input_range_dataset_params.size(); - TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); std::vector expected_output_shapes; expected_output_shapes.reserve(num_tensors_per_slice); @@ -288,43 +443,95 @@ TEST_P(ParameterizedDatasetTest, IteratorOutputShapes) { expected_output_shapes.emplace_back(test_case.expected_outputs[i].shape()); } - TF_EXPECT_OK(VerifyShapesCompatible(iterator_->output_shapes(), + TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(), expected_output_shapes)); } -TEST_F(ZipDatasetOpTestHelper, IteratorOutputPrefix) { +TEST_F(ZipDatasetOpTest, IteratorOutputPrefix) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); - TF_ASSERT_OK(CreateIteratorFromTestCase(TestCase1())); - EXPECT_EQ(iterator_->prefix(), "Iterator::Zip"); + + const TestParam &test_case = TestCase1(); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + int num_tensors_per_slice = test_case.input_range_dataset_params.size(); + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); + + EXPECT_EQ(iterator->prefix(), "Iterator::Zip"); } -TEST_P(ParameterizedDatasetTest, Roundtrip) { +TEST_P(ParameterizedZipDatasetOpTest, Roundtrip) { int thread_num = 2, cpu_num = 2; TF_ASSERT_OK(InitThreadPool(thread_num)); TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num)); + const TestParam &test_case = GetParam(); - auto expected_outputs_it = test_case.expected_outputs.begin(); - TF_ASSERT_OK(CreateIteratorFromTestCase(test_case)); + std::vector range_dataset_tensors; + range_dataset_tensors.reserve(test_case.input_range_dataset_params.size()); + TF_ASSERT_OK(CreateRangeDatasetTensors(test_case.input_range_dataset_params, + &range_dataset_tensors)); + gtl::InlinedVector inputs; + inputs.reserve(range_dataset_tensors.size()); + for (auto &tensor : range_dataset_tensors) { + inputs.emplace_back(&tensor); + } + std::unique_ptr dataset_kernel; + int num_tensors_per_slice = test_case.input_range_dataset_params.size(); + TF_ASSERT_OK(CreateZipDatasetKernel({DT_INT64}, {{num_tensors_per_slice}}, + inputs.size(), &dataset_kernel)); + std::unique_ptr dataset_kernel_ctx; + TF_ASSERT_OK(CreateZipDatasetContext(dataset_kernel.get(), &inputs, + &dataset_kernel_ctx)); + DatasetBase *zip_dataset; + TF_ASSERT_OK(CreateDataset(dataset_kernel.get(), dataset_kernel_ctx.get(), + &zip_dataset)); + core::ScopedUnref scoped_unref(zip_dataset); + std::unique_ptr iterator_ctx; + TF_ASSERT_OK(CreateIteratorContext(dataset_kernel_ctx.get(), &iterator_ctx)); + std::unique_ptr iterator; + TF_ASSERT_OK( + zip_dataset->MakeIterator(iterator_ctx.get(), "Iterator", &iterator)); std::unique_ptr serialization_ctx; TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx)); bool end_of_sequence = false; std::vector out_tensors; + auto expected_outputs_it = test_case.expected_outputs.begin(); int cur_iteration = 0; for (int breakpoint : test_case.breakpoints) { VariantTensorData data; VariantTensorDataWriter writer(&data); - TF_EXPECT_OK(iterator_->Save(serialization_ctx.get(), &writer)); + TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer)); TF_EXPECT_OK(writer.Flush()); VariantTensorDataReader reader(&data); - TF_EXPECT_OK(iterator_->Restore(iterator_ctx_.get(), &reader)); + TF_EXPECT_OK(iterator->Restore(iterator_ctx.get(), &reader)); while (cur_iteration < breakpoint) { - TF_EXPECT_OK(iterator_->GetNext(iterator_ctx_.get(), &out_tensors, - &end_of_sequence)); + TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors, + &end_of_sequence)); if (!end_of_sequence) { for (auto &tensor : out_tensors) { EXPECT_NE(expected_outputs_it, test_case.expected_outputs.end()); @@ -335,7 +542,7 @@ TEST_P(ParameterizedDatasetTest, Roundtrip) { cur_iteration++; } - if (breakpoint >= dataset_->Cardinality()) { + if (breakpoint >= zip_dataset->Cardinality()) { EXPECT_TRUE(end_of_sequence); EXPECT_EQ(expected_outputs_it, test_case.expected_outputs.end()); } else { @@ -345,7 +552,7 @@ TEST_P(ParameterizedDatasetTest, Roundtrip) { } INSTANTIATE_TEST_SUITE_P( - ZipDatasetOpTest, ParameterizedDatasetTest, + ZipDatasetOpTest, ParameterizedZipDatasetOpTest, ::testing::ValuesIn(std::vector({TestCase1(), TestCase2()}))); } // namespace