Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] Distributed reads for from_huggingface #42599

Merged
merged 14 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
48 changes: 48 additions & 0 deletions python/ray/data/datasource/huggingface_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ray.data._internal.dataset_logger import DatasetLogger
from ray.data._internal.util import _check_pyarrow_version
from ray.data.block import Block, BlockAccessor, BlockMetadata
from ray.data.dataset import Dataset
from ray.data.datasource import Datasource, ReadTask
from ray.util.annotations import DeveloperAPI

Expand Down Expand Up @@ -71,6 +72,53 @@ def __init__(
self._dataset = dataset
self._batch_size = batch_size

@classmethod
def list_parquet_urls_from_dataset(
cls, dataset: Union["datasets.Dataset", "datasets.IterableDataset"]
) -> Dataset:
scottjlee marked this conversation as resolved.
Show resolved Hide resolved
"""Return list of Hugging Face hosted parquet file URLs if they
exist for the data (i.e. if the dataset is a public dataset that
has not been transformed) else return an empty list."""
import datasets

# We can use the dataset name, config name, and split name to load
# public hugging face datasets from the Hugging Face Hub. More info
# here: https://huggingface.co/docs/datasets-server/parquet
dataset_name = dataset.info.dataset_name
config_name = dataset.info.config_name
split_name = str(dataset.split)

# If a dataset is not an iterable dataset, we will check if the
# dataset with the matching dataset name, config name, and split name
# on the Hugging Face Hub has the same fingerprint as the
# dataset passed into this function. If it is not matching, transforms
# or other operations have been performed so we cannot use the parquet
# files on the Hugging Face Hub, so we return an empty list.
if not isinstance(dataset, datasets.IterableDataset):
from datasets import load_dataset

try:
ds = load_dataset(dataset_name, config_name, split=split_name)
if ds._fingerprint != dataset._fingerprint:
return []
except Exception:
# If an exception is thrown when trying to reload the dataset
# we should exit gracefully by returning an empty list.
return []

import requests

public_url = (
f"https://huggingface.co/api/datasets/{dataset_name}"
f"/parquet/{config_name}/{split_name}"
)
resp = requests.get(public_url)
if resp.status_code == requests.codes["ok"]:
# dataset corresponds to a public dataset, return list of parquet_files
return resp.json()
else:
return []

def estimate_inmemory_data_size(self) -> Optional[int]:
return self._dataset.dataset_size

Expand Down
35 changes: 29 additions & 6 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2358,13 +2358,18 @@ def from_spark(

@PublicAPI
def from_huggingface(
dataset: Union["datasets.Dataset", "datasets.IterableDataset"],
dataset: Union["datasets.Dataset", "datasets.IterableDataset"], parallelism=-1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add parallelism in the docstring? let's make sure we state that the user should not need to set parallelism in most cases since it is auto-configured--similar to the docstring for parallelism under read_parquet() and other read methods.

in addition, we should also clarify that this is only used when the distributed parquet read is possible (i.e. when it is a public dataset with no transformations), otherwise it's a single node read.

) -> Union[MaterializedDataset, Dataset]:
"""Create a :class:`~ray.data.MaterializedDataset` from a
`Hugging Face Datasets Dataset <https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset/>`_
or a :class:`~ray.data.Dataset` from a `Hugging Face Datasets IterableDataset <https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.IterableDataset/>`_.
For an `IterableDataset`, we use a streaming implementation to read data.

If the dataset is a public Hugging Face Dataset that is hosted on the Hugging Face Hub and
no transformations have been applied, then the `hosted parquet files <https://huggingface.co/docs/datasets-server/parquet#list-parquet-files>`_
will be passed to :meth:`~ray.data.read_parquet` to perform a distributed read. All
other cases will be done with a single node read.

Example:

..
Expand Down Expand Up @@ -2403,18 +2408,36 @@ def from_huggingface(
`DatasetDict <https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.DatasetDict/>`_
and `IterableDatasetDict <https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.IterableDatasetDict/>`_
are not supported.
parallelism: The amount of parallelism to use for the dataset if applicable (i.e.
if the dataset is a public Hugging Face Dataset without transforms applied).
Defaults to -1, which automatically determines the optimal parallelism for your
configuration. You should not need to manually set this value in most cases.
For details on how the parallelism is automatically determined and guidance
on how to tune it, see :ref:`Tuning read parallelism
<read_parallelism>`. Parallelism is upper bounded by the total number of
records in all the parquet files.

Returns:
A :class:`~ray.data.Dataset` holding rows from the `Hugging Face Datasets Dataset`_.
""" # noqa: E501
import datasets

if isinstance(dataset, datasets.IterableDataset):
# HuggingFaceDatasource should not be imported at top level, because
# we only want the Hugging Face datasets package to be imported
# if Hugging Face Datasets are used.
from ray.data.datasource.huggingface_datasource import HuggingFaceDatasource
from ray.data.datasource.huggingface_datasource import HuggingFaceDatasource

if isinstance(dataset, (datasets.IterableDataset, datasets.Dataset)):
# Attempt to read data via Hugging Face Hub parquet files. If the
# returned list of files is empty, attempt read via other methods.
file_urls = HuggingFaceDatasource.list_parquet_urls_from_dataset(dataset)
if len(file_urls) > 0:
# If file urls are returned, the parquet files are available via API
# TODO: Add support for reading from http filesystem in FileBasedDatasource
# GH Issue: https://github.com/ray-project/ray/issues/42706
import fsspec.implementations.http

http = fsspec.implementations.http.HTTPFileSystem()
return read_parquet(file_urls, parallelism=parallelism, filesystem=http)

if isinstance(dataset, datasets.IterableDataset):
# For an IterableDataset, we can use a streaming implementation to read data.
return read_datasource(HuggingFaceDatasource(dataset=dataset))
if isinstance(dataset, datasets.Dataset):
Expand Down
37 changes: 23 additions & 14 deletions python/ray/data/tests/test_execution_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,8 @@ def test_from_arrow_refs_e2e(ray_start_regular_shared, enable_optimizer):
def test_from_huggingface_e2e(ray_start_regular_shared, enable_optimizer):
import datasets

from ray.data.tests.test_huggingface import hfds_assert_equals

data = datasets.load_dataset("tweet_eval", "emotion")
assert isinstance(data, datasets.DatasetDict)
ray_datasets = {
Expand All @@ -1333,20 +1335,27 @@ def test_from_huggingface_e2e(ray_start_regular_shared, enable_optimizer):
# needed for checking operator usage below.
assert len(ds.take_all()) > 0
# Check that metadata fetch is included in stats;
# the underlying implementation uses the `FromArrow` operator.
assert "FromArrow" in ds.stats()
assert ds._plan._logical_plan.dag.name == "FromArrow"
assert ray.get(ray_datasets[ds_key].to_arrow_refs())[0].equals(
data[ds_key].data.table
)
_check_usage_record(["FromArrow"])

ray_dataset = ray.data.from_huggingface(data["train"])
assert isinstance(ray_dataset, ray.data.Dataset)
assert len(ray_dataset.take_all()) > 0
assert "FromArrow" in ray_dataset.stats()
assert ray_dataset._plan._logical_plan.dag.name == "FromArrow"
assert ray.get(ray_dataset.to_arrow_refs())[0].equals(data["train"].data.table)
# the underlying implementation uses the `ReadParquet` operator
# as this is an un-transformed public dataset.
assert "ReadParquet" in ds.stats()
assert ds._plan._logical_plan.dag.name == "ReadParquet"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a comment briefly explaining why we check for ReadParquet in this test for Hugging Face?

or even better would be to have two test paths, one which uses the new distributed parquet path, and one which uses the backup direct read from HF dataset. I know that might be a bit tricky though, since you'll likely need to find a private dataset (that we can access) or a public dataset that doesn't have parquet files converted for some reason. So no worries if you cannot find a way to easily test both paths

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ended up adding a second test case which tests the same on a modified dataset (just did a train test split and had similar tests as the unmodified), so we should be testing both paths now!

# use sort by 'text' to match order of rows
hfds_assert_equals(data[ds_key], ds)
_check_usage_record(["ReadParquet"])

# test transformed public dataset for fallback behavior
base_hf_dataset = data["train"]
hf_dataset_split = base_hf_dataset.train_test_split(test_size=0.2)
ray_dataset_split_train = ray.data.from_huggingface(hf_dataset_split["train"])
assert isinstance(ray_dataset_split_train, ray.data.Dataset)
# `ds.take_all()` triggers execution with new backend, which is
# needed for checking operator usage below.
assert len(ray_dataset_split_train.take_all()) > 0
# Check that metadata fetch is included in stats;
# the underlying implementation uses the `FromArrow` operator.
assert "FromArrow" in ray_dataset_split_train.stats()
assert ray_dataset_split_train._plan._logical_plan.dag.name == "FromArrow"
assert ray_dataset_split_train.count() == hf_dataset_split["train"].num_rows
_check_usage_record(["FromArrow"])


Expand Down
58 changes: 37 additions & 21 deletions python/ray/data/tests/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,59 @@
from packaging.version import Version

import ray
from ray.data.dataset import Dataset
from ray.tests.conftest import * # noqa


def test_from_huggingface(ray_start_regular_shared):
data = datasets.load_dataset("tweet_eval", "emotion")
@pytest.fixture(scope="session")
def hf_dataset():
return datasets.load_dataset("tweet_eval", "stance_climate")


def _arrow_sort_values(table: pyarrow.lib.Table) -> pyarrow.lib.Table:
"""
Sort an Arrow table by the values in the first column. Used for testing
compatibility with pyarrow 6 where `sort_by` does not exist. Inspired by:
https://stackoverflow.com/questions/70893521/how-to-sort-a-pyarrow-table
"""
by = [table.schema.names[0]] # grab first col_name
table_sorted_indexes = pyarrow.compute.bottom_k_unstable(
table, sort_keys=by, k=len(table)
)
table_sorted = table.take(table_sorted_indexes)
return table_sorted


def hfds_assert_equals(hfds: datasets.Dataset, ds: Dataset):
hfds_table = _arrow_sort_values(hfds.data.table)
ds_table = _arrow_sort_values(
pyarrow.concat_tables([ray.get(tbl) for tbl in ds.to_arrow_refs()])
)
assert hfds_table.equals(ds_table)


@pytest.mark.parametrize("num_par", [1, 4])
def test_from_huggingface(hf_dataset, ray_start_regular_shared, num_par):
# Check that DatasetDict is not directly supported.
assert isinstance(data, datasets.DatasetDict)
assert isinstance(hf_dataset, datasets.DatasetDict)
with pytest.raises(
DeprecationWarning,
match="You provided a Hugging Face DatasetDict",
):
ray.data.from_huggingface(data)
ray.data.from_huggingface(hf_dataset)

ray_datasets = {
"train": ray.data.from_huggingface(data["train"]),
"validation": ray.data.from_huggingface(data["validation"]),
"test": ray.data.from_huggingface(data["test"]),
"train": ray.data.from_huggingface(hf_dataset["train"], parallelism=num_par),
}

assert ray.get(ray_datasets["train"].to_arrow_refs())[0].equals(
data["train"].data.table
)
assert ray_datasets["train"].count() == data["train"].num_rows
assert ray_datasets["test"].count() == data["test"].num_rows

ray_dataset = ray.data.from_huggingface(data["train"])
assert isinstance(ray_dataset, ray.data.Dataset)
assert ray.get(ray_dataset.to_arrow_refs())[0].equals(data["train"].data.table)
assert isinstance(ray_datasets["train"], ray.data.Dataset)
hfds_assert_equals(hf_dataset["train"], ray_datasets["train"])

# Test reading in a split Hugging Face dataset yields correct individual datasets
base_hf_dataset = data["train"]
base_hf_dataset = hf_dataset["train"]
hf_dataset_split = base_hf_dataset.train_test_split(test_size=0.2)
ray_dataset_split_train = ray.data.from_huggingface(hf_dataset_split["train"])
ray_dataset_split_test = ray.data.from_huggingface(hf_dataset_split["test"])
assert ray_dataset_split_train.count() == hf_dataset_split["train"].num_rows
assert ray_dataset_split_test.count() == hf_dataset_split["test"].num_rows


@pytest.mark.skipif(
Expand All @@ -58,11 +74,11 @@ def test_from_huggingface(ray_start_regular_shared):
)
def test_from_huggingface_streaming(batch_format, ray_start_regular_shared):
hfds = datasets.load_dataset(
"tweet_eval", "emotion", streaming=True, split="train"
"tweet_eval", "stance_climate", streaming=True, split="train"
).with_format(batch_format)
assert isinstance(hfds, datasets.IterableDataset)
ds = ray.data.from_huggingface(hfds)
assert ds.count() == 3257
assert ds.count() == 355


if __name__ == "__main__":
Expand Down