Skip to content

Commit

Permalink
#tf-data Re-use GlobalShuffleIterator for the Bag dataset.
Browse files Browse the repository at this point in the history
Added an `AnyContext` class for use cases that need to pass
in either an `OpKernelContext` or `IteratorContext`.

For the `BagDataset::Get`, it needs to access `ctx.allocator`
from either `OpKernelContext` or `IteratorContext`.

PiperOrigin-RevId: 618331917
  • Loading branch information
yangustc07 authored and tensorflower-gardener committed Mar 23, 2024
1 parent 3bddf6e commit 51871ec
Show file tree
Hide file tree
Showing 11 changed files with 53 additions and 45 deletions.
10 changes: 5 additions & 5 deletions tensorflow/core/data/dataset_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ Status ProcessBatch(int64_t batch_size, int64_t num_elements,
return absl::OkStatus();
}

Status CopyBatch(CopyBatchParams params,
Status CopyBatch(AnyContext ctx,
std::vector<std::vector<Tensor>>&& batch_elements,
bool parallel_copy, std::vector<Tensor>* out_tensors) {
const size_t num_tuple_components = batch_elements.at(0).size();
Expand All @@ -799,7 +799,7 @@ Status CopyBatch(CopyBatchParams params,
TensorShape first_element_shape(first_element.shape());
TensorShape batch_component_shape({num_batch_elements});
batch_component_shape.AppendShape(first_element_shape);
out_tensors->emplace_back(params.allocator, first_element.dtype(),
out_tensors->emplace_back(ctx.allocator, first_element.dtype(),
batch_component_shape);
if (!out_tensors->back().IsInitialized()) {
return errors::ResourceExhausted(
Expand Down Expand Up @@ -837,7 +837,7 @@ Status CopyBatch(CopyBatchParams params,
if (parallel_copy && total_bytes >= (1 << 20)) {
Status status;
mutex status_mu;
const auto num_threads = params.runner_threadpool_size;
const auto num_threads = ctx.runner_threadpool_size;
const auto slice_size = num_batch_elements / num_threads;
int64_t offset = 0;
BlockingCounter counter(num_threads);
Expand All @@ -847,8 +847,8 @@ Status CopyBatch(CopyBatchParams params,
// evenly, the size of some slices is incremented to guarantee their
// sizes add up to the total number of elements.
if (i < num_batch_elements % num_threads) ++length;
(*params.runner)([offset, length, &status, &status_mu, &counter,
&copy_element_fn]() {
(*ctx.runner)([offset, length, &status, &status_mu, &counter,
&copy_element_fn]() {
Status s;
for (size_t j = offset; j < offset + length; ++j) {
s.Update(copy_element_fn(j));
Expand Down
21 changes: 1 addition & 20 deletions tensorflow/core/data/dataset_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,33 +279,14 @@ Status ProcessBatch(int64_t batch_size, int64_t num_elements,
IteratorContext* ctx, std::vector<Tensor>* output,
bool* end_of_sequence, std::vector<Tensor>* batch);

// Constructs and stores the parameters for the CopyBatch function.
struct CopyBatchParams {
Allocator* allocator;
std::function<void(std::function<void()>)>* runner;
int64 runner_threadpool_size;

explicit CopyBatchParams(IteratorContext* ctx) {
allocator = ctx->allocator({});
runner = ctx->runner();
runner_threadpool_size = ctx->runner_threadpool_size();
}

explicit CopyBatchParams(OpKernelContext* ctx) {
allocator = ctx->get_allocator({});
runner = ctx->runner();
runner_threadpool_size = GetRunnerThreadpoolSizeFromOpKernelContext(ctx);
}
};

// Copies the input elements to a batch.
//
// The `batch_elements` argument contains the individual elements to copy into a
// batch. The `parallel_copy` argument indicates whether to parallelize the
// copy.
// The `out_tensors` argument will be used to store the resulting batch (one for
// each component of the input).
Status CopyBatch(CopyBatchParams params,
Status CopyBatch(AnyContext ctx,
std::vector<std::vector<Tensor>>&& batch_elements,
bool parallel_copy, std::vector<Tensor>* out_tensors);

Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/data/global_shuffle_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ absl::Status GlobalShuffleIterator::GetNext(IteratorContext* ctx,

absl::MutexLock l(&mu_);
int64_t output_index = ctx->index_mapper()(element_count_++);
absl::Status status = dataset_->Get(output_index, out_tensors);
absl::Status status =
dataset_->Get(AnyContext(ctx), output_index, out_tensors);
if (absl::IsOutOfRange(status)) {
*end_of_sequence = true;
return absl::OkStatus();
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/framework/dataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,8 @@ Status DatasetBase::Get(OpKernelContext* ctx, int64 index,
DebugString());
}

Status DatasetBase::Get(int64 index, std::vector<Tensor>* out_tensors) const {
Status DatasetBase::Get(AnyContext ctx, int64 index,
std::vector<Tensor>* out_tensors) const {
return errors::Unimplemented("Random access is not implemented for dataset ",
DebugString());
}
Expand Down
26 changes: 24 additions & 2 deletions tensorflow/core/framework/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,26 @@ class IteratorContext {
MemoryCheckpoint checkpoint_;
};

// Generic context that can be constructed with either an `OpKernelContext` or
// `IteratorContext`.
struct AnyContext {
Allocator* allocator;
std::function<void(std::function<void()>)>* runner;
int64_t runner_threadpool_size;

explicit AnyContext(IteratorContext* ctx) {
allocator = ctx->allocator({});
runner = ctx->runner();
runner_threadpool_size = ctx->runner_threadpool_size();
}

explicit AnyContext(OpKernelContext* ctx) {
allocator = ctx->get_allocator({});
runner = ctx->runner();
runner_threadpool_size = GetRunnerThreadpoolSizeFromOpKernelContext(ctx);
}
};

// Represents the current position in a range of outputs, where the
// range of outputs is typically represented by an `DatasetBase`,
// defined below.
Expand Down Expand Up @@ -1346,9 +1366,11 @@ class DatasetBase : public core::RefCounted {
virtual Status Get(OpKernelContext* ctx, int64 index,
std::vector<Tensor>* out_tensors) const;

// Same as above, but without an `OpKernelContext`. Used to support datasets
// Same as above, but with an `AnyContext`, which can be constructed from
// either an `OpKernelContext` or `IteratorContext`. Used to support datasets
// that provide random access through both the dataset and iterator APIs.
virtual Status Get(int64 index, std::vector<Tensor>* out_tensors) const;
virtual Status Get(AnyContext ctx, int64 index,
std::vector<Tensor>* out_tensors) const;

// Returns true if the dataset and its inputs support random access.
virtual absl::Status RandomIndexingCompatible() const {
Expand Down
8 changes: 3 additions & 5 deletions tensorflow/core/kernels/data/batch_dataset_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,8 @@ class BatchDatasetOp::Dataset : public DatasetBase {
TF_RETURN_IF_ERROR(input_->Get(ctx, i, &batch_element_tuple));
batch_elements.emplace_back(std::move(batch_element_tuple));
}
TF_RETURN_IF_ERROR(CopyBatch(CopyBatchParams(ctx),
std::move(batch_elements), parallel_copy_,
out_tensors));
TF_RETURN_IF_ERROR(CopyBatch(AnyContext(ctx), std::move(batch_elements),
parallel_copy_, out_tensors));
return absl::OkStatus();
}

Expand Down Expand Up @@ -247,8 +246,7 @@ class BatchDatasetOp::Dataset : public DatasetBase {
// respective slice locations. This would require a different GetNext()
// overload that supports zero-copy, and might make sense in an
// optimization pass.
TF_RETURN_IF_ERROR(CopyBatch(CopyBatchParams(ctx),
std::move(batch_elements),
TF_RETURN_IF_ERROR(CopyBatch(AnyContext(ctx), std::move(batch_elements),
dataset()->parallel_copy_, out_tensors));

*end_of_sequence = false;
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/core/kernels/data/experimental/list_dataset_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/data/global_shuffle_utils.h"
#include "tensorflow/core/data/name_utils.h"
#include "tensorflow/core/data/split_utils.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_util.h"
Expand Down Expand Up @@ -96,10 +97,10 @@ class ListDatasetOp::Dataset : public DatasetBase {

absl::Status Get(OpKernelContext* ctx, int64_t index,
std::vector<Tensor>* out_tensors) const override {
return Get(index, out_tensors);
return Get(AnyContext(ctx), index, out_tensors);
}

absl::Status Get(int64_t index,
absl::Status Get(AnyContext ctx, int64_t index,
std::vector<Tensor>* out_tensors) const override {
TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index));
out_tensors->clear();
Expand Down
5 changes: 2 additions & 3 deletions tensorflow/core/kernels/data/parallel_batch_dataset_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,8 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase {
Status status;
{
mutex_lock l(result->mu);
status =
CopyBatch(CopyBatchParams(ctx.get()), std::move(batch_elements),
dataset()->parallel_copy_, &result->output);
status = CopyBatch(AnyContext(ctx.get()), std::move(batch_elements),
dataset()->parallel_copy_, &result->output);
result->status.Update(status);

if (result->status.ok()) {
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/core/kernels/data/range_dataset_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,11 @@ class RangeDatasetOp::Dataset : public DatasetBase {

Status Get(OpKernelContext* ctx, int64 index,
std::vector<Tensor>* out_tensors) const override {
return Get(index, out_tensors);
return Get(AnyContext(ctx), index, out_tensors);
}

Status Get(int64 index, std::vector<Tensor>* out_tensors) const override {
Status Get(AnyContext ctx, int64 index,
std::vector<Tensor>* out_tensors) const override {
TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index));
return ConvertOutputTypes(output_dtypes(), out_tensors,
start_ + (index * step_));
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/core/kernels/data/tensor_dataset_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/data/global_shuffle_utils.h"
#include "tensorflow/core/data/name_utils.h"
#include "tensorflow/core/data/split_utils.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph.h"
Expand Down Expand Up @@ -88,10 +89,11 @@ class TensorDatasetOp::Dataset : public DatasetBase {

Status Get(OpKernelContext* ctx, int64 index,
std::vector<Tensor>* out_tensors) const override {
return Get(index, out_tensors);
return Get(AnyContext(ctx), index, out_tensors);
}

Status Get(int64 index, std::vector<Tensor>* out_tensors) const override {
Status Get(AnyContext ctx, int64 index,
std::vector<Tensor>* out_tensors) const override {
TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index));
*out_tensors = tensors_;
return absl::OkStatus();
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/data/global_shuffle_utils.h"
#include "tensorflow/core/data/name_utils.h"
#include "tensorflow/core/data/split_utils.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_util.h"
Expand Down Expand Up @@ -100,10 +101,11 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase {

Status Get(OpKernelContext* ctx, int64 index,
std::vector<Tensor>* out_tensors) const override {
return Get(index, out_tensors);
return Get(AnyContext(ctx), index, out_tensors);
}

Status Get(int64 index, std::vector<Tensor>* out_tensors) const override {
Status Get(AnyContext ctx, int64 index,
std::vector<Tensor>* out_tensors) const override {
TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index));
out_tensors->clear();
out_tensors->reserve(tensors_.size());
Expand Down

0 comments on commit 51871ec

Please sign in to comment.