Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Add performant way to read large tfrecord datasets #42277

Merged
merged 40 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
af7b677
feat: add performant way to read large tfrecord datasets
martinbomio Jan 9, 2024
f25fa2e
add tfx-bsl as a test dependency
martinbomio Jan 9, 2024
14ed874
address PR comments
martinbomio Jan 11, 2024
3b8bf91
properly enable/disable fast read on tests
martinbomio Jan 16, 2024
b303cac
resolve rabsolute path from relative
martinbomio Jan 24, 2024
ddfca72
add tensorflow-io for s3 fs impl
martinbomio Jan 24, 2024
c3beb21
try adding tfx-bsl, cython in data-test-requirements
scottjlee Jan 24, 2024
4ac5221
skip tfx-bsl install in data
scottjlee Jan 25, 2024
552c2e0
Apply suggestions from code review
martinbomio Jan 25, 2024
4a007c6
new tfx-bsl build
scottjlee Jan 25, 2024
3c20737
Merge branch 'martinbomio/fast-tfrecord-read' of https://github.com/m…
scottjlee Jan 25, 2024
bf061db
lint
scottjlee Jan 25, 2024
3d7bd33
fix missing build dependency
scottjlee Jan 25, 2024
b8519a1
add datatfxbsl build
scottjlee Jan 25, 2024
b6982a3
remove workers arg
scottjlee Jan 25, 2024
7934adc
worker config
scottjlee Jan 25, 2024
386be26
data target
scottjlee Jan 26, 2024
1d83dce
try pinning pandas<2
scottjlee Jan 26, 2024
2c49d51
pin pandas to pandas==1.5.3
scottjlee Jan 26, 2024
953af46
comment
scottjlee Jan 26, 2024
c373724
update tag
scottjlee Jan 26, 2024
e19c1fc
add tfxbsl dockerfile
scottjlee Jan 29, 2024
fb6b290
add crc32c
scottjlee Jan 30, 2024
7d7a08f
Merge branch 'master' into martinbomio/fast-tfrecord-read
scottjlee Jan 30, 2024
ca27c49
rewrite unwrap single value function to use pyarrow
martinbomio Feb 5, 2024
aac5ec6
Merge branch 'master' into martinbomio/fast-tfrecord-read
martinbomio Feb 6, 2024
7408c10
Merge branch 'martinbomio/fast-tfrecord-read' of https://github.com/m…
scottjlee Feb 7, 2024
030556c
cast large_list to list always on fast read
martinbomio Feb 13, 2024
681f753
move casting to datasource
martinbomio Feb 13, 2024
bca5bff
Merge branch 'master' into martinbomio/fast-tfrecord-read
martinbomio Feb 14, 2024
8cf1c2d
clean up docstrings
martinbomio Feb 20, 2024
9aa260b
rename fast_* variables to tfx_
martinbomio Feb 26, 2024
befc187
fix failing tests
martinbomio Feb 26, 2024
413c1f0
add flag in data context to disable using tfx read
martinbomio Feb 26, 2024
1e2c627
disable tfx_read by default
martinbomio Feb 27, 2024
9894178
add TFXREadOptions
martinbomio Feb 28, 2024
bf06a37
Merge branch 'master' into martinbomio/fast-tfrecord-read
martinbomio Feb 28, 2024
bf64415
fix build
martinbomio Feb 28, 2024
663c39c
Merge branch 'master' into martinbomio/fast-tfrecord-read
martinbomio Feb 28, 2024
550392e
Merge branch 'master' into martinbomio/fast-tfrecord-read
c21 Feb 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 15 additions & 0 deletions .buildkite/data.rayci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ steps:
- name: datamongobuild
wanda: ci/docker/datamongo.build.wanda.yaml

- name: datatfxbslbuild
wanda: ci/docker/datatfxbsl.build.wanda.yaml

# tests
- label: ":database: data: arrow 6 tests"
tags:
Expand Down Expand Up @@ -84,6 +87,18 @@ steps:
--build-name datanbuild
--except-tags data_integration,doctest
depends_on: datanbuild

