Skip to content

Commit

Permalink
[Data] Add performant way to read large tfrecord datasets (#42277)
Browse files Browse the repository at this point in the history
The main motivation for this PR is that ray.data.read_tfrcords yields suboptimal performance when reading large datasets.
This PR adds a default "fast" route for reading tf.records that relies on tfx-bsl decoder. This approach also infers the schema when no tf_schema is provided by doing a pass of the data to determine the cardinality of the feature lists.

Signed-off-by: Martin Bomio <martinbomio@spotify.com>
Signed-off-by: Scott Lee <sjl@anyscale.com>
Signed-off-by: Martin <martinbomio@gmail.com>
Co-authored-by: Scott Lee <sjl@anyscale.com>
Co-authored-by: Cheng Su <scnju13@gmail.com>
  • Loading branch information
3 people committed Feb 29, 2024
1 parent fe554c1 commit 2c37909
Show file tree
Hide file tree
Showing 7 changed files with 361 additions and 21 deletions.
15 changes: 15 additions & 0 deletions .buildkite/data.rayci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ steps:
- name: datamongobuild
wanda: ci/docker/datamongo.build.wanda.yaml

- name: datatfxbslbuild
wanda: ci/docker/datatfxbsl.build.wanda.yaml

# tests
- label: ":database: data: arrow 6 tests"
tags:
Expand Down Expand Up @@ -84,6 +87,18 @@ steps:
--build-name datanbuild
--except-tags data_integration,doctest
depends_on: datanbuild

- label: ":database: data: TFRecords (tfx-bsl) tests"
tags:
- python
- data
instance_type: medium
commands:
- bazel run //ci/ray_ci:test_in_docker -- //python/ray/data/... data
--parallelism-per-worker 3
--build-name datatfxbslbuild
--only-tags tfxbsl
depends_on: datatfxbslbuild

- label: ":database: data: doc tests"
tags:
Expand Down
29 changes: 29 additions & 0 deletions ci/docker/data-tfxbsl.build.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# syntax=docker/dockerfile:1.3-labs

ARG DOCKER_IMAGE_BASE_BUILD=cr.ray.io/rayproject/oss-ci-base_ml
FROM $DOCKER_IMAGE_BASE_BUILD

ARG ARROW_VERSION=14.*
ARG ARROW_MONGO_VERSION=
ARG RAY_CI_JAVA_BUILD=

# Unset dind settings; we are using the host's docker daemon.
ENV DOCKER_TLS_CERTDIR=
ENV DOCKER_HOST=
ENV DOCKER_TLS_VERIFY=
ENV DOCKER_CERT_PATH=

SHELL ["/bin/bash", "-ice"]

COPY . .

RUN <<EOF
#!/bin/bash

ARROW_VERSION=$ARROW_VERSION ./ci/env/install-dependencies.sh
# We manually install tfx-bsl here. Adding the library via data- or
# test-requirements.txt files causes unresolvable dependency conflicts with pandas.

pip install -U tfx-bsl==1.14.0 crc32c==2.3

EOF
14 changes: 14 additions & 0 deletions ci/docker/datatfxbsl.build.wanda.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: "datatfxbslbuild"
froms: ["cr.ray.io/rayproject/oss-ci-base_ml"]
dockerfile: ci/docker/data-tfxbsl.build.Dockerfile
srcs:
- ci/env/install-dependencies.sh
- python/requirements.txt
- python/requirements_compiled.txt
- python/requirements/test-requirements.txt
- python/requirements/ml/dl-cpu-requirements.txt
- python/requirements/ml/data-requirements.txt
build_args:
- ARROW_VERSION=14.*
tags:
- cr.ray.io/rayproject/datatfxbslbuild
2 changes: 1 addition & 1 deletion python/ray/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ py_test(
name = "test_tfrecords",
size = "small",
srcs = ["tests/test_tfrecords.py"],
tags = ["team:data", "exclusive"],
tags = ["team:data", "exclusive", "tfxbsl"],
deps = ["//:ray_lib", ":conftest"],
)

Expand Down
203 changes: 197 additions & 6 deletions python/ray/data/datasource/tfrecords_datasource.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,48 @@
import struct
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Union

import numpy as np
import pyarrow

from ray.data._internal.dataset_logger import DatasetLogger
from ray.data.aggregate import AggregateFn
from ray.data.block import Block
from ray.data.datasource.file_based_datasource import FileBasedDatasource
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
import pyarrow
import pandas as pd
import tensorflow as tf
from tensorflow_metadata.proto.v0 import schema_pb2

from ray.data.dataset import Dataset


# Default batch size to be used when using TFX BSL for reading tfrecord files
# Ray will use this parameter by default to read the tf.examples in batches.
DEFAULT_BATCH_SIZE = 2048

logger = DatasetLogger(__name__)


@PublicAPI(stability="alpha")
@dataclass
class TFXReadOptions:
"""
Specifies read options when reading TFRecord files with TFX.
"""

# An int representing the number of consecutive elements of
# this dataset to combine in a single batch when tfx-bsl is used to read
# the tfrecord files.
batch_size: int = DEFAULT_BATCH_SIZE

# Toggles the schema inference applied; applicable
# only if tfx-bsl is used and tf_schema argument is missing.
# Defaults to True.
auto_infer_schema: bool = True


@PublicAPI(stability="alpha")
class TFRecordDatasource(FileBasedDatasource):
Expand All @@ -23,13 +54,31 @@ def __init__(
self,
paths: Union[str, List[str]],
tf_schema: Optional["schema_pb2.Schema"] = None,
tfx_read_options: Optional[TFXReadOptions] = None,
**file_based_datasource_kwargs,
):
"""
Args:
tf_schema: Optional TensorFlow Schema which is used to explicitly set
the schema of the underlying Dataset.
tfx_read_options: Optional options for enabling reading tfrecords
using tfx-bsl.
"""
super().__init__(paths, **file_based_datasource_kwargs)

self.tf_schema = tf_schema
self._tf_schema = tf_schema
self._tfx_read_options = tfx_read_options

def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
if self._tfx_read_options:
yield from self._tfx_read_stream(f, path)
else:
yield from self._default_read_stream(f, path)

def _default_read_stream(
self, f: "pyarrow.NativeFile", path: str
) -> Iterator[Block]:
import pyarrow as pa
import tensorflow as tf
from google.protobuf.message import DecodeError
Expand All @@ -46,14 +95,66 @@ def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
)

