Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 79 additions & 44 deletions src/guidellm/data/deserializers/deserializer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations

import contextlib
from collections.abc import Callable
from typing import Any, Protocol, Union, runtime_checkable

from datasets import Dataset, IterableDataset
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
from transformers import PreTrainedTokenizerBase

from guidellm.data.utils import resolve_dataset_split
Expand All @@ -29,7 +28,7 @@ def __call__(
processor_factory: Callable[[], PreTrainedTokenizerBase],
random_seed: int,
**data_kwargs: dict[str, Any],
) -> dict[str, list]: ...
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: ...


class DatasetDeserializerFactory(
Expand All @@ -47,51 +46,16 @@ def deserialize(
remove_columns: list[str] | None = None,
**data_kwargs: dict[str, Any],
) -> Dataset | IterableDataset:
dataset = None
dataset: Dataset

if type_ is None:
errors = []
# Note: There is no priority order for the deserializers, so all deserializers
# must be mutually exclusive to ensure deterministic behavior.
for name, deserializer in cls.registry.items():
deserializer_fn: DatasetDeserializer = (
deserializer() if isinstance(deserializer, type) else deserializer
)

try:
with contextlib.suppress(DataNotSupportedError):
dataset = deserializer_fn(
data=data,
processor_factory=processor_factory,
random_seed=random_seed,
**data_kwargs,
)
except Exception as e:
errors.append(e)

if dataset is not None:
break # Found one that works. Continuing could overwrite it.

if dataset is None and len(errors) > 0:
raise DataNotSupportedError(f"data deserialization failed; {len(errors)} errors occurred while "
f"attempting to deserialize data {data}: {errors}")

elif deserializer := cls.get_registered_object(type_) is not None:
deserializer_fn: DatasetDeserializer = (
deserializer() if isinstance(deserializer, type) else deserializer
dataset = cls._deserialize_with_registered_deserializers(
data, processor_factory, random_seed, **data_kwargs
)

dataset = deserializer_fn(
data=data,
processor_factory=processor_factory,
random_seed=random_seed,
**data_kwargs,
)

if dataset is None:
raise DataNotSupportedError(
f"No suitable deserializer found for data {data} "
f"with kwargs {data_kwargs} and deserializer type {type_}."
else:
dataset = cls._deserialize_with_specified_deserializer(
data, type_, processor_factory, random_seed, **data_kwargs
)

if resolve_split:
Expand All @@ -107,3 +71,74 @@ def deserialize(
dataset = dataset.remove_columns(remove_columns)

return dataset

@classmethod
def _deserialize_with_registered_deserializers(
cls,
data: Any,
processor_factory: Callable[[], PreTrainedTokenizerBase],
random_seed: int = 42,
**data_kwargs: dict[str, Any],
) -> Dataset:
if cls.registry is None:
raise RuntimeError("registry is None; cannot deserialize dataset")
dataset: Dataset | None = None

errors: dict[str, Exception] = {}
# Note: There is no priority order for the deserializers, so all deserializers
# must be mutually exclusive to ensure deterministic behavior.
for _name, deserializer in cls.registry.items():
deserializer_fn: DatasetDeserializer = (
deserializer() if isinstance(deserializer, type) else deserializer
)

try:
dataset = deserializer_fn(
data=data,
processor_factory=processor_factory,
random_seed=random_seed,
**data_kwargs,
)
except Exception as e: # noqa: BLE001 # The exceptions are saved.
errors[_name] = e

if dataset is not None:
return dataset # Success

if len(errors) > 0:
err_msgs = ""
def sort_key(item):
return (isinstance(item[1], DataNotSupportedError), item[0])
for key, err in sorted(errors.items(), key=sort_key):
err_msgs += f"\n - Deserializer '{key}': ({type(err).__name__}) {err}"
raise ValueError(
"Data deserialization failed, likely because the input doesn't "
f"match any of the input formats. See the {len(errors)} error(s) that "
f"occurred while attempting to deserialize the data {data}:{err_msgs}"
)
return dataset

@classmethod
def _deserialize_with_specified_deserializer(
cls,
data: Any,
type_: str,
processor_factory: Callable[[], PreTrainedTokenizerBase],
random_seed: int = 42,
**data_kwargs: dict[str, Any],
) -> Dataset:
deserializer_from_type = cls.get_registered_object(type_)
if deserializer_from_type is None:
raise ValueError(f"Deserializer type '{type_}' is not registered.")
if isinstance(deserializer_from_type, type):
deserializer_fn = deserializer_from_type()
else:
deserializer_fn = deserializer_from_type

return deserializer_fn(
data=data,
processor_factory=processor_factory,
random_seed=random_seed,
**data_kwargs,
)

28 changes: 14 additions & 14 deletions src/guidellm/data/deserializers/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def __call__(
processor_factory: Callable[[], PreTrainedTokenizerBase],
random_seed: int,
**data_kwargs: dict[str, Any],
) -> dict[str, list]:
) -> Dataset:
_ = (processor_factory, random_seed) # Ignore unused args format errors

if (
not isinstance(data, (str, Path))
not isinstance(data, str | Path)
or not (path := Path(data)).exists()
or not path.is_file()
or path.suffix.lower() not in {".txt", ".text"}
Expand All @@ -62,10 +62,10 @@ def __call__(
processor_factory: Callable[[], PreTrainedTokenizerBase],
random_seed: int,
**data_kwargs: dict[str, Any],
) -> dict[str, list]:
) -> Dataset:
_ = (processor_factory, random_seed)
if (
not isinstance(data, (str, Path))
not isinstance(data, str | Path)
or not (path := Path(data)).exists()
or not path.is_file()
or path.suffix.lower() != ".csv"
Expand All @@ -86,10 +86,10 @@ def __call__(
processor_factory: Callable[[], PreTrainedTokenizerBase],
random_seed: int,
**data_kwargs: dict[str, Any],
) -> dict[str, list]:
) -> Dataset:
_ = (processor_factory, random_seed)
if (
not isinstance(data, (str, Path))
not isinstance(data, str | Path)
or not (path := Path(data)).exists()
or not path.is_file()
or path.suffix.lower() not in {".json", ".jsonl"}
Expand All @@ -110,10 +110,10 @@ def __call__(
processor_factory: Callable[[], PreTrainedTokenizerBase],
random_seed: int,
**data_kwargs: dict[str, Any],
) -> dict[str, list]:
) -> Dataset:
_ = (processor_factory, random_seed)
if (
not isinstance(data, (str, Path))
not isinstance(data, str | Path)
or not (path := Path(data)).exists()
or not path.is_file()
or path.suffix.lower() != ".parquet"
Expand All @@ -134,10 +134,10 @@ def __call__(
processor_factory: Callable[[], PreTrainedTokenizerBase],
random_seed: int,
**data_kwargs: dict[str, Any],
) -> dict[str, list]:
) -> Dataset:
_ = (processor_factory, random_seed)
if (
not isinstance(data, (str, Path))
not isinstance(data, str | Path)
or not (path := Path(data)).exists()
or not path.is_file()
or path.suffix.lower() != ".arrow"
Expand All @@ -158,10 +158,10 @@ def __call__(
processor_factory: Callable[[], PreTrainedTokenizerBase],
random_seed: int,
**data_kwargs: dict[str, Any],
) -> dict[str, list]:
) -> Dataset:
_ = (processor_factory, random_seed)
if (
not isinstance(data, (str, Path))
not isinstance(data, str | Path)
or not (path := Path(data)).exists()
or not path.is_file()
or path.suffix.lower() not in {".hdf5", ".h5"}
Expand All @@ -185,7 +185,7 @@ def __call__(
) -> dict[str, list]:
_ = (processor_factory, random_seed)
if (
not isinstance(data, (str, Path))
not isinstance(data, str | Path)
or not (path := Path(data)).exists()
or not path.is_file()
or path.suffix.lower() != ".db"
Expand All @@ -209,7 +209,7 @@ def __call__(
) -> dict[str, list]:
_ = (processor_factory, random_seed)
if (
not isinstance(data, (str, Path))
not isinstance(data, str | Path)
or not (path := Path(data)).exists()
or not path.is_file()
or path.suffix.lower() != ".tar"
Expand Down
2 changes: 1 addition & 1 deletion src/guidellm/data/deserializers/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __call__(
processor_factory: Callable[[], PreTrainedTokenizerBase],
random_seed: int,
**data_kwargs: dict[str, Any],
) -> dict[str, list]:
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
_ = (processor_factory, random_seed)

if isinstance(
Expand Down
38 changes: 20 additions & 18 deletions src/guidellm/data/deserializers/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __call__(
processor_factory: Callable[[], PreTrainedTokenizerBase],
random_seed: int,
**data_kwargs: dict[str, Any],
) -> dict[str, list]:
) -> Dataset:
_ = (processor_factory, random_seed) # Ignore unused args format errors

if (
Expand Down Expand Up @@ -67,7 +67,7 @@ def __call__(
processor_factory: Callable[[], PreTrainedTokenizerBase],
random_seed: int,
**data_kwargs: dict[str, Any],
) -> dict[str, list]:
) -> Dataset:
_ = (processor_factory, random_seed) # Ignore unused args format errors

if (
Expand All @@ -81,9 +81,9 @@ def __call__(
f"expected list of dicts, got {data}"
)

data: list[dict[str, Any]] = cast("list[dict[str, Any]]", data)
first_keys = set(data[0].keys())
for index, item in enumerate(data):
typed_data: list[dict[str, Any]] = cast("list[dict[str, Any]]", data)
first_keys = set(typed_data[0].keys())
for index, item in enumerate(typed_data):
if set(item.keys()) != first_keys:
raise DataNotSupportedError(
f"All dictionaries must have the same keys. "
Expand All @@ -92,8 +92,8 @@ def __call__(
)

# Convert list of dicts to dict of lists
result_dict = {key: [] for key in first_keys}
for item in data:
result_dict: dict = {key: [] for key in first_keys}
for item in typed_data:
for key, value in item.items():
result_dict[key].append(value)

Expand All @@ -108,7 +108,7 @@ def __call__(
processor_factory: Callable[[], PreTrainedTokenizerBase],
random_seed: int,
**data_kwargs: dict[str, Any],
) -> dict[str, list]:
) -> Dataset:
_ = (processor_factory, random_seed) # Ignore unused args format errors

primitive_types = (str, int, float, bool, type(None))
Expand All @@ -135,7 +135,7 @@ def __call__(
processor_factory: Callable[[], PreTrainedTokenizerBase],
random_seed: int,
**data_kwargs: dict[str, Any],
) -> dict[str, list]:
) -> Dataset:
if (
isinstance(data, str)
and (json_str := data.strip())
Expand All @@ -145,16 +145,18 @@ def __call__(
)
):
with contextlib.suppress(Exception):
parsed = json.loads(data)
parsed_data = json.loads(data)

for deserializer in [
InMemoryDictDatasetDeserializer,
InMemoryDictListDatasetDeserializer,
InMemoryItemListDatasetDeserializer,
]:
deserializers = [
InMemoryDictDatasetDeserializer(),
InMemoryDictListDatasetDeserializer(),
InMemoryItemListDatasetDeserializer(),
]

for deserializer in deserializers:
with contextlib.suppress(DataNotSupportedError):
return deserializer()(
parsed, data_kwargs, processor_factory, random_seed
return deserializer(
parsed_data, processor_factory, random_seed, **data_kwargs
)

raise DataNotSupportedError(
Expand All @@ -171,7 +173,7 @@ def __call__(
processor_factory: Callable[[], PreTrainedTokenizerBase],
random_seed: int,
**data_kwargs: dict[str, Any],
) -> dict[str, list]:
) -> Dataset:
if (
isinstance(data, str)
and (csv_str := data.strip())
Expand Down
Loading
Loading