Skip to content

Commit

Permalink
feat: add performant way to read large tfrecord datasets
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Bomio <martinbomio@spotify.com>
  • Loading branch information
martinbomio committed Jan 9, 2024
1 parent 601709e commit 1223af6
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 7 deletions.
90 changes: 86 additions & 4 deletions python/ray/data/datasource/tfrecords_datasource.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import struct
from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Union

import numpy as np

from ray.data.aggregate import AggregateFn
from ray.data.block import Block
from ray.data.datasource.file_based_datasource import FileBasedDatasource
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
import pandas as pd
import pyarrow
import tensorflow as tf
from tensorflow_metadata.proto.v0 import schema_pb2

from ray.data.dataset import Dataset


@PublicAPI(stability="alpha")
class TFRecordDatasource(FileBasedDatasource):
Expand All @@ -23,13 +25,23 @@ def __init__(
self,
paths: Union[str, List[str]],
tf_schema: Optional["schema_pb2.Schema"] = None,
fast_read: bool = False,
batch_size: Optional[int] = None,
**file_based_datasource_kwargs,
):
super().__init__(paths, **file_based_datasource_kwargs)

self.tf_schema = tf_schema
self._tf_schema = tf_schema
self._fast_read = fast_read
self._batch_size = batch_size or 2048

def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
if self._fast_read:
yield from self._fast_read_stream(f, path)
else:
yield from self._slow_read_stream(f, path)

def _slow_read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
import pyarrow as pa
import tensorflow as tf
from google.protobuf.message import DecodeError
Expand All @@ -46,9 +58,30 @@ def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
)

yield pa.Table.from_pydict(
_convert_example_to_dict(example, self.tf_schema)
_convert_example_to_dict(example, self._tf_schema)
)

def _fast_read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
import pyarrow as pa
import tensorflow as tf
from tfx_bsl.cc.tfx_bsl_extension.coders import ExamplesToRecordBatchDecoder

compression = (self._open_stream_args or {}).get("compression", None)

if compression:
compression = compression.upper()

tf_schema_string = (
self._tf_schema.SerializeToString() if self._tf_schema else None
)

decoder = ExamplesToRecordBatchDecoder(tf_schema_string)

for record in tf.data.TFRecordDataset(path, compression_type=compression).batch(
self._batch_size
):
yield pa.Table.from_batches([decoder.DecodeBatch(record.numpy())])


