Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the flaky issue in ParallelInterleaveDatasetOpTest #27805

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/data/concatenate_dataset_op_test.cc
Expand Up @@ -536,7 +536,8 @@ TEST_P(ParameterizedConcatenateDatasetOpTest, Roundtrip) {
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_EXPECT_OK(iterator->Restore(iterator_ctx.get(), &reader));
TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator",
*concatenate_dataset, &iterator));

while (cur_iteration < breakpoint) {
TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors,
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/core/kernels/data/dataset_test_base.cc
Expand Up @@ -177,6 +177,15 @@ Status DatasetOpsTestBase::CreateDataset(OpKernel* kernel,
return Status::OK();
}

Status DatasetOpsTestBase::RestoreIterator(
IteratorContext* ctx, IteratorStateReader* reader,
const string& output_prefix, const DatasetBase& dataset,
std::unique_ptr<IteratorBase>* iterator) {
TF_RETURN_IF_ERROR(dataset.MakeIterator(ctx, output_prefix, iterator));
TF_RETURN_IF_ERROR((*iterator)->Restore(ctx, reader));
return Status::OK();
}

Status DatasetOpsTestBase::CreateIteratorContext(
OpKernelContext* const op_context,
std::unique_ptr<IteratorContext>* iterator_context) {
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/core/kernels/data/dataset_test_base.h
Expand Up @@ -75,6 +75,15 @@ class DatasetOpsTestBase : public ::testing::Test {
Status CreateDataset(OpKernel* kernel, OpKernelContext* context,
DatasetBase** const dataset);

// Restores the state of the input iterator. It resets the iterator before
// restoring it to make sure the input iterator does not hold any
// resources or tasks. Otherwise, restoring an existing iterator may cause
// the timeout issue or duplicated elements.
Status RestoreIterator(IteratorContext* ctx, IteratorStateReader* reader,
const string& output_prefix,
const DatasetBase& dataset,
std::unique_ptr<IteratorBase>* iterator);

// Creates a new RangeDataset op kernel. `T` specifies the output dtype of the
// op kernel.
template <typename T>
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/data/flat_map_dataset_op_test.cc
Expand Up @@ -546,7 +546,8 @@ TEST_P(ParameterizedFlatMapDatasetOpTest, Roundtrip) {
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_EXPECT_OK(iterator->Restore(iterator_ctx.get(), &reader));
TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator",
*flat_map_dataset, &iterator));

while (cur_iteration <= breakpoint) {
TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors,
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/data/interleave_dataset_op_test.cc
Expand Up @@ -763,7 +763,8 @@ TEST_P(ParameterizedInterleaveDatasetOpTest, Roundtrip) {
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_EXPECT_OK(iterator->Restore(iterator_ctx.get(), &reader));
TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator",
*interleave_dataset, &iterator));

while (cur_iteration <= breakpoint) {
TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors,
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/data/map_dataset_op_test.cc
Expand Up @@ -550,7 +550,8 @@ TEST_P(ParameterizedMapDatasetOpTest, Roundtrip) {
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_EXPECT_OK(iterator->Restore(iterator_context.get(), &reader));
TF_EXPECT_OK(RestoreIterator(iterator_context.get(), &reader, "Iterator",
*map_dataset, &iterator));

while (cur_iteration <= breakpoint) {
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
Expand Down
Expand Up @@ -884,8 +884,7 @@ TEST_F(ParallelInterleaveDatasetOpTest, IteratorOutputPrefix) {
EXPECT_EQ(iterator->prefix(), "Iterator::ParallelInterleaveV2");
}

// TODO(b/130309946): Re-enable once deflaked.
TEST_P(ParameterizedParallelInterleaveDatasetOpTest, DISABLED_Roundtrip) {
TEST_P(ParameterizedParallelInterleaveDatasetOpTest, Roundtrip) {
int thread_num = 2, cpu_num = 2;
const TestCase &test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
Expand Down Expand Up @@ -937,7 +936,8 @@ TEST_P(ParameterizedParallelInterleaveDatasetOpTest, DISABLED_Roundtrip) {
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_EXPECT_OK(iterator->Restore(iterator_ctx.get(), &reader));
TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator",
*parallel_interleave_dataset, &iterator));

while (cur_iteration <= breakpoint) {
std::vector<Tensor> next;
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc
Expand Up @@ -755,9 +755,9 @@ TEST_P(ParameterizedParallelMapDatasetOpTest, Roundtrip) {
VariantTensorDataWriter writer(&data);
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush());

VariantTensorDataReader reader(&data);
TF_EXPECT_OK(iterator->Restore(iterator_ctx.get(), &reader));
TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator",
*parallel_map_dataset, &iterator));

while (cur_iteration <= breakpoint) {
std::vector<Tensor> next;
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/data/prefetch_dataset_op_test.cc
Expand Up @@ -586,7 +586,6 @@ TEST_P(ParameterizedPrefetchDatasetOpTest, Roundtrip) {

std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));

bool end_of_sequence = false;
std::vector<Tensor> out_tensors;
int cur_iteration = 0;
Expand All @@ -598,7 +597,8 @@ TEST_P(ParameterizedPrefetchDatasetOpTest, Roundtrip) {
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_EXPECT_OK(iterator->Restore(iterator_ctx.get(), &reader));
TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator",
*prefetch_dataset, &iterator));

while (cur_iteration <= breakpoint) {
TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors,
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/data/range_dataset_op_test.cc
Expand Up @@ -494,7 +494,8 @@ TEST_P(ParameterizedRangeDatasetOpTest, Roundtrip) {
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_EXPECT_OK(iterator->Restore(iterator_context.get(), &reader));
TF_EXPECT_OK(RestoreIterator(iterator_context.get(), &reader, "Iterator",
*range_dataset, &iterator));

while (cur_iteration <= breakpoint) {
TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/data/repeat_dataset_op_test.cc
Expand Up @@ -547,7 +547,8 @@ TEST_P(ParameterizedDatasetOpTest, Roundtrip) {
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_EXPECT_OK(iterator->Restore(iterator_ctx.get(), &reader));
TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator",
*repeat_dataset, &iterator));

while (cur_iteration < breakpoint) {
out_tensors.clear();
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/data/skip_dataset_op_test.cc
Expand Up @@ -553,7 +553,8 @@ TEST_P(ParameterizedSkipDatasetOpTest, Roundtrip) {
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_EXPECT_OK(iterator->Restore(iterator_ctx.get(), &reader));
TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator",
*skip_dataset, &iterator));

while (cur_iteration <= breakpoint) {
TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors,
Expand Down
Expand Up @@ -524,7 +524,8 @@ TEST_P(ParameterizedSparseTensorSliceDatasetOpTest, Roundtrip) {
TF_ASSERT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_ASSERT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_ASSERT_OK(iterator->Restore(iterator_ctx.get(), &reader));
TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator",
*dataset, &iterator));
}
}

Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/data/take_dataset_op_test.cc
Expand Up @@ -548,7 +548,8 @@ TEST_P(ParameterizedTakeDatasetOpTest, Roundtrip) {
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_EXPECT_OK(iterator->Restore(iterator_ctx.get(), &reader));
TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator",
*take_dataset, &iterator));

while (cur_iteration <= breakpoint) {
TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors,
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/data/tensor_dataset_op_test.cc
Expand Up @@ -492,7 +492,8 @@ TEST_P(ParametrizedTensorDatasetOpTest, Roundtrip) {
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_EXPECT_OK(iterator->Restore(iterator_ctx.get(), &reader));
TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator",
*tensor_dataset, &iterator));

while (cur_iteration <= breakpoint) {
TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors,
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/data/tensor_slice_dataset_op_test.cc
Expand Up @@ -650,7 +650,8 @@ TEST_P(ParameterizedTensorSliceDatasetOpTest, Roundtrip) {
TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer));
TF_ASSERT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader));
TF_EXPECT_OK(RestoreIterator(iterator_context.get(), &reader, "Iterator",
*tensor_slice_dataset, &iterator));
}
}

Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/kernels/data/zip_dataset_op_test.cc
Expand Up @@ -527,7 +527,8 @@ TEST_P(ParameterizedZipDatasetOpTest, Roundtrip) {
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_EXPECT_OK(iterator->Restore(iterator_ctx.get(), &reader));
TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, "Iterator",
*zip_dataset, &iterator));

while (cur_iteration < breakpoint) {
TF_EXPECT_OK(iterator->GetNext(iterator_ctx.get(), &out_tensors,
Expand Down