Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Add performant way to read large tfrecord datasets #42277

Merged
merged 40 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
af7b677
feat: add performant way to read large tfrecord datasets
martinbomio Jan 9, 2024
f25fa2e
add tfx-bsl as a test dependency
martinbomio Jan 9, 2024
14ed874
address PR comments
martinbomio Jan 11, 2024
3b8bf91
properly enable/disable fast read on tests
martinbomio Jan 16, 2024
b303cac
resolve rabsolute path from relative
martinbomio Jan 24, 2024
ddfca72
add tensorflow-io for s3 fs impl
martinbomio Jan 24, 2024
c3beb21
try adding tfx-bsl, cython in data-test-requirements
scottjlee Jan 24, 2024
4ac5221
skip tfx-bsl install in data
scottjlee Jan 25, 2024
552c2e0
Apply suggestions from code review
martinbomio Jan 25, 2024
4a007c6
new tfx-bsl build
scottjlee Jan 25, 2024
3c20737
Merge branch 'martinbomio/fast-tfrecord-read' of https://github.com/m…
scottjlee Jan 25, 2024
bf061db
lint
scottjlee Jan 25, 2024
3d7bd33
fix missing build dependency
scottjlee Jan 25, 2024
b8519a1
add datatfxbsl build
scottjlee Jan 25, 2024
b6982a3
remove workers arg
scottjlee Jan 25, 2024
7934adc
worker config
scottjlee Jan 25, 2024
386be26
data target
scottjlee Jan 26, 2024
1d83dce
try pinning pandas<2
scottjlee Jan 26, 2024
2c49d51
pin pandas to pandas==1.5.3
scottjlee Jan 26, 2024
953af46
comment
scottjlee Jan 26, 2024
c373724
update tag
scottjlee Jan 26, 2024
e19c1fc
add tfxbsl dockerfile
scottjlee Jan 29, 2024
fb6b290
add crc32c
scottjlee Jan 30, 2024
7d7a08f
Merge branch 'master' into martinbomio/fast-tfrecord-read
scottjlee Jan 30, 2024
ca27c49
rewrite unwrap single value function to use pyarrow
martinbomio Feb 5, 2024
aac5ec6
Merge branch 'master' into martinbomio/fast-tfrecord-read
martinbomio Feb 6, 2024
7408c10
Merge branch 'martinbomio/fast-tfrecord-read' of https://github.com/m…
scottjlee Feb 7, 2024
030556c
cast large_list to list always on fast read
martinbomio Feb 13, 2024
681f753
move casting to datasource
martinbomio Feb 13, 2024
bca5bff
Merge branch 'master' into martinbomio/fast-tfrecord-read
martinbomio Feb 14, 2024
8cf1c2d
clean up docstrings
martinbomio Feb 20, 2024
9aa260b
rename fast_* variables to tfx_
martinbomio Feb 26, 2024
befc187
fix failing tests
martinbomio Feb 26, 2024
413c1f0
add flag in data context to disable using tfx read
martinbomio Feb 26, 2024
1e2c627
disable tfx_read by default
martinbomio Feb 27, 2024
9894178
add TFXREadOptions
martinbomio Feb 28, 2024
bf06a37
Merge branch 'master' into martinbomio/fast-tfrecord-read
martinbomio Feb 28, 2024
bf64415
fix build
martinbomio Feb 28, 2024
663c39c
Merge branch 'master' into martinbomio/fast-tfrecord-read
martinbomio Feb 28, 2024
550392e
Merge branch 'master' into martinbomio/fast-tfrecord-read
c21 Feb 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 5 additions & 0 deletions ci/docker/data.build.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ sudo apt-get purge -y mongodb*
sudo apt-get install -y mongodb
sudo rm -rf /var/lib/mongodb/mongod.lock

# Dependency used for read_tfrecords function.
# Given that we only use the ExamplesToRecordBatchDecoder
# which is purley c++, we can isntall it with --no-dependencies.
pip install tfx-bsl==1.14.0 --no-dependencies
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels like a rather weird way to do it.. is this the intended way to use tfx-bsl? are there other more direct ways to import and use the logic in ExamplesToRecordBatchDecoder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

