diff --git a/src/guidellm/data/deserializers/deserializer.py b/src/guidellm/data/deserializers/deserializer.py index b1e69f37..f1041f20 100644 --- a/src/guidellm/data/deserializers/deserializer.py +++ b/src/guidellm/data/deserializers/deserializer.py @@ -1,10 +1,9 @@ from __future__ import annotations -import contextlib from collections.abc import Callable from typing import Any, Protocol, Union, runtime_checkable -from datasets import Dataset, IterableDataset +from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict from transformers import PreTrainedTokenizerBase from guidellm.data.utils import resolve_dataset_split @@ -29,7 +28,7 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: ... + ) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: ... class DatasetDeserializerFactory( @@ -47,51 +46,16 @@ def deserialize( remove_columns: list[str] | None = None, **data_kwargs: dict[str, Any], ) -> Dataset | IterableDataset: - dataset = None + dataset: Dataset if type_ is None: - errors = [] - # Note: There is no priority order for the deserializers, so all deserializers - # must be mutually exclusive to ensure deterministic behavior. - for name, deserializer in cls.registry.items(): - deserializer_fn: DatasetDeserializer = ( - deserializer() if isinstance(deserializer, type) else deserializer - ) - - try: - with contextlib.suppress(DataNotSupportedError): - dataset = deserializer_fn( - data=data, - processor_factory=processor_factory, - random_seed=random_seed, - **data_kwargs, - ) - except Exception as e: - errors.append(e) - - if dataset is not None: - break # Found one that works. Continuing could overwrite it. - - if dataset is None and len(errors) > 0: - raise DataNotSupportedError(f"data deserialization failed; {len(errors)} errors occurred while " - f"attempting to deserialize data {data}: {errors}") - - elif deserializer := cls.get_registered_object(type_) is not None: - deserializer_fn: DatasetDeserializer = ( - deserializer() if isinstance(deserializer, type) else deserializer + dataset = cls._deserialize_with_registered_deserializers( + data, processor_factory, random_seed, **data_kwargs ) - dataset = deserializer_fn( - data=data, - processor_factory=processor_factory, - random_seed=random_seed, - **data_kwargs, - ) - - if dataset is None: - raise DataNotSupportedError( - f"No suitable deserializer found for data {data} " - f"with kwargs {data_kwargs} and deserializer type {type_}." + else: + dataset = cls._deserialize_with_specified_deserializer( + data, type_, processor_factory, random_seed, **data_kwargs ) if resolve_split: @@ -107,3 +71,74 @@ def deserialize( dataset = dataset.remove_columns(remove_columns) return dataset + + @classmethod + def _deserialize_with_registered_deserializers( + cls, + data: Any, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int = 42, + **data_kwargs: dict[str, Any], + ) -> Dataset: + if cls.registry is None: + raise RuntimeError("registry is None; cannot deserialize dataset") + dataset: Dataset | None = None + + errors: dict[str, Exception] = {} + # Note: There is no priority order for the deserializers, so all deserializers + # must be mutually exclusive to ensure deterministic behavior. + for _name, deserializer in cls.registry.items(): + deserializer_fn: DatasetDeserializer = ( + deserializer() if isinstance(deserializer, type) else deserializer + ) + + try: + dataset = deserializer_fn( + data=data, + processor_factory=processor_factory, + random_seed=random_seed, + **data_kwargs, + ) + except Exception as e: # noqa: BLE001 # The exceptions are saved. + errors[_name] = e + + if dataset is not None: + return dataset # Success + + if len(errors) > 0: + err_msgs = "" + def sort_key(item): + return (isinstance(item[1], DataNotSupportedError), item[0]) + for key, err in sorted(errors.items(), key=sort_key): + err_msgs += f"\n - Deserializer '{key}': ({type(err).__name__}) {err}" + raise ValueError( + "Data deserialization failed, likely because the input doesn't " + f"match any of the input formats. See the {len(errors)} error(s) that " + f"occurred while attempting to deserialize the data {data}:{err_msgs}" + ) + return dataset + + @classmethod + def _deserialize_with_specified_deserializer( + cls, + data: Any, + type_: str, + processor_factory: Callable[[], PreTrainedTokenizerBase], + random_seed: int = 42, + **data_kwargs: dict[str, Any], + ) -> Dataset: + deserializer_from_type = cls.get_registered_object(type_) + if deserializer_from_type is None: + raise ValueError(f"Deserializer type '{type_}' is not registered.") + if isinstance(deserializer_from_type, type): + deserializer_fn = deserializer_from_type() + else: + deserializer_fn = deserializer_from_type + + return deserializer_fn( + data=data, + processor_factory=processor_factory, + random_seed=random_seed, + **data_kwargs, + ) + diff --git a/src/guidellm/data/deserializers/file.py b/src/guidellm/data/deserializers/file.py index d57403db..9819e173 100644 --- a/src/guidellm/data/deserializers/file.py +++ b/src/guidellm/data/deserializers/file.py @@ -34,11 +34,11 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: _ = (processor_factory, random_seed) # Ignore unused args format errors if ( - not isinstance(data, (str, Path)) + not isinstance(data, str | Path) or not (path := Path(data)).exists() or not path.is_file() or path.suffix.lower() not in {".txt", ".text"} @@ -62,10 +62,10 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: _ = (processor_factory, random_seed) if ( - not isinstance(data, (str, Path)) + not isinstance(data, str | Path) or not (path := Path(data)).exists() or not path.is_file() or path.suffix.lower() != ".csv" @@ -86,10 +86,10 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: _ = (processor_factory, random_seed) if ( - not isinstance(data, (str, Path)) + not isinstance(data, str | Path) or not (path := Path(data)).exists() or not path.is_file() or path.suffix.lower() not in {".json", ".jsonl"} @@ -110,10 +110,10 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: _ = (processor_factory, random_seed) if ( - not isinstance(data, (str, Path)) + not isinstance(data, str | Path) or not (path := Path(data)).exists() or not path.is_file() or path.suffix.lower() != ".parquet" @@ -134,10 +134,10 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: _ = (processor_factory, random_seed) if ( - not isinstance(data, (str, Path)) + not isinstance(data, str | Path) or not (path := Path(data)).exists() or not path.is_file() or path.suffix.lower() != ".arrow" @@ -158,10 +158,10 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: _ = (processor_factory, random_seed) if ( - not isinstance(data, (str, Path)) + not isinstance(data, str | Path) or not (path := Path(data)).exists() or not path.is_file() or path.suffix.lower() not in {".hdf5", ".h5"} @@ -185,7 +185,7 @@ def __call__( ) -> dict[str, list]: _ = (processor_factory, random_seed) if ( - not isinstance(data, (str, Path)) + not isinstance(data, str | Path) or not (path := Path(data)).exists() or not path.is_file() or path.suffix.lower() != ".db" @@ -209,7 +209,7 @@ def __call__( ) -> dict[str, list]: _ = (processor_factory, random_seed) if ( - not isinstance(data, (str, Path)) + not isinstance(data, str | Path) or not (path := Path(data)).exists() or not path.is_file() or path.suffix.lower() != ".tar" diff --git a/src/guidellm/data/deserializers/huggingface.py b/src/guidellm/data/deserializers/huggingface.py index 80e0ed8c..efe6882a 100644 --- a/src/guidellm/data/deserializers/huggingface.py +++ b/src/guidellm/data/deserializers/huggingface.py @@ -36,7 +36,7 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: _ = (processor_factory, random_seed) if isinstance( diff --git a/src/guidellm/data/deserializers/memory.py b/src/guidellm/data/deserializers/memory.py index 6f8888ec..59051b45 100644 --- a/src/guidellm/data/deserializers/memory.py +++ b/src/guidellm/data/deserializers/memory.py @@ -33,7 +33,7 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: _ = (processor_factory, random_seed) # Ignore unused args format errors if ( @@ -67,7 +67,7 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: _ = (processor_factory, random_seed) # Ignore unused args format errors if ( @@ -81,9 +81,9 @@ def __call__( f"expected list of dicts, got {data}" ) - data: list[dict[str, Any]] = cast("list[dict[str, Any]]", data) - first_keys = set(data[0].keys()) - for index, item in enumerate(data): + typed_data: list[dict[str, Any]] = cast("list[dict[str, Any]]", data) + first_keys = set(typed_data[0].keys()) + for index, item in enumerate(typed_data): if set(item.keys()) != first_keys: raise DataNotSupportedError( f"All dictionaries must have the same keys. " @@ -92,8 +92,8 @@ def __call__( ) # Convert list of dicts to dict of lists - result_dict = {key: [] for key in first_keys} - for item in data: + result_dict: dict = {key: [] for key in first_keys} + for item in typed_data: for key, value in item.items(): result_dict[key].append(value) @@ -108,7 +108,7 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: _ = (processor_factory, random_seed) # Ignore unused args format errors primitive_types = (str, int, float, bool, type(None)) @@ -135,7 +135,7 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: if ( isinstance(data, str) and (json_str := data.strip()) @@ -145,16 +145,18 @@ def __call__( ) ): with contextlib.suppress(Exception): - parsed = json.loads(data) + parsed_data = json.loads(data) - for deserializer in [ - InMemoryDictDatasetDeserializer, - InMemoryDictListDatasetDeserializer, - InMemoryItemListDatasetDeserializer, - ]: + deserializers = [ + InMemoryDictDatasetDeserializer(), + InMemoryDictListDatasetDeserializer(), + InMemoryItemListDatasetDeserializer(), + ] + + for deserializer in deserializers: with contextlib.suppress(DataNotSupportedError): - return deserializer()( - parsed, data_kwargs, processor_factory, random_seed + return deserializer( + parsed_data, processor_factory, random_seed, **data_kwargs ) raise DataNotSupportedError( @@ -171,7 +173,7 @@ def __call__( processor_factory: Callable[[], PreTrainedTokenizerBase], random_seed: int, **data_kwargs: dict[str, Any], - ) -> dict[str, list]: + ) -> Dataset: if ( isinstance(data, str) and (csv_str := data.strip()) diff --git a/src/guidellm/data/deserializers/synthetic.py b/src/guidellm/data/deserializers/synthetic.py index d9e415c6..f1184e9e 100644 --- a/src/guidellm/data/deserializers/synthetic.py +++ b/src/guidellm/data/deserializers/synthetic.py @@ -99,21 +99,23 @@ class SyntheticTextDatasetConfig(StandardBaseModel): @model_validator(mode="after") def check_prefix_options(self) -> SyntheticTextDatasetConfig: - prefix_count = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined] - prefix_tokens = self.__pydantic_extra__.get("prefix_tokens", None) # type: ignore[attr-defined] - if prefix_count is not None or prefix_tokens is not None: - if self.prefix_buckets: - raise ValueError( - "prefix_buckets is mutually exclusive" - " with prefix_count and prefix_tokens" - ) + if self.__pydantic_extra__ is not None: + prefix_count = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined] + prefix_tokens = self.__pydantic_extra__.get("prefix_tokens", None) # type: ignore[attr-defined] + + if prefix_count is not None or prefix_tokens is not None: + if self.prefix_buckets: + raise ValueError( + "prefix_buckets is mutually exclusive" + " with prefix_count and prefix_tokens" + ) - self.prefix_buckets = [ - SyntheticTextPrefixBucketConfig( - prefix_count=prefix_count or 1, - prefix_tokens=prefix_tokens or 0, - ) - ] + self.prefix_buckets = [ + SyntheticTextPrefixBucketConfig( + prefix_count=prefix_count or 1, + prefix_tokens=prefix_tokens or 0, + ) + ] return self @@ -174,14 +176,14 @@ def __iter__(self) -> Iterator[dict[str, Any]]: def _create_prompt( self, prompt_tokens_count: int, faker: Faker, unique: str = "" ) -> str: - prompt_token_ids = [] + prompt_token_ids: list[int] = [] avg_chars_per_token = 5 margin_of_safety = 1.5 attempts = 0 while len(prompt_token_ids) < prompt_tokens_count: attempts += 1 - num_chars = ( + num_chars = int( prompt_tokens_count * avg_chars_per_token * margin_of_safety * attempts ) text = unique + faker.text(max_nb_chars=num_chars) diff --git a/src/guidellm/data/loaders.py b/src/guidellm/data/loaders.py index fd46334d..e260eef5 100644 --- a/src/guidellm/data/loaders.py +++ b/src/guidellm/data/loaders.py @@ -17,6 +17,7 @@ __all__ = ["DataLoader", "DatasetsIterator"] + class DatasetsIterator(TorchIterableDataset): def __init__( self, @@ -85,7 +86,7 @@ def generator( while max_items is None or gen_count < max_items: try: - row = { + row: dict[str, Any] = { "items": [next(dataset_iter) for dataset_iter in dataset_iters] } gen_count += 1 @@ -98,9 +99,12 @@ def generator( continue for preprocessor in self.preprocessors: - row = preprocessor(row) + # This can assign a GenerationRequest, which would then be + # passed into the preprocessor, which is a type violation. + # This should be fixed at some point. + row = preprocessor(row) # type: ignore[assignment] yield row - except Exception as err: + except Exception as err: # noqa: BLE001 # Exception logged logger.error(f"Skipping data row due to error: {err}") gen_count -= 1 diff --git a/src/guidellm/data/preprocessors/formatters.py b/src/guidellm/data/preprocessors/formatters.py index a5d3d0bc..272cf604 100644 --- a/src/guidellm/data/preprocessors/formatters.py +++ b/src/guidellm/data/preprocessors/formatters.py @@ -7,8 +7,6 @@ DatasetPreprocessor, PreprocessorRegistry, ) -from guidellm.data.schemas import GenerativeDatasetColumnType -from guidellm.data.utils import text_stats from guidellm.schemas import GenerationRequest, GenerationRequestArguments, UsageMetrics __all__ = [ @@ -59,9 +57,13 @@ def __init__( self.max_tokens: int | None = max_tokens or max_completion_tokens def __call__( - self, columns: dict[GenerativeDatasetColumnType, list[Any]] + self, columns: dict[str, list[Any]] ) -> GenerationRequest: - arguments: GenerationRequestArguments = GenerationRequestArguments(body={}) + """ + :param columns: A dict of GenerativeDatasetColumnType to Any + """ + arguments: GenerationRequestArguments = GenerationRequestArguments() + arguments.body = {} # The type checker works better setting this field here input_metrics = UsageMetrics() output_metrics = UsageMetrics() @@ -99,10 +101,9 @@ def __call__( prefix = "".join(pre for pre in columns.get("prefix_column", []) if pre) text = "".join(txt for txt in columns.get("text_column", []) if txt) if prefix or text: - arguments.body["prompt"] = prefix + text - stats = text_stats(arguments.body["prompt"]) - input_metrics.text_characters = stats.get("num_chars") - input_metrics.text_words = stats.get("num_words") + prompt = prefix + text + arguments.body["prompt"] = prompt + input_metrics.add_text_metrics(prompt) return GenerationRequest( request_type="text_completions", @@ -142,9 +143,13 @@ def __init__( ) def __call__( # noqa: C901, PLR0912, PLR0915 - self, columns: dict[GenerativeDatasetColumnType, list[Any]] + self, columns: dict[str, list[Any]] ) -> GenerationRequest: - arguments = GenerationRequestArguments(body={}) + """ + :param columns: A dict of GenerativeDatasetColumnType to Any + """ + arguments = GenerationRequestArguments() + arguments.body = {} # The type checker works best with body assigned here input_metrics = UsageMetrics() output_metrics = UsageMetrics() @@ -191,27 +196,14 @@ def __call__( # noqa: C901, PLR0912, PLR0915 if not prefix: continue - stats = text_stats(prefix) - if (num_chars := stats.get("num_chars")) is not None: - input_metrics.text_characters = ( - input_metrics.text_characters or 0 - ) + num_chars - if (num_words := stats.get("num_words")) is not None: - input_metrics.text_words = (input_metrics.text_words or 0) + num_words - + input_metrics.add_text_metrics(prefix) arguments.body["messages"].append({"role": "system", "content": prefix}) for text in columns.get("text_column", []): if not text: continue - stats = text_stats(text) - if (num_chars := stats.get("num_chars")) is not None: - input_metrics.text_characters = ( - input_metrics.text_characters or 0 - ) + num_chars - if (num_words := stats.get("num_words")) is not None: - input_metrics.text_words = (input_metrics.text_words or 0) + num_words + input_metrics.add_text_metrics(text) arguments.body["messages"].append( {"role": "user", "content": [{"type": "text", "text": text}]} @@ -329,9 +321,10 @@ def __init__( self.encode_audio_kwargs = encode_kwargs or {} def __call__( # noqa: C901 - self, columns: dict[GenerativeDatasetColumnType, list[Any]] + self, columns: dict[str, list[Any]] ) -> GenerationRequest: - arguments = GenerationRequestArguments(body={}, files={}) + arguments = GenerationRequestArguments(files={}) + arguments.body = {} # The type checker works best with body assigned here input_metrics = UsageMetrics() output_metrics = UsageMetrics() @@ -387,10 +380,9 @@ def __call__( # noqa: C901 prefix = "".join(pre for pre in columns.get("prefix_column", []) if pre) text = "".join(txt for txt in columns.get("text_column", []) if txt) if prefix or text: - arguments.body["prompt"] = prefix + text - stats = text_stats(arguments.body["prompt"]) - input_metrics.text_characters = stats.get("num_chars") - input_metrics.text_words = stats.get("num_words") + prompt = prefix + text + arguments.body["prompt"] = prompt + input_metrics.add_text_metrics(prompt) return GenerationRequest( request_type="audio_transcriptions", @@ -405,7 +397,7 @@ class GenerativeAudioTranslationRequestFormatter( GenerativeAudioTranscriptionRequestFormatter ): def __call__( - self, columns: dict[GenerativeDatasetColumnType, list[Any]] + self, columns: dict[str, list[Any]] ) -> GenerationRequest: result = super().__call__(columns) result.request_type = "audio_translations" diff --git a/src/guidellm/data/preprocessors/mappers.py b/src/guidellm/data/preprocessors/mappers.py index 0783103b..1eced9fe 100644 --- a/src/guidellm/data/preprocessors/mappers.py +++ b/src/guidellm/data/preprocessors/mappers.py @@ -169,12 +169,12 @@ def __init__( def __call__( self, row: dict[str, Any] - ) -> dict[GenerativeDatasetColumnType, list[Any]]: + ) -> dict[str, list[Any]]: if self.datasets_column_mappings is None: raise ValueError("DefaultGenerativeColumnMapper not setup with data.") items = cast("dict[int, dict[str, Any]]", row.pop("items")) - mapped: dict[GenerativeDatasetColumnType, list[Any]] = defaultdict(list) + mapped: dict[str, Any] = defaultdict(list) for column_type, column_mappings in self.datasets_column_mappings.items(): for ( diff --git a/src/guidellm/data/preprocessors/preprocessor.py b/src/guidellm/data/preprocessors/preprocessor.py index eefb53d3..0b4bc49a 100644 --- a/src/guidellm/data/preprocessors/preprocessor.py +++ b/src/guidellm/data/preprocessors/preprocessor.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Any, Protocol, Union, runtime_checkable +from typing import Any, Protocol, runtime_checkable from datasets import Dataset, IterableDataset +from guidellm.schemas import GenerationRequest from guidellm.utils import RegistryMixin __all__ = ["DataDependentPreprocessor", "DatasetPreprocessor", "PreprocessorRegistry"] @@ -11,7 +12,8 @@ @runtime_checkable class DatasetPreprocessor(Protocol): - def __call__(self, item: dict[str, Any]) -> dict[str, Any]: ... + def __call__(self, item: dict[str, Any]) -> ( + GenerationRequest | dict[str, Any]): ... @runtime_checkable @@ -24,6 +26,6 @@ def setup_data( class PreprocessorRegistry( - RegistryMixin[Union[DataDependentPreprocessor, type[DataDependentPreprocessor]]] + RegistryMixin[DataDependentPreprocessor | type[DataDependentPreprocessor]] ): pass diff --git a/src/guidellm/data/processor.py b/src/guidellm/data/processor.py index 645683c4..7962bfbf 100644 --- a/src/guidellm/data/processor.py +++ b/src/guidellm/data/processor.py @@ -23,8 +23,9 @@ def __call__(self) -> PreTrainedTokenizerBase: if isinstance(self.processor, PreTrainedTokenizerBase): return self.processor else: - self.processor = AutoTokenizer.from_pretrained( + from_pretrained = AutoTokenizer.from_pretrained( self.processor, **(self.processor_args or {}), ) - return self.processor + self.processor = from_pretrained + return from_pretrained diff --git a/src/guidellm/data/utils/__init__.py b/src/guidellm/data/utils/__init__.py index d71e6236..c2748cd9 100644 --- a/src/guidellm/data/utils/__init__.py +++ b/src/guidellm/data/utils/__init__.py @@ -1,10 +1,6 @@ from .dataset import DEFAULT_SPLITS, resolve_dataset_split -from .functions import ( - text_stats, -) __all__ = [ "DEFAULT_SPLITS", "resolve_dataset_split", - "text_stats", ] diff --git a/src/guidellm/data/utils/dataset.py b/src/guidellm/data/utils/dataset.py index 9656c1a7..e5108c44 100644 --- a/src/guidellm/data/utils/dataset.py +++ b/src/guidellm/data/utils/dataset.py @@ -73,7 +73,7 @@ def resolve_dataset_split( dataset: Dataset | IterableDataset | DatasetDict | IterableDatasetDict, split: str | None = None, ) -> Dataset | IterableDataset: - if split is not None and isinstance(dataset, (DatasetDict, IterableDatasetDict)): + if split is not None and isinstance(dataset, DatasetDict | IterableDatasetDict): if split in dataset: return dataset[split] @@ -83,7 +83,7 @@ def resolve_dataset_split( f"Requested split '{split}' but dataset has no splits: {dataset}." ) - if isinstance(dataset, (Dataset, IterableDataset)): + if isinstance(dataset, Dataset | IterableDataset): return dataset for _, default_splits in DEFAULT_SPLITS.items(): diff --git a/src/guidellm/data/utils/functions.py b/src/guidellm/data/utils/functions.py deleted file mode 100644 index 4260b1f1..00000000 --- a/src/guidellm/data/utils/functions.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Literal - -__all__ = ["text_stats"] - - -def text_stats( - text: str, -) -> dict[Literal["type", "text", "num_chars", "num_words"], str | int]: - """Compute basic text statistics.""" - num_chars = len(text) - num_words = len(text.split()) - - return { - "type": "text", - "text": text, - "num_chars": num_chars, - "num_words": num_words, - } diff --git a/src/guidellm/schemas/request.py b/src/guidellm/schemas/request.py index 9e9189fc..1f90d130 100644 --- a/src/guidellm/schemas/request.py +++ b/src/guidellm/schemas/request.py @@ -169,6 +169,16 @@ def total_tokens(self) -> int | None: self.video_tokens or 0 ) + (self.audio_tokens or 0) or None + def add_text_metrics(self, text): + """ + Adds the metrics from the given text to the fields + `text_characters` and `text_words`. + + :param text: Text to add metrics from + """ + self.text_characters = (self.text_characters or 0) + len(text) + self.text_words = (self.text_words or 0) + len(text.split()) + class GenerationRequest(StandardBaseModel): """