Skip to content

Commit

Permalink
Merge pull request #49121 from yangustc07/cherrypicks_WBVJW
Browse files Browse the repository at this point in the history
<release 2.4>-<rc2> cherry-pick request: [tf.data] Fix snapshot segfault when using repeat and prefetch.
  • Loading branch information
mihaimaruseac committed May 19, 2021
2 parents 41b3fa3 + 14642ad commit 42db85a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
17 changes: 6 additions & 11 deletions tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc
Expand Up @@ -191,8 +191,6 @@ class SnapshotDatasetV2Op::Dataset::Iterator::Reader

explicit Reader(const Params& params, int64 start_index);

~Reader() override;

Status Initialize(IteratorContext* ctx) override;

Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
Expand All @@ -212,7 +210,7 @@ class SnapshotDatasetV2Op::Dataset::Iterator::Reader

std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);

DatasetBase* input_ TF_GUARDED_BY(mu_);
DatasetBase* input_ TF_GUARDED_BY(mu_) = nullptr;

std::unique_ptr<InstantiatedCapturedFunction> instantiated_reader_func_
TF_GUARDED_BY(mu_);
Expand Down Expand Up @@ -451,7 +449,11 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::GetNextInternal(
bool* end_of_sequence) {
mutex_lock l(mu_);
if (iterator_ == nullptr) {
TF_RETURN_IF_ERROR(InitializeIterator(ctx, nullptr));
Status s = InitializeIterator(ctx, /*reader=*/nullptr);
if (!s.ok()) {
iterator_.reset();
return s;
}
}
index_++;
return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
Expand Down Expand Up @@ -530,8 +532,6 @@ SnapshotDatasetV2Op::Dataset::Iterator::Reader::Reader(const Params& params,
int64 start_index)
: DatasetIterator<Dataset>(params), start_index_(start_index) {}

SnapshotDatasetV2Op::Dataset::Iterator::Reader::~Reader() { input_->Unref(); }

Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
IteratorContext* ctx) {
mutex_lock l(mu_);
Expand Down Expand Up @@ -578,11 +578,6 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
"reader_func returns more than one argument.");
}
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(reader_output[0], &input_));

// We need to take a reference here as we will use the input_ and
// its iterator.
input_->Ref();

return input_->MakeIterator(ctx, this, prefix(), &input_impl_);
}

Expand Down
13 changes: 13 additions & 0 deletions tensorflow/python/data/experimental/kernel_tests/snapshot_test.py
Expand Up @@ -356,6 +356,19 @@ def make_dataset():
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=multiprocessing.cpu_count())

@combinations.generate(test_base.default_test_combinations())
def testRepeatAndPrefetch(self):
"""This test reproduces github.com/tensorflow/tensorflow/issues/48903."""
dataset = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32))
dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
dataset = dataset.shuffle(buffer_size=16)
dataset = dataset.batch(16)
dataset = dataset.repeat()
dataset = dataset.prefetch(1)
next_element = self.getNext(dataset)
for _ in range(30):
self.evaluate(next_element())


class LegacySnapshotDatasetTest(
reader_dataset_ops_test_base.TFRecordDatasetTestBase,
Expand Down

0 comments on commit 42db85a

Please sign in to comment.