other more direct ways like adding the dependency like any other dependency or are you thinking about something else?

I tried adding it as a direct dependency in the data test dependencies, but it conflicted with other existing dependencies.

We do not really need to bring transitive dependencies since the only thing we need to use is the ExamplesToRecordBatchDecoder which is self contained and doesn't need any extra dependency to work. Another approach we could take is to add this class to ray repo itself, I am not really familiar with that part of ray codebase but we would have to add the c class and add a python binding

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

conflicted with other existing dependencies.

could you elaborate on the specifics?

cc @can-anyscale // example that we might need multiple constraint files.

I am worrying that this is a hack, and might be not sustainable. Like is there any guarantee that when tfx-bsl has a new version in the future, this will still work? Do we expect users to all install tfx-bsl with no --no-dependencies

This is not a test-only dependency, but a data dependency. at the end of the days, for it to be useful, it needs to work with other dependencies in key workflows / workloads. Tests and CI are proxies to "ray works for users".

It seems that this is only used as part of the internal implementation, not part of ray data interface? if that is the case, can we fork https://github.com/tensorflow/tfx-bsl , give it another package name, remove the other parts that we do not need, and build it from source, and publish a tfs-bsl-ray package just for this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or maybe we should try to resolve the dependency conflicts..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect users to all install tfx-bsl with no --no-dependencies

yes, we do

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect users to all install tfx-bsl with no --no-dependencies

yes, we do

:) this is more of a question of ray code owners. I am asking @scottjlee to help here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can always make the assumption (or force users to) install with the --no-dependencies flag, since they might be using the library's other features. But that would mean the dependency resolution would be put on the user, which we should try to avoid.


if [[ $RAY_CI_JAVA_BUILD == 1 ]]; then
# These packages increase the image size quite a bit, so we only install them
# as needed.
Expand Down
133 changes: 127 additions & 6 deletions python/ray/data/datasource/tfrecords_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,26 @@
from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Union

import numpy as np
import pyarrow

from ray.data._internal.dataset_logger import DatasetLogger
from ray.data.aggregate import AggregateFn
from ray.data.block import Block
from ray.data.datasource.file_based_datasource import FileBasedDatasource
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
import pyarrow
import pandas as pd
import tensorflow as tf
from tensorflow_metadata.proto.v0 import schema_pb2

from ray.data.dataset import Dataset


DEFAULT_BATCH_SIZE = 2048
c21 marked this conversation as resolved.
Show resolved Hide resolved

logger = DatasetLogger(__name__)


@PublicAPI(stability="alpha")
class TFRecordDatasource(FileBasedDatasource):
Expand All @@ -23,13 +33,23 @@ def __init__(
self,
paths: Union[str, List[str]],
tf_schema: Optional["schema_pb2.Schema"] = None,
fast_read: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change the name fast_read to read_with_tfx?
The fast_read looks a bit vague to me.

batch_size: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also add the comment for these two parameters in docstring.

**file_based_datasource_kwargs,
):
super().__init__(paths, **file_based_datasource_kwargs)

self.tf_schema = tf_schema
self._tf_schema = tf_schema
self._fast_read = fast_read
self._batch_size = batch_size or DEFAULT_BATCH_SIZE

def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
if self._fast_read:
yield from self._fast_read_stream(f, path)
else:
yield from self._slow_read_stream(f, path)

def _slow_read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_slow_read_stream -> _default_read_stream

import pyarrow as pa
import tensorflow as tf
from google.protobuf.message import DecodeError
Expand All @@ -46,14 +66,64 @@ def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
)

yield pa.Table.from_pydict(
_convert_example_to_dict(example, self.tf_schema)
_convert_example_to_dict(example, self._tf_schema)
)

def _fast_read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_fast_read_stream -> tfx_read_stream

import tensorflow as tf
from tfx_bsl.cc.tfx_bsl_extension.coders import ExamplesToRecordBatchDecoder

full_path = self._resolve_full_path(path)

compression = (self._open_stream_args or {}).get("compression", None)

if compression:
compression = compression.upper()

tf_schema_string = (
self._tf_schema.SerializeToString() if self._tf_schema else None
)