- label: ":database: data: TFRecords (tfx-bsl) tests"
tags:
- python
- data
instance_type: medium
commands:
- bazel run //ci/ray_ci:test_in_docker -- //python/ray/data/... data
--parallelism-per-worker 3
--build-name datatfxbslbuild
--only-tags tfxbsl
depends_on: datatfxbslbuild

- label: ":database: data: doc tests"
tags:
Expand Down
29 changes: 29 additions & 0 deletions ci/docker/data-tfxbsl.build.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# syntax=docker/dockerfile:1.3-labs

ARG DOCKER_IMAGE_BASE_BUILD=cr.ray.io/rayproject/oss-ci-base_ml
FROM $DOCKER_IMAGE_BASE_BUILD

ARG ARROW_VERSION=14.*
ARG ARROW_MONGO_VERSION=
ARG RAY_CI_JAVA_BUILD=

# Unset dind settings; we are using the host's docker daemon.
ENV DOCKER_TLS_CERTDIR=
ENV DOCKER_HOST=
ENV DOCKER_TLS_VERIFY=
ENV DOCKER_CERT_PATH=

SHELL ["/bin/bash", "-ice"]

COPY . .

RUN <<EOF
#!/bin/bash

ARROW_VERSION=$ARROW_VERSION ./ci/env/install-dependencies.sh
# We manually install tfx-bsl here. Adding the library via data- or
# test-requirements.txt files causes unresolvable dependency conflicts with pandas.

pip install -U tfx-bsl==1.14.0 crc32c==2.3

EOF
14 changes: 14 additions & 0 deletions ci/docker/datatfxbsl.build.wanda.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: "datatfxbslbuild"
froms: ["cr.ray.io/rayproject/oss-ci-base_ml"]
dockerfile: ci/docker/data-tfxbsl.build.Dockerfile
srcs:
- ci/env/install-dependencies.sh
- python/requirements.txt
- python/requirements_compiled.txt
- python/requirements/test-requirements.txt
- python/requirements/ml/dl-cpu-requirements.txt
- python/requirements/ml/data-requirements.txt
build_args:
- ARROW_VERSION=14.*
tags:
- cr.ray.io/rayproject/datatfxbslbuild
2 changes: 1 addition & 1 deletion python/ray/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ py_test(
name = "test_tfrecords",
size = "small",
srcs = ["tests/test_tfrecords.py"],
tags = ["team:data", "exclusive"],
tags = ["team:data", "exclusive", "tfxbsl"],
deps = ["//:ray_lib", ":conftest"],
)

Expand Down
203 changes: 197 additions & 6 deletions python/ray/data/datasource/tfrecords_datasource.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,48 @@
import struct
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Union

import numpy as np
import pyarrow

from ray.data._internal.dataset_logger import DatasetLogger
from ray.data.aggregate import AggregateFn
from ray.data.block import Block
from ray.data.datasource.file_based_datasource import FileBasedDatasource
from ray.util.annotations import PublicAPI

if TYPE_CHECKING:
import pyarrow
import pandas as pd
import tensorflow as tf
from tensorflow_metadata.proto.v0 import schema_pb2

from ray.data.dataset import Dataset


# Default batch size to be used when using TFX BSL for reading tfrecord files
# Ray will use this parameter by default to read the tf.examples in batches.
DEFAULT_BATCH_SIZE = 2048
c21 marked this conversation as resolved.
Show resolved Hide resolved

logger = DatasetLogger(__name__)


@PublicAPI(stability="alpha")
@dataclass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lint: add API annotation:

@PublicAPI(stability="alpha")

class TFXReadOptions:
"""
Specifies read options when reading TFRecord files with TFX.
"""

# An int representing the number of consecutive elements of
# this dataset to combine in a single batch when tfx-bsl is used to read
# the tfrecord files.
batch_size: int = DEFAULT_BATCH_SIZE

# Toggles the schema inference applied; applicable
# only if tfx-bsl is used and tf_schema argument is missing.
# Defaults to True.
auto_infer_schema: bool = True


