Skip to content

Commit

Permalink
#tf-data Support global shuffle for the Tensor slices dataset.
Browse files Browse the repository at this point in the history
FUTURE_COPYBARA_INTEGRATE_REVIEW=#62782 from DanielYang59:warn-shuffle 134fc24
PiperOrigin-RevId: 615218481
  • Loading branch information
yangustc07 authored and tensorflower-gardener committed Mar 13, 2024
1 parent ac89199 commit a0c45e2
Show file tree
Hide file tree
Showing 15 changed files with 141 additions and 77 deletions.
Expand Up @@ -130,62 +130,6 @@ enum OpSet {
STABLEHLO = 4;
}

// Defines various calibration options.
// NEXT ID: 3
// TODO(b/326751656): remove after all the dependencies moved to
// stableho.quantization.CalibrationOptions.
message CalibrationOptions {
// Configurations for calibration methods.
// NEXT ID: 7
enum CalibrationMethod {
CALIBRATION_METHOD_UNSPECIFIED = 0;
// Use the min, max values of all sample datasets.
CALIBRATION_METHOD_MIN_MAX = 1;
// Use the average of min, max values in each sample dataset.
CALIBRATION_METHOD_AVERAGE_MIN_MAX = 2;
// Use the min/max percentile value of histogram.
CALIBRATION_METHOD_HISTOGRAM_PERCENTILE = 3;
// Use the histogram mid values that minimize MSE error.
// This is very slow algorithm because it computes all errors for all
// histogram mid value pairs. Therefore the value of num_bins is recommended
// to be 256 or less.
CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE = 4;
// Use the histogram mid values that minimize MSE error.
// This is an algorithm that finds the bin with the max frequency in the
// histogram and expands the range.
CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY = 5;
// Use the histogram mid values that minimize MSE error. This is an
// algorithm that starts with the center in thehistogram and expands the
// range.
CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC = 6;
}

// Parameters required for calibration.
// NEXT ID: 4
message CalibrationParameters {
// The number of bins when histogram is initialized. It can be increased
// because histogram is dynamically expanded by sample inputs.
// initial_num_bins is 256 by default.
int32 initial_num_bins = 1;
// min_percentile is only used in HISTOGRAM_PERCENTILE.
// min_percentile is 0.001 by default.
float min_percentile = 2;
// max_percentile is only used in HISTOGRAM_PERCENTILE.
// max_percentile is 99.999 by default.
float max_percentile = 3;
}

// Determines how to calibrate.
// The default calibration method is MIN_MAX.
CalibrationMethod calibration_method = 1;

// Defines the parameters required for calibration. Parameters such as the
// number of bins in the histogram and percentile belong to it.
// MIN_MAX and AVERAGE_MIN_MAX don't require this parameter and methods
// starting with HISTOGRAM require this parameter.
CalibrationParameters calibration_parameters = 2;
}

// The data format of each sample in the representative dataset.
message RepresentativeDataSample {
map<string, TensorProto> tensor_proto_inputs = 2;
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/core/kernels/data/BUILD
Expand Up @@ -1357,6 +1357,9 @@ tf_kernel_library(
"//tensorflow/core/data:dataset_utils",
"//tensorflow/core/data:name_utils",
"//tensorflow/core/data:split_utils",
"@com_google_absl//absl/status",
"@local_tsl//tsl/platform:mutex",
"@local_tsl//tsl/platform:thread_annotations",
],
)

Expand Down
38 changes: 38 additions & 0 deletions tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#include <string>
#include <utility>

#include "absl/status/status.h"
#include "tensorflow/core/data/dataset_utils.h"
#include "tensorflow/core/data/name_utils.h"
#include "tensorflow/core/data/split_utils.h"
Expand All @@ -25,6 +26,8 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_util.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/util/batch_util.h"
#include "tsl/platform/mutex.h"
#include "tsl/platform/thread_annotations.h"