decoder = ExamplesToRecordBatchDecoder(tf_schema_string)
exception_thrown = None
try:
for record in tf.data.TFRecordDataset(
full_path, compression_type=compression
).batch(self._batch_size):
yield pyarrow.Table.from_batches([decoder.DecodeBatch(record.numpy())])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DecoderBatch will create the large_list data type and when creating a tensorflow dataset via the dataset.to_tf method will fail with the message: NotImplementedError: large_list

except Exception as error:
logger.get_logger().exception(f"Failed to read TFRecord file {full_path}")
exception_thrown = error

# we need to do this hack were we raise an exception outside of the
# except block because tensorflow DataLossError is unpickable, and
# even if we raise a runtime error, ray keeps information about the
# original error, which makes it unpickable still.
if exception_thrown:
raise RuntimeError(f"Failed to read TFRecord file {full_path}.")

def _resolve_full_path(self, relative_path):
if isinstance(self._filesystem, pyarrow.fs.S3FileSystem):
return f"s3://{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.GcsFileSystem):
return f"gs://{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.HadoopFileSystem):
return f"hdfs:///{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.PyFileSystem):
protocol = self._filesystem.handler.fs.protocol
if isinstance(protocol, list) or isinstance(protocol, tuple):
protocol = protocol[0]
if protocol == "gcs":
protocol = "gs"
return f"{protocol}://{relative_path}"

return relative_path


def _convert_example_to_dict(
example: "tf.train.Example",
tf_schema: Optional["schema_pb2.Schema"],
) -> Dict[str, "pyarrow.Array"]:
) -> Dict[str, pyarrow.Array]:
record = {}
schema_dict = {}
# Convert user-specified schema into dict for convenient mapping
Expand All @@ -73,7 +143,7 @@ def _convert_example_to_dict(


def _convert_arrow_table_to_examples(
arrow_table: "pyarrow.Table",
arrow_table: pyarrow.Table,
tf_schema: Optional["schema_pb2.Schema"] = None,
) -> Iterable["tf.train.Example"]:
import tensorflow as tf
Expand Down Expand Up @@ -118,7 +188,7 @@ def _get_single_true_type(dct) -> str:
def _get_feature_value(
feature: "tf.train.Feature",
schema_feature_type: Optional["schema_pb2.FeatureType"] = None,
) -> "pyarrow.Array":
) -> pyarrow.Array:
import pyarrow as pa

