diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index 01044957857106..2cb5c69584faf3 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -191,8 +191,6 @@ class SnapshotDatasetV2Op::Dataset::Iterator::Reader explicit Reader(const Params& params, int64 start_index); - ~Reader() override; - Status Initialize(IteratorContext* ctx) override; Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, @@ -212,7 +210,7 @@ class SnapshotDatasetV2Op::Dataset::Iterator::Reader std::unique_ptr input_impl_ TF_GUARDED_BY(mu_); - DatasetBase* input_ TF_GUARDED_BY(mu_); + DatasetBase* input_ TF_GUARDED_BY(mu_) = nullptr; std::unique_ptr instantiated_reader_func_ TF_GUARDED_BY(mu_); @@ -451,7 +449,11 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::GetNextInternal( bool* end_of_sequence) { mutex_lock l(mu_); if (iterator_ == nullptr) { - TF_RETURN_IF_ERROR(InitializeIterator(ctx, nullptr)); + Status s = InitializeIterator(ctx, /*reader=*/nullptr); + if (!s.ok()) { + iterator_.reset(); + return s; + } } index_++; return iterator_->GetNext(ctx, out_tensors, end_of_sequence); @@ -530,8 +532,6 @@ SnapshotDatasetV2Op::Dataset::Iterator::Reader::Reader(const Params& params, int64 start_index) : DatasetIterator(params), start_index_(start_index) {} -SnapshotDatasetV2Op::Dataset::Iterator::Reader::~Reader() { input_->Unref(); } - Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize( IteratorContext* ctx) { mutex_lock l(mu_); @@ -578,11 +578,6 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize( "reader_func returns more than one argument."); } TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(reader_output[0], &input_)); - - // We need to take a reference here as we will use the input_ and - // its iterator. - input_->Ref(); - return input_->MakeIterator(ctx, this, prefix(), &input_impl_); } diff --git a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py index b1fa780f6b3c80..c1e3b1e9fa0b83 100644 --- a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py @@ -356,6 +356,19 @@ def make_dataset(): num_runs_per_fingerprint=1, num_snapshot_shards_per_run=multiprocessing.cpu_count()) + @combinations.generate(test_base.default_test_combinations()) + def testRepeatAndPrefetch(self): + """This test reproduces github.com/tensorflow/tensorflow/issues/48903.""" + dataset = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32)) + dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir)) + dataset = dataset.shuffle(buffer_size=16) + dataset = dataset.batch(16) + dataset = dataset.repeat() + dataset = dataset.prefetch(1) + next_element = self.getNext(dataset) + for _ in range(30): + self.evaluate(next_element()) + class LegacySnapshotDatasetTest( reader_dataset_ops_test_base.TFRecordDatasetTestBase,