yield pa.Table.from_pydict(
_convert_example_to_dict(example, self.tf_schema)
_convert_example_to_dict(example, self._tf_schema)
)

def _tfx_read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
import tensorflow as tf
from tfx_bsl.cc.tfx_bsl_extension.coders import ExamplesToRecordBatchDecoder

full_path = self._resolve_full_path(path)

compression = (self._open_stream_args or {}).get("compression", None)

if compression:
compression = compression.upper()

tf_schema_string = (
self._tf_schema.SerializeToString() if self._tf_schema else None
)

decoder = ExamplesToRecordBatchDecoder(tf_schema_string)
exception_thrown = None
try:
for record in tf.data.TFRecordDataset(
full_path, compression_type=compression
).batch(self._tfx_read_options.batch_size):
yield _cast_large_list_to_list(
pyarrow.Table.from_batches([decoder.DecodeBatch(record.numpy())])
)
except Exception as error:
logger.get_logger().exception(f"Failed to read TFRecord file {full_path}")
exception_thrown = error

# we need to do this hack were we raise an exception outside of the
# except block because tensorflow DataLossError is unpickable, and
# even if we raise a runtime error, ray keeps information about the
# original error, which makes it unpickable still.
if exception_thrown:
raise RuntimeError(f"Failed to read TFRecord file {full_path}.")

def _resolve_full_path(self, relative_path):
if isinstance(self._filesystem, pyarrow.fs.S3FileSystem):
return f"s3://{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.GcsFileSystem):
return f"gs://{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.HadoopFileSystem):
return f"hdfs:///{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.PyFileSystem):
protocol = self._filesystem.handler.fs.protocol
if isinstance(protocol, list) or isinstance(protocol, tuple):
protocol = protocol[0]
if protocol == "gcs":
protocol = "gs"
return f"{protocol}://{relative_path}"

