From 5c06a8adf24a782e9741c0e21d86f5cf856aa762 Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Tue, 21 Oct 2025 16:03:27 -0400 Subject: [PATCH 1/7] Fix data package type and quality errors Signed-off-by: Jared O'Connell --- .../data/deserializers/deserializer.py | 84 ++++++++++++------- src/guidellm/data/deserializers/file.py | 28 +++---- .../data/deserializers/huggingface.py | 2 +- src/guidellm/data/deserializers/memory.py | 38 +++++---- src/guidellm/data/deserializers/synthetic.py | 28 ++++--- src/guidellm/data/loaders.py | 12 ++- src/guidellm/data/preprocessors/formatters.py | 26 ++++-- src/guidellm/data/preprocessors/mappers.py | 4 +- .../data/preprocessors/preprocessor.py | 8 +- src/guidellm/data/processor.py | 5 +- src/guidellm/data/schemas.py | 1 + src/guidellm/data/utils/dataset.py | 4 +- src/guidellm/data/utils/functions.py | 4 +- 13 files changed, 143 insertions(+), 101 deletions(-) diff --git a/src/guidellm/data/deserializers/deserializer.py b/src/guidellm/data/deserializers/deserializer.py index b1e69f37..96e3d4ba 100644 --- a/src/guidellm/data/deserializers/deserializer.py +++ b/src/guidellm/data/deserializers/deserializer.py @@ -4,7 +4,7 @@ from collections.abc import Callable from typing import Any, Protocol, Union, runtime_checkable -from datasets import Dataset, IterableDataset +from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict from transformers import PreTrainedTokenizerBase from guidellm.data.utils import resolve_dataset_split @@ -29,7 +29,7 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: ... + ) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: ... class DatasetDeserializerFactory( @@ -47,40 +47,19 @@ def deserialize( remove_columns: list[str] | None = None, **data_kwargs: dict[str, Any], ) -> Dataset | IterableDataset: - dataset = None + dataset: Dataset | None = None if type_ is None: - errors = [] - # Note: There is no priority order for the deserializers, so all deserializers - # must be mutually exclusive to ensure deterministic behavior. - for name, deserializer in cls.registry.items(): - deserializer_fn: DatasetDeserializer = ( - deserializer() if isinstance(deserializer, type) else deserializer - ) - - try: - with contextlib.suppress(DataNotSupportedError): - dataset = deserializer_fn( - data=data, - processor_factory=processor_factory, - random_seed=random_seed, - **data_kwargs, - ) - except Exception as e: - errors.append(e) - - if dataset is not None: - break # Found one that works. Continuing could overwrite it. - - if dataset is None and len(errors) > 0: - raise DataNotSupportedError(f"data deserialization failed; {len(errors)} errors occurred while " - f"attempting to deserialize data {data}: {errors}") - - elif deserializer := cls.get_registered_object(type_) is not None: - deserializer_fn: DatasetDeserializer = ( - deserializer() if isinstance(deserializer, type) else deserializer + dataset = cls._deserialize_with_registered_deserializers( + data, processor_factory, random_seed, **data_kwargs ) + elif (deserializer_from_type := cls.get_registered_object(type_)) is not None: + if isinstance(deserializer_from_type, type): + deserializer_fn = deserializer_from_type() + else: + deserializer_fn = deserializer_from_type + dataset = deserializer_fn( data=data, processor_factory=processor_factory, @@ -107,3 +86,44 @@ def deserialize( dataset = dataset.remove_columns(remove_columns) return dataset + + @classmethod + def _deserialize_with_registered_deserializers( + cls, + data: Any, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int = 42, + **data_kwargs: dict[str, Any], + ) -> Dataset: + if cls.registry is None: + raise RuntimeError("registry is None; cannot deserialize dataset") + dataset: Dataset | None = None + + errors = [] + # Note: There is no priority order for the deserializers, so all deserializers + # must be mutually exclusive to ensure deterministic behavior. + for _name, deserializer in cls.registry.items(): + deserializer_fn: DatasetDeserializer = ( + deserializer() if isinstance(deserializer, type) else deserializer + ) + + try: + with contextlib.suppress(DataNotSupportedError): + dataset = deserializer_fn( + data=data, + processor_factory=processor_factory, + random_seed=random_seed, + **data_kwargs, + ) + except Exception as e: # noqa: BLE001 # The exceptions are saved. + errors.append(e) + + if dataset is not None: + break # Found one that works. Continuing could overwrite it. + + if dataset is None and len(errors) > 0: + raise DataNotSupportedError( + f"data deserialization failed; {len(errors)} errors occurred while " + f"attempting to deserialize data {data}: {errors}" + ) + return dataset diff --git a/src/guidellm/data/deserializers/file.py b/src/guidellm/data/deserializers/file.py index d57403db..9819e173 100644 --- a/src/guidellm/data/deserializers/file.py +++ b/src/guidellm/data/deserializers/file.py @@ -34,11 +34,11 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: _ = (processor_factory, random_seed) # Ignore unused args format errors if ( - not isinstance(data, (str, Path)) + not isinstance(data, str | Path) or not (path := Path(data)).exists() or not path.is_file() or path.suffix.lower() not in {".txt", ".text"} @@ -62,10 +62,10 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: _ = (processor_factory, random_seed) if ( - not isinstance(data, (str, Path)) + not isinstance(data, str | Path) or not (path := Path(data)).exists() or not path.is_file() or path.suffix.lower() != ".csv" @@ -86,10 +86,10 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: _ = (processor_factory, random_seed) if ( - not isinstance(data, (str, Path)) + not isinstance(data, str | Path) or not (path := Path(data)).exists() or not path.is_file() or path.suffix.lower() not in {".json", ".jsonl"} @@ -110,10 +110,10 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: _ = (processor_factory, random_seed) if ( - not isinstance(data, (str, Path)) + not isinstance(data, str | Path) or not (path := Path(data)).exists() or not path.is_file() or path.suffix.lower() != ".parquet" @@ -134,10 +134,10 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: _ = (processor_factory, random_seed) if ( - not isinstance(data, (str, Path)) + not isinstance(data, str | Path) or not (path := Path(data)).exists() or not path.is_file() or path.suffix.lower() != ".arrow" @@ -158,10 +158,10 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: _ = (processor_factory, random_seed) if ( - not isinstance(data, (str, Path)) + not isinstance(data, str | Path) or not (path := Path(data)).exists() or not path.is_file() or path.suffix.lower() not in {".hdf5", ".h5"} @@ -185,7 +185,7 @@ def __call__( ) -> dict[str, list]: _ = (processor_factory, random_seed) if ( - not isinstance(data, (str, Path)) + not isinstance(data, str | Path) or not (path := Path(data)).exists() or not path.is_file() or path.suffix.lower() != ".db" @@ -209,7 +209,7 @@ def __call__( ) -> dict[str, list]: _ = (processor_factory, random_seed) if ( - not isinstance(data, (str, Path)) + not isinstance(data, str | Path) or not (path := Path(data)).exists() or not path.is_file() or path.suffix.lower() != ".tar" diff --git a/src/guidellm/data/deserializers/huggingface.py b/src/guidellm/data/deserializers/huggingface.py index 80e0ed8c..efe6882a 100644 --- a/src/guidellm/data/deserializers/huggingface.py +++ b/src/guidellm/data/deserializers/huggingface.py @@ -36,7 +36,7 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: _ = (processor_factory, random_seed) if isinstance( diff --git a/src/guidellm/data/deserializers/memory.py b/src/guidellm/data/deserializers/memory.py index 6f8888ec..59051b45 100644 --- a/src/guidellm/data/deserializers/memory.py +++ b/src/guidellm/data/deserializers/memory.py @@ -33,7 +33,7 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: _ = (processor_factory, random_seed) # Ignore unused args format errors if ( @@ -67,7 +67,7 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: _ = (processor_factory, random_seed) # Ignore unused args format errors if ( @@ -81,9 +81,9 @@ def __call__( f"expected list of dicts, got {data}" ) - data: list[dict[str, Any]] = cast("list[dict[str, Any]]", data) - first_keys = set(data[0].keys()) - for index, item in enumerate(data): + typed_data: list[dict[str, Any]] = cast("list[dict[str, Any]]", data) + first_keys = set(typed_data[0].keys()) + for index, item in enumerate(typed_data): if set(item.keys()) != first_keys: raise DataNotSupportedError( f"All dictionaries must have the same keys. " @@ -92,8 +92,8 @@ def __call__( ) # Convert list of dicts to dict of lists - result_dict = {key: [] for key in first_keys} - for item in data: + result_dict: dict = {key: [] for key in first_keys} + for item in typed_data: for key, value in item.items(): result_dict[key].append(value) @@ -108,7 +108,7 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: _ = (processor_factory, random_seed) # Ignore unused args format errors primitive_types = (str, int, float, bool, type(None)) @@ -135,7 +135,7 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: if ( isinstance(data, str) and (json_str := data.strip()) @@ -145,16 +145,18 @@ def __call__( ) ): with contextlib.suppress(Exception): - parsed = json.loads(data) + parsed_data = json.loads(data) - for deserializer in [ - InMemoryDictDatasetDeserializer, - InMemoryDictListDatasetDeserializer, - InMemoryItemListDatasetDeserializer, - ]: + deserializers = [ + InMemoryDictDatasetDeserializer(), + InMemoryDictListDatasetDeserializer(), + InMemoryItemListDatasetDeserializer(), + ] + + for deserializer in deserializers: with contextlib.suppress(DataNotSupportedError): - return deserializer()( - parsed, data_kwargs, processor_factory, random_seed + return deserializer( + parsed_data, processor_factory, random_seed, **data_kwargs ) raise DataNotSupportedError( @@ -171,7 +173,7 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: if ( isinstance(data, str) and (csv_str := data.strip()) diff --git a/src/guidellm/data/deserializers/synthetic.py b/src/guidellm/data/deserializers/synthetic.py index d9e415c6..a52fb907 100644 --- a/src/guidellm/data/deserializers/synthetic.py +++ b/src/guidellm/data/deserializers/synthetic.py @@ -99,21 +99,25 @@ class SyntheticTextDatasetConfig(StandardBaseModel): @model_validator(mode="after") def check_prefix_options(self) -> SyntheticTextDatasetConfig: - prefix_count = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined] - prefix_tokens = self.__pydantic_extra__.get("prefix_tokens", None) # type: ignore[attr-defined] - if prefix_count is not None or prefix_tokens is not None: - if self.prefix_buckets: + prefix_count: Any | None = None + prefix_tokens: Any | None = None + if self.__pydantic_extra__ is not None: + prefix_count = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined] + prefix_tokens = self.__pydantic_extra__.get("prefix_tokens", None) # type: ignore[attr-defined] + + if (prefix_count is not None or prefix_tokens is not None + and self.prefix_buckets): raise ValueError( "prefix_buckets is mutually exclusive" " with prefix_count and prefix_tokens" ) - self.prefix_buckets = [ - SyntheticTextPrefixBucketConfig( - prefix_count=prefix_count or 1, - prefix_tokens=prefix_tokens or 0, - ) - ] + self.prefix_buckets = [ + SyntheticTextPrefixBucketConfig( + prefix_count=prefix_count or 1, + prefix_tokens=prefix_tokens or 0, + ) + ] return self @@ -174,14 +178,14 @@ def __iter__(self) -> Iterator[dict[str, Any]]: def _create_prompt( self, prompt_tokens_count: int, faker: Faker, unique: str = "" ) -> str: - prompt_token_ids = [] + prompt_token_ids: list[int] = [] avg_chars_per_token = 5 margin_of_safety = 1.5 attempts = 0 while len(prompt_token_ids) < prompt_tokens_count: attempts += 1 - num_chars = ( + num_chars = math.ceil( prompt_tokens_count * avg_chars_per_token * margin_of_safety * attempts ) text = unique + faker.text(max_nb_chars=num_chars) diff --git a/src/guidellm/data/loaders.py b/src/guidellm/data/loaders.py index fd46334d..1fa1f849 100644 --- a/src/guidellm/data/loaders.py +++ b/src/guidellm/data/loaders.py @@ -16,6 +16,8 @@ __all__ = ["DataLoader", "DatasetsIterator"] +from guidellm.schemas import GenerationRequest + class DatasetsIterator(TorchIterableDataset): def __init__( @@ -85,7 +87,7 @@ def generator( while max_items is None or gen_count < max_items: try: - row = { + row: dict[str, Any] = { "items": [next(dataset_iter) for dataset_iter in dataset_iters] } gen_count += 1 @@ -98,9 +100,13 @@ def generator( continue for preprocessor in self.preprocessors: - row = preprocessor(row) + processed_row = preprocessor(row) + if isinstance(processed_row, GenerationRequest): + yield processed_row + else: + row = processed_row yield row - except Exception as err: + except Exception as err: # noqa: BLE001 # Exception logged logger.error(f"Skipping data row due to error: {err}") gen_count -= 1 diff --git a/src/guidellm/data/preprocessors/formatters.py b/src/guidellm/data/preprocessors/formatters.py index a5d3d0bc..46407d4a 100644 --- a/src/guidellm/data/preprocessors/formatters.py +++ b/src/guidellm/data/preprocessors/formatters.py @@ -7,7 +7,6 @@ DatasetPreprocessor, PreprocessorRegistry, ) -from guidellm.data.schemas import GenerativeDatasetColumnType from guidellm.data.utils import text_stats from guidellm.schemas import GenerationRequest, GenerationRequestArguments, UsageMetrics @@ -59,9 +58,13 @@ def __init__( self.max_tokens: int | None = max_tokens or max_completion_tokens def __call__( - self, columns: dict[GenerativeDatasetColumnType, list[Any]] + self, columns: dict[str, list[Any]] ) -> GenerationRequest: - arguments: GenerationRequestArguments = GenerationRequestArguments(body={}) + """ + :param columns: A dict of GenerativeDatasetColumnType to Any + """ + arguments: GenerationRequestArguments = GenerationRequestArguments() + arguments.body = {} # The type checker works better setting this field here input_metrics = UsageMetrics() output_metrics = UsageMetrics() @@ -101,7 +104,7 @@ def __call__( if prefix or text: arguments.body["prompt"] = prefix + text stats = text_stats(arguments.body["prompt"]) - input_metrics.text_characters = stats.get("num_chars") + input_metrics.text_characters = stats.get("num_chars") # type: ignore[assignment] # input_metrics.text_words = stats.get("num_words") return GenerationRequest( @@ -142,9 +145,13 @@ def __init__( ) def __call__( # noqa: C901, PLR0912, PLR0915 - self, columns: dict[GenerativeDatasetColumnType, list[Any]] + self, columns: dict[str, list[Any]] ) -> GenerationRequest: - arguments = GenerationRequestArguments(body={}) + """ + :param columns: A dict of GenerativeDatasetColumnType to Any + """ + arguments = GenerationRequestArguments() + arguments.body = {} # The type checker works best with body assigned here input_metrics = UsageMetrics() output_metrics = UsageMetrics() @@ -329,9 +336,10 @@ def __init__( self.encode_audio_kwargs = encode_kwargs or {} def __call__( # noqa: C901 - self, columns: dict[GenerativeDatasetColumnType, list[Any]] + self, columns: dict[str, list[Any]] ) -> GenerationRequest: - arguments = GenerationRequestArguments(body={}, files={}) + arguments = GenerationRequestArguments(files={}) + arguments.body = {} # The type checker works best with body assigned here input_metrics = UsageMetrics() output_metrics = UsageMetrics() @@ -405,7 +413,7 @@ class GenerativeAudioTranslationRequestFormatter( GenerativeAudioTranscriptionRequestFormatter ): def __call__( - self, columns: dict[GenerativeDatasetColumnType, list[Any]] + self, columns: dict[str, list[Any]] ) -> GenerationRequest: result = super().__call__(columns) result.request_type = "audio_translations" diff --git a/src/guidellm/data/preprocessors/mappers.py b/src/guidellm/data/preprocessors/mappers.py index 0783103b..1eced9fe 100644 --- a/src/guidellm/data/preprocessors/mappers.py +++ b/src/guidellm/data/preprocessors/mappers.py @@ -169,12 +169,12 @@ def __init__( def __call__( self, row: dict[str, Any] - ) -> dict[GenerativeDatasetColumnType, list[Any]]: + ) -> dict[str, list[Any]]: if self.datasets_column_mappings is None: raise ValueError("DefaultGenerativeColumnMapper not setup with data.") items = cast("dict[int, dict[str, Any]]", row.pop("items")) - mapped: dict[GenerativeDatasetColumnType, list[Any]] = defaultdict(list) + mapped: dict[str, Any] = defaultdict(list) for column_type, column_mappings in self.datasets_column_mappings.items(): for ( diff --git a/src/guidellm/data/preprocessors/preprocessor.py b/src/guidellm/data/preprocessors/preprocessor.py index eefb53d3..0b4bc49a 100644 --- a/src/guidellm/data/preprocessors/preprocessor.py +++ b/src/guidellm/data/preprocessors/preprocessor.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Any, Protocol, Union, runtime_checkable +from typing import Any, Protocol, runtime_checkable from datasets import Dataset, IterableDataset +from guidellm.schemas import GenerationRequest from guidellm.utils import RegistryMixin __all__ = ["DataDependentPreprocessor", "DatasetPreprocessor", "PreprocessorRegistry"] @@ -11,7 +12,8 @@ @runtime_checkable class DatasetPreprocessor(Protocol): - def __call__(self, item: dict[str, Any]) -> dict[str, Any]: ... + def __call__(self, item: dict[str, Any]) -> ( + GenerationRequest | dict[str, Any]): ... @runtime_checkable @@ -24,6 +26,6 @@ def setup_data( class PreprocessorRegistry( - RegistryMixin[Union[DataDependentPreprocessor, type[DataDependentPreprocessor]]] + RegistryMixin[DataDependentPreprocessor | type[DataDependentPreprocessor]] ): pass diff --git a/src/guidellm/data/processor.py b/src/guidellm/data/processor.py index 645683c4..7962bfbf 100644 --- a/src/guidellm/data/processor.py +++ b/src/guidellm/data/processor.py @@ -23,8 +23,9 @@ def __call__(self) -> PreTrainedTokenizerBase: if isinstance(self.processor, PreTrainedTokenizerBase): return self.processor else: - self.processor = AutoTokenizer.from_pretrained( + from_pretrained = AutoTokenizer.from_pretrained( self.processor, **(self.processor_args or {}), ) - return self.processor + self.processor = from_pretrained + return from_pretrained diff --git a/src/guidellm/data/schemas.py b/src/guidellm/data/schemas.py index c4421e07..cfc1d0a5 100644 --- a/src/guidellm/data/schemas.py +++ b/src/guidellm/data/schemas.py @@ -10,4 +10,5 @@ "image_column", "video_column", "audio_column", + "items" # special case ] diff --git a/src/guidellm/data/utils/dataset.py b/src/guidellm/data/utils/dataset.py index 9656c1a7..e5108c44 100644 --- a/src/guidellm/data/utils/dataset.py +++ b/src/guidellm/data/utils/dataset.py @@ -73,7 +73,7 @@ def resolve_dataset_split( dataset: Dataset | IterableDataset | DatasetDict | IterableDatasetDict, split: str | None = None, ) -> Dataset | IterableDataset: - if split is not None and isinstance(dataset, (DatasetDict, IterableDatasetDict)): + if split is not None and isinstance(dataset, DatasetDict | IterableDatasetDict): if split in dataset: return dataset[split] @@ -83,7 +83,7 @@ def resolve_dataset_split( f"Requested split '{split}' but dataset has no splits: {dataset}." ) - if isinstance(dataset, (Dataset, IterableDataset)): + if isinstance(dataset, Dataset | IterableDataset): return dataset for _, default_splits in DEFAULT_SPLITS.items(): diff --git a/src/guidellm/data/utils/functions.py b/src/guidellm/data/utils/functions.py index 4260b1f1..0d47fb96 100644 --- a/src/guidellm/data/utils/functions.py +++ b/src/guidellm/data/utils/functions.py @@ -5,14 +5,12 @@ def text_stats( text: str, -) -> dict[Literal["type", "text", "num_chars", "num_words"], str | int]: +) -> dict[Literal["num_chars", "num_words"], int]: """Compute basic text statistics.""" num_chars = len(text) num_words = len(text.split()) return { - "type": "text", - "text": text, "num_chars": num_chars, "num_words": num_words, } From 59ab00354e386c87154957f7b0e8ad9dacf2333d Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Tue, 21 Oct 2025 16:07:14 -0400 Subject: [PATCH 2/7] Remove no longer applicable type ignore Signed-off-by: Jared O'Connell --- src/guidellm/data/preprocessors/formatters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/guidellm/data/preprocessors/formatters.py b/src/guidellm/data/preprocessors/formatters.py index 46407d4a..15e8d788 100644 --- a/src/guidellm/data/preprocessors/formatters.py +++ b/src/guidellm/data/preprocessors/formatters.py @@ -104,7 +104,7 @@ def __call__( if prefix or text: arguments.body["prompt"] = prefix + text stats = text_stats(arguments.body["prompt"]) - input_metrics.text_characters = stats.get("num_chars") # type: ignore[assignment] # + input_metrics.text_characters = stats.get("num_chars") input_metrics.text_words = stats.get("num_words") return GenerationRequest( From 787d6feb9512b920581dfa2a95987fa450f32961 Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Tue, 21 Oct 2025 16:08:41 -0400 Subject: [PATCH 3/7] Remove change based on reverted changes Signed-off-by: Jared O'Connell --- src/guidellm/data/schemas.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/guidellm/data/schemas.py b/src/guidellm/data/schemas.py index cfc1d0a5..c4421e07 100644 --- a/src/guidellm/data/schemas.py +++ b/src/guidellm/data/schemas.py @@ -10,5 +10,4 @@ "image_column", "video_column", "audio_column", - "items" # special case ] From ad41a6664a31ed000affdb88073d67411c72a20e Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Tue, 21 Oct 2025 16:27:21 -0400 Subject: [PATCH 4/7] Fix logic error introduced in prior changes Signed-off-by: Jared O'Connell --- src/guidellm/data/deserializers/synthetic.py | 26 +++++++++----------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/guidellm/data/deserializers/synthetic.py b/src/guidellm/data/deserializers/synthetic.py index a52fb907..dc905130 100644 --- a/src/guidellm/data/deserializers/synthetic.py +++ b/src/guidellm/data/deserializers/synthetic.py @@ -99,25 +99,23 @@ class SyntheticTextDatasetConfig(StandardBaseModel): @model_validator(mode="after") def check_prefix_options(self) -> SyntheticTextDatasetConfig: - prefix_count: Any | None = None - prefix_tokens: Any | None = None if self.__pydantic_extra__ is not None: prefix_count = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined] prefix_tokens = self.__pydantic_extra__.get("prefix_tokens", None) # type: ignore[attr-defined] - if (prefix_count is not None or prefix_tokens is not None - and self.prefix_buckets): - raise ValueError( - "prefix_buckets is mutually exclusive" - " with prefix_count and prefix_tokens" - ) + if prefix_count is not None or prefix_tokens is not None: + if self.prefix_buckets: + raise ValueError( + "prefix_buckets is mutually exclusive" + " with prefix_count and prefix_tokens" + ) - self.prefix_buckets = [ - SyntheticTextPrefixBucketConfig( - prefix_count=prefix_count or 1, - prefix_tokens=prefix_tokens or 0, - ) - ] + self.prefix_buckets = [ + SyntheticTextPrefixBucketConfig( + prefix_count=prefix_count or 1, + prefix_tokens=prefix_tokens or 0, + ) + ] return self From 81817f4c53b4af579891f749fba6f6bbf80127bb Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Tue, 21 Oct 2025 18:44:33 -0400 Subject: [PATCH 5/7] Address review comments Signed-off-by: Jared O'Connell --- src/guidellm/data/deserializers/synthetic.py | 2 +- src/guidellm/data/loaders.py | 10 ++--- src/guidellm/data/preprocessors/formatters.py | 43 +++++++++---------- src/guidellm/data/utils/__init__.py | 4 -- src/guidellm/data/utils/functions.py | 16 ------- 5 files changed, 25 insertions(+), 50 deletions(-) delete mode 100644 src/guidellm/data/utils/functions.py diff --git a/src/guidellm/data/deserializers/synthetic.py b/src/guidellm/data/deserializers/synthetic.py index dc905130..f1184e9e 100644 --- a/src/guidellm/data/deserializers/synthetic.py +++ b/src/guidellm/data/deserializers/synthetic.py @@ -183,7 +183,7 @@ def _create_prompt( while len(prompt_token_ids) < prompt_tokens_count: attempts += 1 - num_chars = math.ceil( + num_chars = int( prompt_tokens_count * avg_chars_per_token * margin_of_safety * attempts ) text = unique + faker.text(max_nb_chars=num_chars) diff --git a/src/guidellm/data/loaders.py b/src/guidellm/data/loaders.py index 1fa1f849..e260eef5 100644 --- a/src/guidellm/data/loaders.py +++ b/src/guidellm/data/loaders.py @@ -16,7 +16,6 @@ __all__ = ["DataLoader", "DatasetsIterator"] -from guidellm.schemas import GenerationRequest class DatasetsIterator(TorchIterableDataset): @@ -100,11 +99,10 @@ def generator( continue for preprocessor in self.preprocessors: - processed_row = preprocessor(row) - if isinstance(processed_row, GenerationRequest): - yield processed_row - else: - row = processed_row + # This can assign a GenerationRequest, which would then be + # passed into the preprocessor, which is a type violation. + # This should be fixed at some point. + row = preprocessor(row) # type: ignore[assignment] yield row except Exception as err: # noqa: BLE001 # Exception logged logger.error(f"Skipping data row due to error: {err}") diff --git a/src/guidellm/data/preprocessors/formatters.py b/src/guidellm/data/preprocessors/formatters.py index 15e8d788..a253539d 100644 --- a/src/guidellm/data/preprocessors/formatters.py +++ b/src/guidellm/data/preprocessors/formatters.py @@ -7,7 +7,6 @@ DatasetPreprocessor, PreprocessorRegistry, ) -from guidellm.data.utils import text_stats from guidellm.schemas import GenerationRequest, GenerationRequestArguments, UsageMetrics __all__ = [ @@ -102,10 +101,10 @@ def __call__( prefix = "".join(pre for pre in columns.get("prefix_column", []) if pre) text = "".join(txt for txt in columns.get("text_column", []) if txt) if prefix or text: - arguments.body["prompt"] = prefix + text - stats = text_stats(arguments.body["prompt"]) - input_metrics.text_characters = stats.get("num_chars") - input_metrics.text_words = stats.get("num_words") + prompt = prefix + text + arguments.body["prompt"] = prompt + input_metrics.text_characters = len(prompt) + input_metrics.text_words = len(prompt.split()) return GenerationRequest( request_type="text_completions", @@ -198,13 +197,12 @@ def __call__( # noqa: C901, PLR0912, PLR0915 if not prefix: continue - stats = text_stats(prefix) - if (num_chars := stats.get("num_chars")) is not None: - input_metrics.text_characters = ( - input_metrics.text_characters or 0 - ) + num_chars - if (num_words := stats.get("num_words")) is not None: - input_metrics.text_words = (input_metrics.text_words or 0) + num_words + input_metrics.text_characters = ( + input_metrics.text_characters or 0 + ) + len(prefix) + + input_metrics.text_words = (input_metrics.text_words or 0) + \ + len(prefix.split()) arguments.body["messages"].append({"role": "system", "content": prefix}) @@ -212,13 +210,12 @@ def __call__( # noqa: C901, PLR0912, PLR0915 if not text: continue - stats = text_stats(text) - if (num_chars := stats.get("num_chars")) is not None: - input_metrics.text_characters = ( - input_metrics.text_characters or 0 - ) + num_chars - if (num_words := stats.get("num_words")) is not None: - input_metrics.text_words = (input_metrics.text_words or 0) + num_words + input_metrics.text_characters = ( + input_metrics.text_characters or 0 + ) + len(text) + input_metrics.text_words = ( + input_metrics.text_words or 0 + ) + len(text.split()) arguments.body["messages"].append( {"role": "user", "content": [{"type": "text", "text": text}]} @@ -395,10 +392,10 @@ def __call__( # noqa: C901 prefix = "".join(pre for pre in columns.get("prefix_column", []) if pre) text = "".join(txt for txt in columns.get("text_column", []) if txt) if prefix or text: - arguments.body["prompt"] = prefix + text - stats = text_stats(arguments.body["prompt"]) - input_metrics.text_characters = stats.get("num_chars") - input_metrics.text_words = stats.get("num_words") + prompt = prefix + text + arguments.body["prompt"] = prompt + input_metrics.text_characters = len(prompt) + input_metrics.text_words = len(prompt.split()) return GenerationRequest( request_type="audio_transcriptions", diff --git a/src/guidellm/data/utils/__init__.py b/src/guidellm/data/utils/__init__.py index d71e6236..c2748cd9 100644 --- a/src/guidellm/data/utils/__init__.py +++ b/src/guidellm/data/utils/__init__.py @@ -1,10 +1,6 @@ from .dataset import DEFAULT_SPLITS, resolve_dataset_split -from .functions import ( - text_stats, -) __all__ = [ "DEFAULT_SPLITS", "resolve_dataset_split", - "text_stats", ] diff --git a/src/guidellm/data/utils/functions.py b/src/guidellm/data/utils/functions.py deleted file mode 100644 index 0d47fb96..00000000 --- a/src/guidellm/data/utils/functions.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import Literal - -__all__ = ["text_stats"] - - -def text_stats( - text: str, -) -> dict[Literal["num_chars", "num_words"], int]: - """Compute basic text statistics.""" - num_chars = len(text) - num_words = len(text.split()) - - return { - "num_chars": num_chars, - "num_words": num_words, - } From 93836bb70a88f3b21f39fa10bc345585fe95c55c Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Wed, 22 Oct 2025 15:31:25 -0400 Subject: [PATCH 6/7] Address review comment regarding the err messages Signed-off-by: Jared O'Connell --- .../data/deserializers/deserializer.py | 83 +++++++++++-------- 1 file changed, 49 insertions(+), 34 deletions(-) diff --git a/src/guidellm/data/deserializers/deserializer.py b/src/guidellm/data/deserializers/deserializer.py index 96e3d4ba..f1041f20 100644 --- a/src/guidellm/data/deserializers/deserializer.py +++ b/src/guidellm/data/deserializers/deserializer.py @@ -1,6 +1,5 @@ from __future__ import annotations -import contextlib from collections.abc import Callable from typing import Any, Protocol, Union, runtime_checkable @@ -47,30 +46,16 @@ def deserialize( remove_columns: list[str] | None = None, **data_kwargs: dict[str, Any], ) -> Dataset | IterableDataset: - dataset: Dataset | None = None + dataset: Dataset if type_ is None: dataset = cls._deserialize_with_registered_deserializers( data, processor_factory, random_seed, **data_kwargs ) - elif (deserializer_from_type := cls.get_registered_object(type_)) is not None: - if isinstance(deserializer_from_type, type): - deserializer_fn = deserializer_from_type() - else: - deserializer_fn = deserializer_from_type - - dataset = deserializer_fn( - data=data, - processor_factory=processor_factory, - random_seed=random_seed, - **data_kwargs, - ) - - if dataset is None: - raise DataNotSupportedError( - f"No suitable deserializer found for data {data} " - f"with kwargs {data_kwargs} and deserializer type {type_}." + else: + dataset = cls._deserialize_with_specified_deserializer( + data, type_, processor_factory, random_seed, **data_kwargs ) if resolve_split: @@ -99,7 +84,7 @@ def _deserialize_with_registered_deserializers( raise RuntimeError("registry is None; cannot deserialize dataset") dataset: Dataset | None = None - errors = [] + errors: dict[str, Exception] = {} # Note: There is no priority order for the deserializers, so all deserializers # must be mutually exclusive to ensure deterministic behavior. for _name, deserializer in cls.registry.items(): @@ -108,22 +93,52 @@ def _deserialize_with_registered_deserializers( ) try: - with contextlib.suppress(DataNotSupportedError): - dataset = deserializer_fn( - data=data, - processor_factory=processor_factory, - random_seed=random_seed, - **data_kwargs, - ) + dataset = deserializer_fn( + data=data, + processor_factory=processor_factory, + random_seed=random_seed, + **data_kwargs, + ) except Exception as e: # noqa: BLE001 # The exceptions are saved. - errors.append(e) + errors[_name] = e if dataset is not None: - break # Found one that works. Continuing could overwrite it. - - if dataset is None and len(errors) > 0: - raise DataNotSupportedError( - f"data deserialization failed; {len(errors)} errors occurred while " - f"attempting to deserialize data {data}: {errors}" + return dataset # Success + + if len(errors) > 0: + err_msgs = "" + def sort_key(item): + return (isinstance(item[1], DataNotSupportedError), item[0]) + for key, err in sorted(errors.items(), key=sort_key): + err_msgs += f"\n - Deserializer '{key}': ({type(err).__name__}) {err}" + raise ValueError( + "Data deserialization failed, likely because the input doesn't " + f"match any of the input formats. See the {len(errors)} error(s) that " + f"occurred while attempting to deserialize the data {data}:{err_msgs}" ) return dataset + + @classmethod + def _deserialize_with_specified_deserializer( + cls, + data: Any, + type_: str, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int = 42, + **data_kwargs: dict[str, Any], + ) -> Dataset: + deserializer_from_type = cls.get_registered_object(type_) + if deserializer_from_type is None: + raise ValueError(f"Deserializer type '{type_}' is not registered.") + if isinstance(deserializer_from_type, type): + deserializer_fn = deserializer_from_type() + else: + deserializer_fn = deserializer_from_type + + return deserializer_fn( + data=data, + processor_factory=processor_factory, + random_seed=random_seed, + **data_kwargs, + ) + From a5f60b08d89c092f9e1f904bc1165bd74ab2f953 Mon Sep 17 00:00:00 2001 From: Jared O'Connell Date: Wed, 22 Oct 2025 22:51:29 -0400 Subject: [PATCH 7/7] Reduce duplicated code for text metrics Signed-off-by: Jared O'Connell --- src/guidellm/data/preprocessors/formatters.py | 21 ++++--------------- src/guidellm/schemas/request.py | 10 +++++++++ 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/src/guidellm/data/preprocessors/formatters.py b/src/guidellm/data/preprocessors/formatters.py index a253539d..272cf604 100644 --- a/src/guidellm/data/preprocessors/formatters.py +++ b/src/guidellm/data/preprocessors/formatters.py @@ -103,8 +103,7 @@ def __call__( if prefix or text: prompt = prefix + text arguments.body["prompt"] = prompt - input_metrics.text_characters = len(prompt) - input_metrics.text_words = len(prompt.split()) + input_metrics.add_text_metrics(prompt) return GenerationRequest( request_type="text_completions", @@ -197,25 +196,14 @@ def __call__( # noqa: C901, PLR0912, PLR0915 if not prefix: continue - input_metrics.text_characters = ( - input_metrics.text_characters or 0 - ) + len(prefix) - - input_metrics.text_words = (input_metrics.text_words or 0) + \ - len(prefix.split()) - + input_metrics.add_text_metrics(prefix) arguments.body["messages"].append({"role": "system", "content": prefix}) for text in columns.get("text_column", []): if not text: continue - input_metrics.text_characters = ( - input_metrics.text_characters or 0 - ) + len(text) - input_metrics.text_words = ( - input_metrics.text_words or 0 - ) + len(text.split()) + input_metrics.add_text_metrics(text) arguments.body["messages"].append( {"role": "user", "content": [{"type": "text", "text": text}]} @@ -394,8 +382,7 @@ def __call__( # noqa: C901 if prefix or text: prompt = prefix + text arguments.body["prompt"] = prompt - input_metrics.text_characters = len(prompt) - input_metrics.text_words = len(prompt.split()) + input_metrics.add_text_metrics(prompt) return GenerationRequest( request_type="audio_transcriptions", diff --git a/src/guidellm/schemas/request.py b/src/guidellm/schemas/request.py index 9e9189fc..1f90d130 100644 --- a/src/guidellm/schemas/request.py +++ b/src/guidellm/schemas/request.py @@ -169,6 +169,16 @@ def total_tokens(self) -> int | None: self.video_tokens or 0 ) + (self.audio_tokens or 0) or None + def add_text_metrics(self, text): + """ + Adds the metrics from the given text to the fields + `text_characters` and `text_words`. + + :param text: Text to add metrics from + """ + self.text_characters = (self.text_characters or 0) + len(text) + self.text_words = (self.text_words or 0) + len(text.split()) + class GenerationRequest(StandardBaseModel): """