namespace tensorflow {
namespace data {
Expand Down Expand Up @@ -96,6 +99,10 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase {

Status Get(OpKernelContext* ctx, int64 index,
std::vector<Tensor>* out_tensors) const override {
return Get(index, out_tensors);
}

Status Get(int64 index, std::vector<Tensor>* out_tensors) const {
TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index));
out_tensors->clear();
out_tensors->reserve(tensors_.size());
Expand All @@ -105,6 +112,10 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase {
return absl::OkStatus();
}

absl::Status RandomIndexingCompatible() const override {
return absl::OkStatus();
}

protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Expand Down Expand Up @@ -163,6 +174,10 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
if (ctx->index_mapper() != nullptr) {
return Get(ctx, out_tensors, end_of_sequence);
}

Tensor split;
TF_RETURN_IF_ERROR(split_provider_->GetNext(&split, end_of_sequence));
if (*end_of_sequence) {
Expand All @@ -178,6 +193,19 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase {
return absl::OkStatus();
}

absl::Status Get(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
tsl::mutex_lock l(mu_);
int64_t output_index = ctx->index_mapper()(element_count_++);
absl::Status status = dataset()->Get(output_index, out_tensors);
if (absl::IsOutOfRange(status)) {
*end_of_sequence = true;
return absl::OkStatus();
}
*end_of_sequence = false;
return absl::OkStatus();
}

protected:
std::shared_ptr<model::Node> CreateNode(
IteratorContext* ctx, model::Node::Args args) const override {
Expand All @@ -192,12 +220,22 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase {

Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
if (ctx->restored_element_count().has_value()) {
tsl::mutex_lock l(mu_);
element_count_ = *(ctx->restored_element_count());
return absl::OkStatus();
}
return split_provider_->Restore(
[this](const std::string& key) { return full_name(key); }, reader);
}

private:
std::shared_ptr<SplitProvider> split_provider_;

mutable tsl::mutex mu_;
// Count of elements produced by this iterator when it runs in the random
// access mode.
int64_t element_count_ TF_GUARDED_BY(mu_) = 0;
};

const std::vector<Tensor> tensors_;
Expand Down
1 change: 1 addition & 0 deletions tensorflow/python/data/kernel_tests/BUILD
Expand Up @@ -457,6 +457,7 @@ tf_py_strict_test(
deps = [
":checkpoint_test_base",
":test_base",
"//tensorflow/python/data/experimental/ops:global_shuffle_op",
"//tensorflow/python/data/experimental/ops:random_access",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:options",
Expand Down
73 changes: 73 additions & 0 deletions tensorflow/python/data/kernel_tests/from_tensor_slices_test.py
Expand Up @@ -14,10 +14,12 @@
# ==============================================================================
"""Tests for `tf.data.Dataset.from_tensor_slices()."""
import collections
from typing import Callable, Optional

from absl.testing import parameterized
import numpy as np

from tensorflow.python.data.experimental.ops import global_shuffle_op
from tensorflow.python.data.experimental.ops import random_access
from tensorflow.python.data.kernel_tests import checkpoint_test_base
from tensorflow.python.data.kernel_tests import test_base
Expand Down Expand Up @@ -404,5 +406,76 @@ def testDict(self, verify_fn):
num_outputs=3)


class FromTensorSlicesGlobalShuffleTest(
test_base.DatasetTestBase, parameterized.TestCase):

@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
dataset_range=[10, 100],
repetitions=[1, 3],
seed=[None, 19],
reshuffle_each_iteration=[True, False])))
def testGlobalShuffleTensorSlicesDataset(
self,
dataset_range: int,
repetitions: int,
seed: Optional[int],
reshuffle_each_iteration: bool):
dataset = dataset_ops.Dataset.from_tensor_slices(list(range(dataset_range)))
if repetitions > 1:
dataset = dataset.repeat(repetitions)
dataset = global_shuffle_op._global_shuffle(
dataset, seed=seed, reshuffle_each_iteration=reshuffle_each_iteration)
dataset_output = self.getDatasetOutput(
dataset, requires_initialization=True)

expected = list(range(dataset_range)) * repetitions
self.assertCountEqual(dataset_output, expected)
self.assertNotEqual(dataset_output, expected)
self.assertLen(expected, self.evaluate(dataset.cardinality()))


class FromTensorSlicesGlobalShuffleCheckpointTest(
checkpoint_test_base.CheckpointTestBase, parameterized.TestCase):

@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
checkpoint_test_base.default_test_combinations(),
combinations.combine(
dataset_range=[10],
repetitions=[1, 3],
reshuffle_each_iteration=[True, False],
symbolic_checkpoint=[True, False])))
def testGlobalShuffleTensorSlicesDataset(
self,
verify_fn: Callable[..., None],
dataset_range: int,
repetitions: int,
reshuffle_each_iteration: bool,
symbolic_checkpoint: bool):