return relative_path


def _convert_example_to_dict(
example: "tf.train.Example",
tf_schema: Optional["schema_pb2.Schema"],
) -> Dict[str, "pyarrow.Array"]:
) -> Dict[str, pyarrow.Array]:
record = {}
schema_dict = {}
# Convert user-specified schema into dict for convenient mapping
Expand All @@ -73,7 +174,7 @@ def _convert_example_to_dict(


def _convert_arrow_table_to_examples(
arrow_table: "pyarrow.Table",
arrow_table: pyarrow.Table,
tf_schema: Optional["schema_pb2.Schema"] = None,
) -> Iterable["tf.train.Example"]:
import tensorflow as tf
Expand Down Expand Up @@ -118,7 +219,7 @@ def _get_single_true_type(dct) -> str:
def _get_feature_value(
feature: "tf.train.Feature",
schema_feature_type: Optional["schema_pb2.FeatureType"] = None,
) -> "pyarrow.Array":
) -> pyarrow.Array:
import pyarrow as pa

underlying_feature_type = {
Expand Down Expand Up @@ -361,6 +462,96 @@ def _read_records(
raise RuntimeError(error_message) from e


def _cast_large_list_to_list(batch: pyarrow.Table):
"""
This function transform pyarrow.large_list into list and pyarrow.large_binary into
pyarrow.binary so that all types resulting from the tfrecord_datasource are usable
with dataset.to_tf().
"""
old_schema = batch.schema
fields = {}

for column_name in old_schema.names:
field_type = old_schema.field(column_name).type
if type(field_type) == pyarrow.lib.LargeListType:
value_type = field_type.value_type

if value_type == pyarrow.large_binary():
value_type = pyarrow.binary()

fields[column_name] = pyarrow.list_(value_type)
elif field_type == pyarrow.large_binary():
fields[column_name] = pyarrow.binary()
else:
fields[column_name] = old_schema.field(column_name)

new_schema = pyarrow.schema(fields)
return batch.cast(new_schema)


def _infer_schema_and_transform(dataset: "Dataset"):
list_sizes = dataset.aggregate(_MaxListSize(dataset.schema().names))

return dataset.map_batches(
_unwrap_single_value_lists,
fn_kwargs={"col_lengths": list_sizes["max_list_size"]},
batch_format="pyarrow",
)


def _unwrap_single_value_lists(batch: pyarrow.Table, col_lengths: Dict[str, int]):
"""
This function will transfrom the dataset converting list types that always
contain single values to thery underlying data type
(i.e. pyarrow.int64() and pyarrow.float64())
"""
columns = {}

for col in col_lengths:
value_type = batch[col].type.value_type

if col_lengths[col] == 1:
if batch[col]:
columns[col] = pyarrow.array(
[x.as_py()[0] if x.as_py() else None for x in batch[col]],
type=value_type,
)
else:
columns[col] = batch[col]

return pyarrow.table(columns)


class _MaxListSize(AggregateFn):
def __init__(self, columns: List[str]):
self._columns = columns
super().__init__(
init=self._init,
merge=self._merge,
accumulate_row=self._accumulate_row,
finalize=lambda a: a,
name="max_list_size",
)

def _init(self, k: str):
return {col: 0 for col in self._columns}

def _merge(self, acc1: Dict[str, int], acc2: Dict[str, int]):
merged = {}
for col in self._columns:
merged[col] = max(acc1[col], acc2[col])

return merged

def _accumulate_row(self, acc: Dict[str, int], row: "pd.Series"):
for k in row:
value = row[k]
if value:
acc[k] = max(len(value), acc[k])

return acc


# Adapted from https://github.com/vahidk/tfrecord/blob/74b2d24a838081356d993ec0e147eaf59ccd4c84/tfrecord/writer.py#L57-L72 # noqa: E501
#
# MIT License
Expand Down

0 comments on commit 2c37909

Please sign in to comment.