Skip to content

Commit

Permalink
responding to pr feedback, adding docs, improving tests
Browse files Browse the repository at this point in the history
Signed-off-by: Matthew Owen <mowen@anyscale.com>
  • Loading branch information
omatthew98 committed Jan 25, 2024
1 parent dddf9d4 commit edd4c16
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 23 deletions.
3 changes: 3 additions & 0 deletions python/ray/data/datasource/huggingface_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def __init__(
def list_parquet_urls_from_dataset(
cls, dataset: Union["datasets.Dataset", "datasets.IterableDataset"]
) -> Dataset:
"""Return list of Hugging Face hosted parquet file URLs if they
exist for the data (i.e. if the dataset is a public dataset that
has not been transformed) else return an empty list."""
# We can use the dataset name, config name, and split name to load
# public hugging face datasets from the Hugging Face Hub. More info
# here: https://huggingface.co/docs/datasets-server/parquet
Expand Down
15 changes: 15 additions & 0 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2365,6 +2365,11 @@ def from_huggingface(
or a :class:`~ray.data.Dataset` from a `Hugging Face Datasets IterableDataset <https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.IterableDataset/>`_.
For an `IterableDataset`, we use a streaming implementation to read data.
If the dataset is a public Hugging Face Dataset that is hosted on the Hugging Face Hub and
no transformations have been applied, then the `hosted parquet files <https://huggingface.co/docs/datasets-server/parquet#list-parquet-files>`_
will be passed :meth:`~ray.data.read_parquet` to to allow for a distributed read. All
other cases will be done with a single node read.
Example:
..
Expand Down Expand Up @@ -2403,6 +2408,14 @@ def from_huggingface(
`DatasetDict <https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.DatasetDict/>`_
and `IterableDatasetDict <https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.IterableDatasetDict/>`_
are not supported.
parallelism: The amount of parallelism to use for the dataset if applicable (e.g.
if the dataset is a public Hugging Face Dataset without transforms applied).
Defaults to -1, which automatically determines the optimal parallelism for your
configuration. You should not need to manually set this value in most cases.
For details on how the parallelism is automatically determined and guidance
on how to tune it, see :ref:`Tuning read parallelism
<read_parallelism>`. Parallelism is upper bounded by the total number of
records in all the parquet files.
Returns:
A :class:`~ray.data.Dataset` holding rows from the `Hugging Face Datasets Dataset`_.
Expand All @@ -2417,6 +2430,8 @@ def from_huggingface(
file_urls = HuggingFaceDatasource.list_parquet_urls_from_dataset(dataset)
if len(file_urls) > 0:
# If file urls are returned, the parquet files are available via API
# TODO: Add support for reading from http filesystem in FileBasedDatasource
# GH Issue: https://github.com/ray-project/ray/issues/42706
import fsspec.implementations.http

http = fsspec.implementations.http.HTTPFileSystem()
Expand Down
25 changes: 23 additions & 2 deletions python/ray/data/tests/test_execution_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,8 +1334,9 @@ def test_from_huggingface_e2e(ray_start_regular_shared, enable_optimizer):
# needed for checking operator usage below.
assert len(ds.take_all()) > 0
# Check that metadata fetch is included in stats;
# the underlying implementation uses the `FromArrow` operator.
# assert "FromArrow" in ds.stats()
# the underlying implementation uses the `ReadParquet` operator
# as this is an un-transformed public dataset.
assert "ReadParquet" in ds.stats()
assert ds._plan._logical_plan.dag.name == "ReadParquet"
# use sort by 'text' to match order of rows
expected_table = data[ds_key].data.table.sort_by("text")
Expand All @@ -1345,6 +1346,26 @@ def test_from_huggingface_e2e(ray_start_regular_shared, enable_optimizer):
expected_table.equals(output_full_table)
_check_usage_record(["ReadParquet"])

# test transformed public dataset for fallback behavior
base_hf_dataset = data["train"]
hf_dataset_split = base_hf_dataset.train_test_split(test_size=0.2)
ray_dataset_split_train = ray.data.from_huggingface(hf_dataset_split["train"])
assert isinstance(ray_dataset_split_train, ray.data.Dataset)
# `ds.take_all()` triggers execution with new backend, which is
# needed for checking operator usage below.
assert len(ray_dataset_split_train.take_all()) > 0
# Check that metadata fetch is included in stats;
# the underlying implementation uses the `FromArrow` operator.
assert "FromArrow" in ray_dataset_split_train.stats()
assert ray_dataset_split_train._plan._logical_plan.dag.name == "FromArrow"
# use sort by 'text' to match order of rows
expected_table = hf_dataset_split["train"].data.table.sort_by("text")
output_full_table = pyarrow.concat_tables(
[ray.get(tbl) for tbl in ray_dataset_split_train.to_arrow_refs()]
).sort_by("text")
expected_table.equals(output_full_table)
_check_usage_record(["FromArrow"])


def test_from_tf_e2e(ray_start_regular_shared, enable_optimizer):
import tensorflow as tf
Expand Down
40 changes: 19 additions & 21 deletions python/ray/data/tests/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
from ray.tests.conftest import * # noqa


def test_from_huggingface(ray_start_regular_shared):
data = datasets.load_dataset(
"tweet_eval", "emotion", download_mode="force_redownload"
)
@pytest.mark.parametrize("num_par", [1, 4])
def test_from_huggingface(ray_start_regular_shared, num_par):
data = datasets.load_dataset("tweet_eval", "emotion")

# Check that DatasetDict is not directly supported.
assert isinstance(data, datasets.DatasetDict)
Expand All @@ -20,26 +19,25 @@ def test_from_huggingface(ray_start_regular_shared):
):
ray.data.from_huggingface(data)

for num_par in [1, 4]:
ray_datasets = {
"train": ray.data.from_huggingface(data["train"], parallelism=num_par),
"validation": ray.data.from_huggingface(
data["validation"], parallelism=num_par
),
"test": ray.data.from_huggingface(data["test"], parallelism=num_par),
}
ray_datasets = {
"train": ray.data.from_huggingface(data["train"], parallelism=num_par),
"validation": ray.data.from_huggingface(
data["validation"], parallelism=num_par
),
"test": ray.data.from_huggingface(data["test"], parallelism=num_par),
}

assert isinstance(ray_datasets["train"], ray.data.Dataset)
assert isinstance(ray_datasets["train"], ray.data.Dataset)

# use sort by 'text' to match order of rows
expected_table = data["train"].data.table.sort_by("text")
output_full_table = pyarrow.concat_tables(
[ray.get(tbl) for tbl in ray_datasets["train"].to_arrow_refs()]
).sort_by("text")
# use sort by 'text' to match order of rows
expected_table = data["train"].data.table.sort_by("text")
output_full_table = pyarrow.concat_tables(
[ray.get(tbl) for tbl in ray_datasets["train"].to_arrow_refs()]
).sort_by("text")

assert expected_table.equals(output_full_table)
assert ray_datasets["train"].count() == data["train"].num_rows
assert ray_datasets["test"].count() == data["test"].num_rows
assert expected_table.equals(output_full_table)
assert ray_datasets["train"].count() == data["train"].num_rows
assert ray_datasets["test"].count() == data["test"].num_rows

# Test reading in a split Hugging Face dataset yields correct individual datasets
base_hf_dataset = data["train"]
Expand Down

0 comments on commit edd4c16

Please sign in to comment.