@PublicAPI(stability="alpha")
class TFRecordDatasource(FileBasedDatasource):
Expand All @@ -23,13 +54,31 @@ def __init__(
self,
paths: Union[str, List[str]],
tf_schema: Optional["schema_pb2.Schema"] = None,
tfx_read_options: Optional[TFXReadOptions] = None,
**file_based_datasource_kwargs,
):
"""
Args:
tf_schema: Optional TensorFlow Schema which is used to explicitly set
the schema of the underlying Dataset.
tfx_read_options: Optional options for enabling reading tfrecords
using tfx-bsl.

"""
super().__init__(paths, **file_based_datasource_kwargs)

self.tf_schema = tf_schema
self._tf_schema = tf_schema
self._tfx_read_options = tfx_read_options

def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
if self._tfx_read_options:
yield from self._tfx_read_stream(f, path)
else:
yield from self._default_read_stream(f, path)

def _default_read_stream(
self, f: "pyarrow.NativeFile", path: str
) -> Iterator[Block]:
import pyarrow as pa
import tensorflow as tf
from google.protobuf.message import DecodeError
Expand All @@ -46,14 +95,66 @@ def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
)

yield pa.Table.from_pydict(
_convert_example_to_dict(example, self.tf_schema)
_convert_example_to_dict(example, self._tf_schema)
)

def _tfx_read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
import tensorflow as tf
from tfx_bsl.cc.tfx_bsl_extension.coders import ExamplesToRecordBatchDecoder

full_path = self._resolve_full_path(path)

compression = (self._open_stream_args or {}).get("compression", None)

if compression:
compression = compression.upper()

tf_schema_string = (
self._tf_schema.SerializeToString() if self._tf_schema else None
)

decoder = ExamplesToRecordBatchDecoder(tf_schema_string)
exception_thrown = None
try:
for record in tf.data.TFRecordDataset(
full_path, compression_type=compression
).batch(self._tfx_read_options.batch_size):
yield _cast_large_list_to_list(
pyarrow.Table.from_batches([decoder.DecodeBatch(record.numpy())])
)
except Exception as error:
logger.get_logger().exception(f"Failed to read TFRecord file {full_path}")
exception_thrown = error

# we need to do this hack were we raise an exception outside of the
# except block because tensorflow DataLossError is unpickable, and
# even if we raise a runtime error, ray keeps information about the
# original error, which makes it unpickable still.
if exception_thrown:
raise RuntimeError(f"Failed to read TFRecord file {full_path}.")

def _resolve_full_path(self, relative_path):
if isinstance(self._filesystem, pyarrow.fs.S3FileSystem):
return f"s3://{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.GcsFileSystem):
return f"gs://{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.HadoopFileSystem):
return f"hdfs:///{relative_path}"
if isinstance(self._filesystem, pyarrow.fs.PyFileSystem):
protocol = self._filesystem.handler.fs.protocol
if isinstance(protocol, list) or isinstance(protocol, tuple):
protocol = protocol[0]
if protocol == "gcs":
protocol = "gs"
return f"{protocol}://{relative_path}"

return relative_path


def _convert_example_to_dict(
example: "tf.train.Example",
tf_schema: Optional["schema_pb2.Schema"],
) -> Dict[str, "pyarrow.Array"]:
) -> Dict[str, pyarrow.Array]:
record = {}
schema_dict = {}
# Convert user-specified schema into dict for convenient mapping
Expand All @@ -73,7 +174,7 @@ def _convert_example_to_dict(


def _convert_arrow_table_to_examples(
arrow_table: "pyarrow.Table",
arrow_table: pyarrow.Table,
tf_schema: Optional["schema_pb2.Schema"] = None,
) -> Iterable["tf.train.Example"]:
import tensorflow as tf
Expand Down Expand Up @@ -118,7 +219,7 @@ def _get_single_true_type(dct) -> str:
def _get_feature_value(
feature: "tf.train.Feature",
schema_feature_type: Optional["schema_pb2.FeatureType"] = None,
) -> "pyarrow.Array":
) -> pyarrow.Array:
import pyarrow as pa

