-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Data] Add performant way to read large tfrecord datasets #42277
Changes from 6 commits
af7b677
f25fa2e
14ed874
3b8bf91
b303cac
ddfca72
c3beb21
4ac5221
552c2e0
4a007c6
3c20737
bf061db
3d7bd33
b8519a1
b6982a3
7934adc
386be26
1d83dce
2c49d51
953af46
c373724
e19c1fc
fb6b290
7d7a08f
ca27c49
aac5ec6
7408c10
030556c
681f753
bca5bff
8cf1c2d
9aa260b
befc187
413c1f0
1e2c627
9894178
bf06a37
bf64415
663c39c
550392e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,16 +2,26 @@ | |
from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Union | ||
|
||
import numpy as np | ||
import pyarrow | ||
|
||
from ray.data._internal.dataset_logger import DatasetLogger | ||
from ray.data.aggregate import AggregateFn | ||
from ray.data.block import Block | ||
from ray.data.datasource.file_based_datasource import FileBasedDatasource | ||
from ray.util.annotations import PublicAPI | ||
|
||
if TYPE_CHECKING: | ||
import pyarrow | ||
import pandas as pd | ||
import tensorflow as tf | ||
from tensorflow_metadata.proto.v0 import schema_pb2 | ||
|
||
from ray.data.dataset import Dataset | ||
|
||
|
||
DEFAULT_BATCH_SIZE = 2048 | ||
c21 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
logger = DatasetLogger(__name__) | ||
|
||
|
||
@PublicAPI(stability="alpha") | ||
class TFRecordDatasource(FileBasedDatasource): | ||
|
@@ -23,13 +33,23 @@ def __init__( | |
self, | ||
paths: Union[str, List[str]], | ||
tf_schema: Optional["schema_pb2.Schema"] = None, | ||
fast_read: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we change the name |
||
batch_size: Optional[int] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's also add the comment for these two parameters in docstring. |
||
**file_based_datasource_kwargs, | ||
): | ||
super().__init__(paths, **file_based_datasource_kwargs) | ||
|
||
self.tf_schema = tf_schema | ||
self._tf_schema = tf_schema | ||
self._fast_read = fast_read | ||
self._batch_size = batch_size or DEFAULT_BATCH_SIZE | ||
|
||
def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]: | ||
if self._fast_read: | ||
yield from self._fast_read_stream(f, path) | ||
else: | ||
yield from self._slow_read_stream(f, path) | ||
|
||
def _slow_read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
import pyarrow as pa | ||
import tensorflow as tf | ||
from google.protobuf.message import DecodeError | ||
|
@@ -46,14 +66,64 @@ def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]: | |
) | ||
|
||
yield pa.Table.from_pydict( | ||
_convert_example_to_dict(example, self.tf_schema) | ||
_convert_example_to_dict(example, self._tf_schema) | ||
) | ||
|
||
def _fast_read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
import tensorflow as tf | ||
from tfx_bsl.cc.tfx_bsl_extension.coders import ExamplesToRecordBatchDecoder | ||
|
||
full_path = self._resolve_full_path(path) | ||
|
||
compression = (self._open_stream_args or {}).get("compression", None) | ||
|
||
if compression: | ||
compression = compression.upper() | ||
|
||
tf_schema_string = ( | ||
self._tf_schema.SerializeToString() if self._tf_schema else None | ||
) | ||
|
||
decoder = ExamplesToRecordBatchDecoder(tf_schema_string) | ||
exception_thrown = None | ||
try: | ||
for record in tf.data.TFRecordDataset( | ||
full_path, compression_type=compression | ||
).batch(self._batch_size): | ||
yield pyarrow.Table.from_batches([decoder.DecodeBatch(record.numpy())]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DecoderBatch will create the large_list data type and when creating a tensorflow dataset via the dataset.to_tf method will fail with the message: NotImplementedError: large_list |
||
except Exception as error: | ||
logger.get_logger().exception(f"Failed to read TFRecord file {full_path}") | ||
exception_thrown = error | ||
|
||
# we need to do this hack were we raise an exception outside of the | ||
# except block because tensorflow DataLossError is unpickable, and | ||
# even if we raise a runtime error, ray keeps information about the | ||
# original error, which makes it unpickable still. | ||
if exception_thrown: | ||
raise RuntimeError(f"Failed to read TFRecord file {full_path}.") | ||
|
||
def _resolve_full_path(self, relative_path): | ||
if isinstance(self._filesystem, pyarrow.fs.S3FileSystem): | ||
return f"s3://{relative_path}" | ||
if isinstance(self._filesystem, pyarrow.fs.GcsFileSystem): | ||
return f"gs://{relative_path}" | ||
if isinstance(self._filesystem, pyarrow.fs.HadoopFileSystem): | ||
return f"hdfs:///{relative_path}" | ||
if isinstance(self._filesystem, pyarrow.fs.PyFileSystem): | ||
protocol = self._filesystem.handler.fs.protocol | ||
if isinstance(protocol, list) or isinstance(protocol, tuple): | ||
protocol = protocol[0] | ||
if protocol == "gcs": | ||
protocol = "gs" | ||
return f"{protocol}://{relative_path}" | ||
|
||
return relative_path | ||
|
||
|
||
def _convert_example_to_dict( | ||
example: "tf.train.Example", | ||
tf_schema: Optional["schema_pb2.Schema"], | ||
) -> Dict[str, "pyarrow.Array"]: | ||
) -> Dict[str, pyarrow.Array]: | ||
record = {} | ||
schema_dict = {} | ||
# Convert user-specified schema into dict for convenient mapping | ||
|
@@ -73,7 +143,7 @@ def _convert_example_to_dict( | |
|
||
|
||
def _convert_arrow_table_to_examples( | ||
arrow_table: "pyarrow.Table", | ||
arrow_table: pyarrow.Table, | ||
tf_schema: Optional["schema_pb2.Schema"] = None, | ||
) -> Iterable["tf.train.Example"]: | ||
import tensorflow as tf | ||
|
@@ -118,7 +188,7 @@ def _get_single_true_type(dct) -> str: | |
def _get_feature_value( | ||
feature: "tf.train.Feature", | ||
schema_feature_type: Optional["schema_pb2.FeatureType"] = None, | ||
) -> "pyarrow.Array": | ||
) -> pyarrow.Array: | ||
import pyarrow as pa | ||
|
||
underlying_feature_type = { | ||
|
@@ -361,6 +431,57 @@ def _read_records( | |
raise RuntimeError(error_message) from e | ||
|
||
|
||
def _infer_schema_and_transform(dataset: "Dataset"): | ||
list_sizes = dataset.aggregate(_MaxListSize(dataset.schema().names)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add a comment on why this function is needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @raulchen it will indeed mean an extra pass on the data. The reason it is needed is because tfx-bsl ExampleDecoder returns always list of lists when no schema is provided, and what this function is doing is infering the schema for those fields that are single value fields. Performance wise, some of our benchmarks on this implementation (we have had it for a while running internally), gives us more than 15X improvements compared to the current implementation. Some of our datasets take ~30m to load with the ray native implementation compared to less than 2m with this tfx-bsl implementation. Let me know if you need more benchmark numbers, happy to provide more There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for the explanation. it's okay to proceed without more benchmarks for now. |
||
|
||
return dataset.map_batches( | ||
_unwrap_single_value_lists, | ||
fn_kwargs={"col_lengths": list_sizes["max_list_size"]}, | ||
) | ||
|
||
|
||
def _unwrap_single_value_lists( | ||
batch: Dict[str, np.ndarray], col_lengths: Dict[str, int] | ||
): | ||
for col in col_lengths: | ||
if col_lengths[col] == 1: | ||
batch[col] = np.array( | ||
[x[0] if isinstance(x, np.ndarray) else x for x in batch[col]] | ||
) | ||
|
||
return batch | ||
|
||
|
||
class _MaxListSize(AggregateFn): | ||
def __init__(self, columns: List[str]): | ||
self._columns = columns | ||
super().__init__( | ||
init=self._init, | ||
merge=self._merge, | ||
accumulate_row=self._accumulate_row, | ||
finalize=lambda a: a, | ||
name="max_list_size", | ||
) | ||
|
||
def _init(self, k: str): | ||
return {col: 0 for col in self._columns} | ||
|
||
def _merge(self, acc1: Dict[str, int], acc2: Dict[str, int]): | ||
merged = {} | ||
for col in self._columns: | ||
merged[col] = max(acc1[col], acc2[col]) | ||
|
||
return merged | ||
|
||
def _accumulate_row(self, acc: Dict[str, int], row: "pd.Series"): | ||
for k in row: | ||
value = row[k] | ||
if value: | ||
acc[k] = max(len(value), acc[k]) | ||
|
||
return acc | ||
|
||
|
||
# Adapted from https://github.com/vahidk/tfrecord/blob/74b2d24a838081356d993ec0e147eaf59ccd4c84/tfrecord/writer.py#L57-L72 # noqa: E501 | ||
# | ||
# MIT License | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -1476,6 +1476,8 @@ def read_tfrecords( | |||||
tf_schema: Optional["schema_pb2.Schema"] = None, | ||||||
shuffle: Union[Literal["files"], None] = None, | ||||||
file_extensions: Optional[List[str]] = None, | ||||||
fast_read_batch_size: Optional[int] = None, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
fast_read_auto_infer_schema: bool = True, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
) -> Dataset: | ||||||
"""Create a :class:`~ray.data.Dataset` from TFRecord files that contain | ||||||
`tf.train.Example <https://www.tensorflow.org/api_docs/python/tf/train/Example>`_ | ||||||
scottjlee marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
@@ -1550,13 +1552,45 @@ def read_tfrecords( | |||||
shuffle: If setting to "files", randomly shuffle input files order before read. | ||||||
Defaults to not shuffle with ``None``. | ||||||
file_extensions: A list of file extensions to filter files by. | ||||||
batch_size: An int representing the number of consecutive elements of this | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
dataset to combine in a single batch when fast_read is used. | ||||||
fast_read_auto_infer_schema: Toggles the schema inference applied; applicable | ||||||
only if fast_read is used and tf_schema argument is missing. | ||||||
Defaults to True. | ||||||
|
||||||
Returns: | ||||||
A :class:`~ray.data.Dataset` that contains the example features. | ||||||
|
||||||
Raises: | ||||||
ValueError: If a file contains a message that isn't a ``tf.train.Example``. | ||||||
""" | ||||||
import platform | ||||||
|
||||||
fast_read = False | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
|
||||||
try: | ||||||
from tfx_bsl.cc.tfx_bsl_extension.coders import ( # noqa: F401 | ||||||
ExamplesToRecordBatchDecoder, | ||||||
) | ||||||
|
||||||
fast_read = True | ||||||
except ModuleNotFoundError: | ||||||
if platform.processor() == "arm": | ||||||
logger.warning( | ||||||
"This function depends on tfx-bsl which is currently not supported" | ||||||
martinbomio marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
" on devices with Apple silicon (e.g. M1) and requires an" | ||||||
" environment with x86 CPU architecture." | ||||||
) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when user has specified There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it does not support it right now, I find it a lot safer to not assume it will There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
So you are suggestion I just remove the whole block of trying to import and logging a warning which will also remove the fallback? |
||||||
else: | ||||||
logger.warning( | ||||||
"To use TFRecordDatasource with large datasets, please install" | ||||||
" tfx-bsl package with pip install tfx_bsl --no-dependencies`." | ||||||
) | ||||||
logger.info( | ||||||
"Falling back to slower strategy for reading tf.records. This" | ||||||
martinbomio marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
"reading strategy should be avoided when reading large datasets." | ||||||
) | ||||||
|
||||||
if meta_provider is None: | ||||||
meta_provider = get_generic_metadata_provider( | ||||||
TFRecordDatasource._FILE_EXTENSIONS | ||||||
|
@@ -1573,8 +1607,17 @@ def read_tfrecords( | |||||
shuffle=shuffle, | ||||||
include_paths=include_paths, | ||||||
file_extensions=file_extensions, | ||||||
fast_read=fast_read, | ||||||
batch_size=fast_read_batch_size, | ||||||
) | ||||||
return read_datasource(datasource, parallelism=parallelism) | ||||||
ds = read_datasource(datasource, parallelism=parallelism) | ||||||
|
||||||
if fast_read_auto_infer_schema and fast_read and not tf_schema: | ||||||
from ray.data.datasource.tfrecords_datasource import _infer_schema_and_transform | ||||||
|
||||||
return _infer_schema_and_transform(ds) | ||||||
|
||||||
return ds | ||||||
|
||||||
|
||||||
@PublicAPI(stability="alpha") | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this feels like a rather weird way to do it.. is this the intended way to use
tfx-bsl
? are there other more direct ways to import and use the logic inExamplesToRecordBatchDecoder
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
other more direct ways like adding the dependency like any other dependency or are you thinking about something else?
I tried adding it as a direct dependency in the data test dependencies, but it conflicted with other existing dependencies.
We do not really need to bring transitive dependencies since the only thing we need to use is the
ExamplesToRecordBatchDecoder
which is self contained and doesn't need any extra dependency to work. Another approach we could take is to add this class to ray repo itself, I am not really familiar with that part of ray codebase but we would have to add the c class and add a python bindingThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you elaborate on the specifics?
cc @can-anyscale // example that we might need multiple constraint files.
I am worrying that this is a hack, and might be not sustainable. Like is there any guarantee that when
tfx-bsl
has a new version in the future, this will still work? Do we expect users to all installtfx-bsl
with no--no-dependencies
This is not a test-only dependency, but a data dependency. at the end of the days, for it to be useful, it needs to work with other dependencies in key workflows / workloads. Tests and CI are proxies to "ray works for users".
It seems that this is only used as part of the internal implementation, not part of ray data interface? if that is the case, can we fork https://github.com/tensorflow/tfx-bsl , give it another package name, remove the other parts that we do not need, and build it from source, and publish a
tfs-bsl-ray
package just for this?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or maybe we should try to resolve the dependency conflicts..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, we do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:) this is more of a question of ray code owners. I am asking @scottjlee to help here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we can always make the assumption (or force users to) install with the
--no-dependencies
flag, since they might be using the library's other features. But that would mean the dependency resolution would be put on the user, which we should try to avoid.