def _convert_example_to_dict(
example: "tf.train.Example",
Expand Down Expand Up @@ -361,6 +394,54 @@ def _read_records(
raise RuntimeError(error_message) from e


def unwrap_single_value_columns(dataset: "Dataset"):
list_sizes = dataset.aggregate(_MaxListSize(dataset.schema().names))

return dataset.map_batches(
_unwrap_single_value_lists,
fn_kwargs={"col_lengths": list_sizes["max_list_size"]},
batch_format="pandas",
)


def _unwrap_single_value_lists(batch: "pd.DataFrame", col_lengths: Dict[str, int]):
for col in col_lengths:
if col_lengths[col] == 1:
batch[col] = batch[col].str[0]

return batch


class _MaxListSize(AggregateFn):
def __init__(self, columns: List[str]):
self._columns = columns
super().__init__(
init=self._init,
merge=self._merge,
accumulate_row=self._accumulate_row,
finalize=lambda a: a,
name="max_list_size",
)

def _init(self, k: str):
return {col: 0 for col in self._columns}

def _merge(self, acc1: Dict[str, int], acc2: Dict[str, int]):
merged = {}
for col in self._columns:
merged[col] = max(acc1[col], acc2[col])

return merged

def _accumulate_row(self, acc: Dict[str, int], row: "pd.Series"):
for k in row:
value = row[k]
if value:
acc[k] = max(len(value), acc[k])

return acc


# Adapted from https://github.com/vahidk/tfrecord/blob/74b2d24a838081356d993ec0e147eaf59ccd4c84/tfrecord/writer.py#L57-L72 # noqa: E501
#
# MIT License
Expand Down Expand Up @@ -402,6 +483,7 @@ def _write_record(
def _masked_crc(data: bytes) -> bytes:
"""CRC checksum."""
import crc32c
import numpy as np

mask = 0xA282EAD8
crc = crc32c.crc32(data)
Expand Down
53 changes: 52 additions & 1 deletion python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,6 +1490,9 @@ def read_tfrecords(
tf_schema: Optional["schema_pb2.Schema"] = None,
shuffle: Union[Literal["files"], None] = None,
file_extensions: Optional[List[str]] = None,
batch_size: Optional[int] = None,
schema_inference: bool = True,
force_fast_read: bool = False,
) -> Dataset:
"""Create a :class:`~ray.data.Dataset` from TFRecord files that contain
`tf.train.Example <https://www.tensorflow.org/api_docs/python/tf/train/Example>`_
Expand Down Expand Up @@ -1564,13 +1567,52 @@ def read_tfrecords(
shuffle: If setting to "files", randomly shuffle input files order before read.
Defaults to not shuffle with ``None``.
file_extensions: A list of file extensions to filter files by.
batch_size: An int representing the number of consecutive elements of this
dataset to combine in a single batch when fast_read is used.
schema_inference: Toggles the schema inference applied; applicable only if
tf_schema argument is missing. Defaults to True.
force_fast_read: Forces the fast read, failing if the proper dependencies
are not installed.
Returns:
A :class:`~ray.data.Dataset` that contains the example features.
Raises:
ValueError: If a file contains a message that isn't a ``tf.train.Example``.
"""
import platform

fast_read = False

try:
from tfx_bsl.cc.tfx_bsl_extension.coders import ( # noqa: F401
ExamplesToRecordBatchDecoder,
)

fast_read = force_fast_read
except ModuleNotFoundError:
if platform.processor() == "arm":
logger.warning(
"This function depends on tfx-bsl which is currently not supported"
" on devices with Apple silicon (e.g. M1) and requires an"
" environment with x86 CPU architecture."
)
else:
if force_fast_read:
raise ModuleNotFoundError(
"This function was called with `force_fast_read` but tfx-bsl is not"
"installed. Please install tfx-bsl package with pip install"
"tfx_bsl --no-dependencies"
)
logger.warning(
"To use TFRecordDatasource with large datasets, please install"
" tfx-bsl package with pip install tfx_bsl --no-dependencies`."
)
logger.info(
"Falling back to slower strategy for reading tf.records. This"
"reading strategy should be avoided when reading large datasets."
)

if meta_provider is None:
meta_provider = get_generic_metadata_provider(
TFRecordDatasource._FILE_EXTENSIONS
Expand All @@ -1587,8 +1629,17 @@ def read_tfrecords(
shuffle=shuffle,
include_paths=include_paths,
file_extensions=file_extensions,
fast_read=fast_read,
batch_size=batch_size,
)
return read_datasource(datasource, parallelism=parallelism)
ds = read_datasource(datasource, parallelism=parallelism)

if schema_inference and fast_read and not tf_schema:
from ray.data.datasource.tfrecords_datasource import unwrap_single_value_columns

return unwrap_single_value_columns(ds)

return ds


@PublicAPI(stability="alpha")
Expand Down
10 changes: 8 additions & 2 deletions python/ray/data/tests/test_tfrecords.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,13 @@ def _str2bytes(d):
assert ds_expected.take() == ds_actual.take()


@pytest.mark.parametrize("with_tf_schema", (True, False))
@pytest.mark.parametrize(
"with_tf_schema,force_fast_read",
[(True, True), (True, False), (False, True), (False, False)],
)
def test_read_tfrecords(
with_tf_schema,
force_fast_read,
ray_start_regular_shared,
tmp_path,
):
Expand All @@ -357,7 +361,9 @@ def test_read_tfrecords(
with tf.io.TFRecordWriter(path=path) as writer:
writer.write(example.SerializeToString())

ds = ray.data.read_tfrecords(path, tf_schema=tf_schema)
ds = ray.data.read_tfrecords(
path, tf_schema=tf_schema, force_fast_read=force_fast_read
)
df = ds.to_pandas()
# Protobuf serializes features in a non-deterministic order.
if with_tf_schema:
Expand Down

0 comments on commit 1223af6

Please sign in to comment.