underlying_feature_type = {
Expand Down Expand Up @@ -361,6 +462,96 @@ def _read_records(
raise RuntimeError(error_message) from e


def _cast_large_list_to_list(batch: pyarrow.Table):
"""
This function transform pyarrow.large_list into list and pyarrow.large_binary into
pyarrow.binary so that all types resulting from the tfrecord_datasource are usable
with dataset.to_tf().
"""
old_schema = batch.schema
fields = {}

for column_name in old_schema.names:
field_type = old_schema.field(column_name).type
if type(field_type) == pyarrow.lib.LargeListType:
value_type = field_type.value_type

if value_type == pyarrow.large_binary():
value_type = pyarrow.binary()

fields[column_name] = pyarrow.list_(value_type)
elif field_type == pyarrow.large_binary():
fields[column_name] = pyarrow.binary()
else:
fields[column_name] = old_schema.field(column_name)

new_schema = pyarrow.schema(fields)
return batch.cast(new_schema)


def _infer_schema_and_transform(dataset: "Dataset"):
list_sizes = dataset.aggregate(_MaxListSize(dataset.schema().names))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment on why this function is needed?
it seems that we'll read the datasource twice, one for aggregate, one for map_batches. will that not be less efficient?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@raulchen it will indeed mean an extra pass on the data. The reason it is needed is because tfx-bsl ExampleDecoder returns always list of lists when no schema is provided, and what this function is doing is infering the schema for those fields that are single value fields.

Performance wise, some of our benchmarks on this implementation (we have had it for a while running internally), gives us more than 15X improvements compared to the current implementation. Some of our datasets take ~30m to load with the ray native implementation compared to less than 2m with this tfx-bsl implementation. Let me know if you need more benchmark numbers, happy to provide more

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the explanation. it's okay to proceed without more benchmarks for now.


return dataset.map_batches(
_unwrap_single_value_lists,
fn_kwargs={"col_lengths": list_sizes["max_list_size"]},
batch_format="pyarrow",
)


def _unwrap_single_value_lists(batch: pyarrow.Table, col_lengths: Dict[str, int]):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@scottjlee this function now also cast the large_list to list preserving the underlying data format, and also cast large_binary to binary since to_tf does not have an implementation for large_binary.

One thing I just realized is that this conversion will only be applied when the schema inference is ran (which is when no tf_schema is provided), which means that there might be cases where the fast_read is used with a tf_schema and the to_tf could potentially fail if there's a large_list in the schema

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, maybe in that case, it would be appropriate to catch the failure, and try applying the large_list -> list_ / large_binary -> binary casting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@scottjlee meaning that in this error scenario we will leave the user to cast the schema?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that Ray Data would automatically try the casting. Alternatively, we could add to the error message to indicate the user should modify the scheme themselves

"""
This function will transfrom the dataset converting list types that always
contain single values to thery underlying data type
(i.e. pyarrow.int64() and pyarrow.float64())
"""
columns = {}

for col in col_lengths:
value_type = batch[col].type.value_type

if col_lengths[col] == 1:
if batch[col]:
columns[col] = pyarrow.array(
[x.as_py()[0] if x.as_py() else None for x in batch[col]],
type=value_type,
)
else:
columns[col] = batch[col]

return pyarrow.table(columns)


class _MaxListSize(AggregateFn):
def __init__(self, columns: List[str]):
self._columns = columns
super().__init__(
init=self._init,
merge=self._merge,
accumulate_row=self._accumulate_row,
finalize=lambda a: a,
name="max_list_size",
)

def _init(self, k: str):
return {col: 0 for col in self._columns}

def _merge(self, acc1: Dict[str, int], acc2: Dict[str, int]):
merged = {}
for col in self._columns:
merged[col] = max(acc1[col], acc2[col])

return merged

def _accumulate_row(self, acc: Dict[str, int], row: "pd.Series"):
for k in row:
value = row[k]
if value:
acc[k] = max(len(value), acc[k])

return acc


# Adapted from https://github.com/vahidk/tfrecord/blob/74b2d24a838081356d993ec0e147eaf59ccd4c84/tfrecord/writer.py#L57-L72 # noqa: E501
#
# MIT License
Expand Down