underlying_feature_type = {
Expand Down Expand Up @@ -361,6 +431,57 @@ def _read_records(
raise RuntimeError(error_message) from e


def _infer_schema_and_transform(dataset: "Dataset"):
list_sizes = dataset.aggregate(_MaxListSize(dataset.schema().names))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment on why this function is needed?
it seems that we'll read the datasource twice, one for aggregate, one for map_batches. will that not be less efficient?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@raulchen it will indeed mean an extra pass on the data. The reason it is needed is because tfx-bsl ExampleDecoder returns always list of lists when no schema is provided, and what this function is doing is infering the schema for those fields that are single value fields.

Performance wise, some of our benchmarks on this implementation (we have had it for a while running internally), gives us more than 15X improvements compared to the current implementation. Some of our datasets take ~30m to load with the ray native implementation compared to less than 2m with this tfx-bsl implementation. Let me know if you need more benchmark numbers, happy to provide more

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the explanation. it's okay to proceed without more benchmarks for now.


return dataset.map_batches(
_unwrap_single_value_lists,
fn_kwargs={"col_lengths": list_sizes["max_list_size"]},
)


def _unwrap_single_value_lists(
batch: Dict[str, np.ndarray], col_lengths: Dict[str, int]
):
for col in col_lengths:
if col_lengths[col] == 1:
batch[col] = np.array(
[x[0] if isinstance(x, np.ndarray) else x for x in batch[col]]
)

return batch


class _MaxListSize(AggregateFn):
def __init__(self, columns: List[str]):
self._columns = columns
super().__init__(
init=self._init,
merge=self._merge,
accumulate_row=self._accumulate_row,
finalize=lambda a: a,
name="max_list_size",
)

def _init(self, k: str):
return {col: 0 for col in self._columns}

def _merge(self, acc1: Dict[str, int], acc2: Dict[str, int]):
merged = {}
for col in self._columns:
merged[col] = max(acc1[col], acc2[col])

return merged

def _accumulate_row(self, acc: Dict[str, int], row: "pd.Series"):
for k in row:
value = row[k]
if value:
acc[k] = max(len(value), acc[k])

return acc


# Adapted from https://github.com/vahidk/tfrecord/blob/74b2d24a838081356d993ec0e147eaf59ccd4c84/tfrecord/writer.py#L57-L72 # noqa: E501
#
# MIT License
Expand Down
45 changes: 44 additions & 1 deletion python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1476,6 +1476,8 @@ def read_tfrecords(
tf_schema: Optional["schema_pb2.Schema"] = None,
shuffle: Union[Literal["files"], None] = None,
file_extensions: Optional[List[str]] = None,
fast_read_batch_size: Optional[int] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fast_read_batch_size -> tfx_read_batch_size

fast_read_auto_infer_schema: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fast_read_auto_infer_schema -> tfx_read_auto_infer_schema

) -> Dataset:
"""Create a :class:`~ray.data.Dataset` from TFRecord files that contain
`tf.train.Example <https://www.tensorflow.org/api_docs/python/tf/train/Example>`_
scottjlee marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -1550,13 +1552,45 @@ def read_tfrecords(
shuffle: If setting to "files", randomly shuffle input files order before read.
Defaults to not shuffle with ``None``.
file_extensions: A list of file extensions to filter files by.
batch_size: An int representing the number of consecutive elements of this
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
batch_size: An int representing the number of consecutive elements of this
fast_read_batch_size: An int representing the number of consecutive elements of this

dataset to combine in a single batch when fast_read is used.
fast_read_auto_infer_schema: Toggles the schema inference applied; applicable
only if fast_read is used and tf_schema argument is missing.
Defaults to True.

Returns:
A :class:`~ray.data.Dataset` that contains the example features.

Raises:
ValueError: If a file contains a message that isn't a ``tf.train.Example``.
"""
import platform

fast_read = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fast_read -> read_with_tfx


try:
from tfx_bsl.cc.tfx_bsl_extension.coders import ( # noqa: F401
ExamplesToRecordBatchDecoder,
)

fast_read = True
except ModuleNotFoundError:
if platform.processor() == "arm":
logger.warning(
"This function depends on tfx-bsl which is currently not supported"
martinbomio marked this conversation as resolved.
Show resolved Hide resolved
" on devices with Apple silicon (e.g. M1) and requires an"
" environment with x86 CPU architecture."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when user has specified tfx_read_options but tfx_bsl isn't installed, it'd be better to just throw an exception.
And I guess it's okay to not check platform here, as tfx_bsl may support arm in the future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it does not support it right now, I find it a lot safer to not assume it will

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when user has specified tfx_read_options but tfx_bsl isn't installed, it'd be better to just throw an exception.

So you are suggestion I just remove the whole block of trying to import and logging a warning which will also remove the fallback?

else:
logger.warning(
"To use TFRecordDatasource with large datasets, please install"
" tfx-bsl package with pip install tfx_bsl --no-dependencies`."
)
logger.info(
"Falling back to slower strategy for reading tf.records. This"
martinbomio marked this conversation as resolved.
Show resolved Hide resolved
"reading strategy should be avoided when reading large datasets."
)

if meta_provider is None:
meta_provider = get_generic_metadata_provider(
TFRecordDatasource._FILE_EXTENSIONS
Expand All @@ -1573,8 +1607,17 @@ def read_tfrecords(
shuffle=shuffle,
include_paths=include_paths,
file_extensions=file_extensions,
fast_read=fast_read,
batch_size=fast_read_batch_size,
)
return read_datasource(datasource, parallelism=parallelism)
ds = read_datasource(datasource, parallelism=parallelism)

if fast_read_auto_infer_schema and fast_read and not tf_schema:
from ray.data.datasource.tfrecords_datasource import _infer_schema_and_transform

return _infer_schema_and_transform(ds)

return ds


@PublicAPI(stability="alpha")
Expand Down