def _build_dataset() -> dataset_ops.Dataset:
dataset = dataset_ops.Dataset.from_tensor_slices(
list(range(dataset_range)))
if repetitions > 1:
dataset = dataset.repeat(repetitions)
dataset = global_shuffle_op._global_shuffle(
dataset, seed=42, reshuffle_each_iteration=reshuffle_each_iteration)
if symbolic_checkpoint:
options = options_lib.Options()
options.experimental_symbolic_checkpoint = symbolic_checkpoint
dataset = dataset.with_options(options)
return dataset

verify_fn(
self,
_build_dataset,
num_outputs=dataset_range * repetitions,
assert_items_equal=reshuffle_each_iteration)


if __name__ == "__main__":
test.main()
10 changes: 7 additions & 3 deletions tensorflow/python/data/ops/dataset_ops.py
Expand Up @@ -1408,7 +1408,7 @@ def enumerate(self, start=0, name=None) -> "DatasetV2":
return Dataset.zip((range_dataset, self), name=name)

def shuffle(
self, buffer_size, seed=None, reshuffle_each_iteration=None, name=None
self, buffer_size, seed=None, reshuffle_each_iteration=True, name=None
) -> "DatasetV2":
"""Randomly shuffles the elements of this dataset.
Expand All @@ -1424,8 +1424,12 @@ def shuffle(
maintaining the 1,000 element buffer.
`reshuffle_each_iteration` controls whether the shuffle order should be
different for each epoch. In TF 1.X, the idiomatic way to create epochs
was through the `repeat` transformation:
different for each epoch. However you should avoid using
`shuffle(reshuffle_each_iteration=True)`, then `take` and `skip` to split
a dataset into training and test sets, which would lead to data leakage (as
the entire dataset would be re-shuffled then re-split after each epoch).
Please use the `tf.keras.utils.split_dataset` method instead. In TF 1.X,
the idiomatic way to create epochs was through the `repeat` transformation:
```python
dataset = tf.data.Dataset.range(3)
Expand Down
21 changes: 11 additions & 10 deletions tensorflow/python/data/ops/shuffle_op.py
Expand Up @@ -26,28 +26,29 @@ def _shuffle( # pylint: disable=unused-private-name
input_dataset,
buffer_size,
seed=None,
reshuffle_each_iteration=None,
name=None):
reshuffle_each_iteration=True,
name=None,
):
return _ShuffleDataset(
input_dataset, buffer_size, seed, reshuffle_each_iteration, name=name)


class _ShuffleDataset(dataset_ops.UnaryUnchangedStructureDataset):
"""A `Dataset` that randomly shuffles the elements of its input."""

def __init__(self,
input_dataset,
buffer_size,
seed=None,
reshuffle_each_iteration=None,
name=None):
def __init__(
self,
input_dataset,
buffer_size,
seed=None,
reshuffle_each_iteration=True,
name=None,
):
"""See `Dataset.shuffle()` for details."""
self._input_dataset = input_dataset
self._buffer_size = ops.convert_to_tensor(
buffer_size, dtype=dtypes.int64, name="buffer_size")
self._seed, self._seed2 = random_seed.get_seed(seed)
if reshuffle_each_iteration is None:
reshuffle_each_iteration = True
self._reshuffle_each_iteration = reshuffle_each_iteration
self._name = name

Expand Down
Expand Up @@ -160,7 +160,7 @@ tf_class {
}
member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
}
member_method {
name: "skip"
Expand Down
Expand Up @@ -162,7 +162,7 @@ tf_class {
}
member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
}
member_method {
name: "skip"
Expand Down
Expand Up @@ -161,7 +161,7 @@ tf_class {
}
member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
}
member_method {
name: "skip"
Expand Down
Expand Up @@ -162,7 +162,7 @@ tf_class {
}
member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
}
member_method {
name: "skip"
Expand Down
Expand Up @@ -162,7 +162,7 @@ tf_class {
}
member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
}
member_method {
name: "skip"
Expand Down
Expand Up @@ -163,7 +163,7 @@ tf_class {
}
member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
}
member_method {
name: "skip"
Expand Down
Expand Up @@ -162,7 +162,7 @@ tf_class {
}
member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
}
member_method {
name: "skip"
Expand Down
Expand Up @@ -163,7 +163,7 @@ tf_class {
}
member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\'], "
}
member_method {
name: "skip"
Expand Down

0 comments on commit a0c45e2

Please sign in to comment.