diff --git a/pyproject.toml b/pyproject.toml index 2effa113..f51560dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -174,6 +174,7 @@ module = [ "transformers.*", "setuptools.*", "setuptools_git_versioning.*", + "torchcodec.*" ] ignore_missing_imports = true diff --git a/src/guidellm/extras/multimodal.py b/src/guidellm/extras/multimodal.py index 5fe0a141..37d64f96 100644 --- a/src/guidellm/extras/multimodal.py +++ b/src/guidellm/extras/multimodal.py @@ -230,7 +230,7 @@ def encode_video( else: raise ValueError(f"Unsupported video type: {type(video)} for {video}") - video_base64 = base64.b64encode(video).decode("utf-8") + video_base64 = base64.b64encode(video_bytes).decode("utf-8") return { "type": "video_base64", @@ -266,8 +266,9 @@ def encode_audio( "audio_samples", "audio_seconds", "audio_bytes", + "file_name", ], - str | int | float | None, + str | int | float | bytes | None, ]: """Decode audio (if necessary) and re-encode to specified format.""" samples = _decode_audio(audio, sample_rate=sample_rate, max_duration=max_duration) @@ -338,10 +339,10 @@ def _decode_audio( # noqa: C901, PLR0912 samples: AudioSamples + data: torch.Tensor | bytes # HF datasets return AudioDecoder for audio column if isinstance(audio, AudioDecoder): samples = audio.get_samples_played_in_range(stop_seconds=max_duration) - elif isinstance(audio, torch.Tensor): # If float stream assume decoded audio if torch.is_floating_point(audio): diff --git a/src/guidellm/mock_server/handlers/chat_completions.py b/src/guidellm/mock_server/handlers/chat_completions.py index de2781b0..5f198a31 100644 --- a/src/guidellm/mock_server/handlers/chat_completions.py +++ b/src/guidellm/mock_server/handlers/chat_completions.py @@ -136,7 +136,7 @@ async def _handle_non_stream(self, req: ChatCompletionsRequest) -> HTTPResponse: # Token counts prompt_text = self.tokenizer.apply_chat_template(req.messages) - prompt_tokens = len(self.tokenizer(prompt_text)) + prompt_tokens = len(self.tokenizer(prompt_text)) # type: ignore[arg-type] max_tokens = req.max_completion_tokens or req.max_tokens or math.inf completion_tokens_count = min( sample_number(self.config.output_tokens, self.config.output_tokens_std), @@ -197,7 +197,7 @@ async def generate_stream(stream_response): # Token counts prompt_text = self.tokenizer.apply_chat_template(req.messages) - prompt_tokens = len(self.tokenizer(prompt_text)) + prompt_tokens = len(self.tokenizer(prompt_text)) # type: ignore[arg-type] max_tokens = req.max_completion_tokens or req.max_tokens or math.inf completion_tokens_count = int( min( diff --git a/src/guidellm/mock_server/utils.py b/src/guidellm/mock_server/utils.py index 8348d0a6..a839f484 100644 --- a/src/guidellm/mock_server/utils.py +++ b/src/guidellm/mock_server/utils.py @@ -58,12 +58,15 @@ def __call__(self, text: str | list[str], **kwargs) -> list[int]: # noqa: ARG00 return self.convert_tokens_to_ids(tokens) elif isinstance(text, list): # Handle batch processing - return [self.__call__(t) for t in text] + result = [] + for t in text: + result.extend(self.__call__(t)) + return result else: msg = f"text input must be of type `str` or `list[str]`, got {type(text)}" raise ValueError(msg) - def tokenize(self, text: TextInput, **_kwargs) -> list[str]: + def tokenize(self, text: TextInput, **_kwargs) -> list[str]: # type: ignore[override] """ Tokenize input text into a list of token strings. @@ -76,7 +79,7 @@ def tokenize(self, text: TextInput, **_kwargs) -> list[str]: # Split text into tokens: words, spaces, and punctuation return re.findall(r"\w+|[^\w\s]|\s+", text) - def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: + def convert_tokens_to_ids(self, tokens: str | list[str]) -> list[int]: """ Convert token strings to numeric token IDs. @@ -87,12 +90,12 @@ def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: :return: Single token ID or list of token IDs """ if isinstance(tokens, str): - return hash(tokens) % self.VocabSize + return [hash(tokens) % self.VocabSize] return [hash(token) % self.VocabSize for token in tokens] - def convert_ids_to_tokens( - self, ids: int | list[int], _skip_special_tokens: bool = False - ) -> str | list[str]: + def convert_ids_to_tokens( # type: ignore[override] + self, ids: list[int], _skip_special_tokens: bool = False + ) -> list[str]: """ Convert numeric token IDs back to token strings. @@ -102,17 +105,9 @@ def convert_ids_to_tokens( :param ids: Single token ID or list of token IDs to convert :return: Single token string or list of token strings """ - if not ids and not isinstance(ids, list): - return "" - elif not ids: + if not ids: return [""] - if isinstance(ids, int): - fake = Faker() - fake.seed_instance(ids % self.VocabSize) - - return fake.word() - fake = Faker() fake.seed_instance(sum(ids) % self.VocabSize) @@ -162,7 +157,7 @@ def _add_tokens( """ return 0 - def apply_chat_template( + def apply_chat_template( # type: ignore[override] self, conversation: list, tokenize: bool = False, # Changed default to False to match transformers @@ -193,7 +188,7 @@ def apply_chat_template( return self.convert_tokens_to_ids(self.tokenize(formatted_text)) return formatted_text - def decode( + def decode( # type: ignore[override] self, token_ids: list[int], skip_special_tokens: bool = True, @@ -255,7 +250,7 @@ def create_fake_tokens_str( fake = Faker() fake.seed_instance(seed) - tokens = [] + tokens: list[str] = [] while len(tokens) < num_tokens: text = fake.text( diff --git a/src/guidellm/presentation/data_models.py b/src/guidellm/presentation/data_models.py index 62bf97d8..f98dd5a2 100644 --- a/src/guidellm/presentation/data_models.py +++ b/src/guidellm/presentation/data_models.py @@ -117,25 +117,23 @@ def from_benchmarks(cls, benchmarks: list["GenerativeBenchmark"]): range(len(successful_requests)), min(5, len(successful_requests)) ) sample_prompts = [ - successful_requests[i].request_args.replace("\n", " ").replace('"', "'") - if successful_requests[i].request_args is not None - else "" + req.request_args.replace("\n", " ").replace('"', "'") + if (req := successful_requests[i]).request_args else "" for i in sample_indices ] sample_outputs = [ - successful_requests[i].output.replace("\n", " ").replace('"', "'") - if successful_requests[i].output is not None - else "" + req.output.replace("\n", " ").replace('"', "'") + if (req := successful_requests[i]).output else "" for i in sample_indices ] prompt_tokens = [ - float(req.prompt_tokens) + float(req.prompt_tokens) if req.prompt_tokens is not None else -1 for bm in benchmarks for req in bm.requests.successful ] output_tokens = [ - float(req.output_tokens) + float(req.output_tokens) if req.output_tokens is not None else -1 for bm in benchmarks for req in bm.requests.successful ] diff --git a/src/guidellm/utils/encoding.py b/src/guidellm/utils/encoding.py index 7ececef5..50e7dce0 100644 --- a/src/guidellm/utils/encoding.py +++ b/src/guidellm/utils/encoding.py @@ -32,7 +32,7 @@ HAS_MSGSPEC = True except ImportError: - MsgspecDecoder = MsgspecEncoder = None + MsgspecDecoder = MsgspecEncoder = None # type: ignore[misc, assignment] # HAS_MSGSPEC will be checked at runtime HAS_MSGSPEC = False diff --git a/src/guidellm/utils/imports.py b/src/guidellm/utils/imports.py index 9a6b82d1..8b7ad5f6 100644 --- a/src/guidellm/utils/imports.py +++ b/src/guidellm/utils/imports.py @@ -3,7 +3,7 @@ try: import orjson as json except ImportError: - import json + import json # type: ignore[no-redef] # Done only after a failure. __all__ = ["json"] diff --git a/src/guidellm/utils/registry.py b/src/guidellm/utils/registry.py index 1a1a213f..cf450bd1 100644 --- a/src/guidellm/utils/registry.py +++ b/src/guidellm/utils/registry.py @@ -65,7 +65,7 @@ class TokenProposal(RegistryMixin): :cvar registry_populated: Track whether auto-discovery has completed """ - registry: ClassVar[dict[str, RegistryObjT] | None] = None + registry: ClassVar[dict[str, RegistryObjT] | None] = None # type: ignore[misc] registry_auto_discovery: ClassVar[bool] = False registry_populated: ClassVar[bool] = False diff --git a/src/guidellm/utils/statistics.py b/src/guidellm/utils/statistics.py index 0529cb0c..a8403c72 100644 --- a/src/guidellm/utils/statistics.py +++ b/src/guidellm/utils/statistics.py @@ -283,40 +283,12 @@ def from_request_times( ) # First convert to timing events based on type - events: list[tuple[float, float]] = [] - - if distribution_type == "concurrency": - # For concurrency, each request adds to concurrency at start - # and subtracts at end - for (start, end), weight in zip(requests, weights, strict=False): - events.append((start, weight)) - events.append((end, -1 * weight)) - elif distribution_type == "rate": - # For rate, each request is added at the end time only - global_start = min(start for start, _ in requests) if requests else 0.0 - events.append((global_start, 0.0)) - for (_, end), weight in zip(requests, weights, strict=False): - events.append((end, weight)) - else: - raise ValueError( - f"Invalid distribution_type '{distribution_type}'. " - "Must be 'concurrency' or 'rate'." - ) - - # Combine any events within epsilon of each other for stability - sorted_events = sorted(events, key=lambda event: event[0]) - flattened_events: list[tuple[float, float]] = ( - [sorted_events.pop(0)] if sorted_events else [] + events = DistributionSummary._convert_to_timing_events( + requests, distribution_type, weights ) - last_time = flattened_events[0][0] if flattened_events else 0.0 - for time, val in sorted_events: - if abs(time - last_time) <= epsilon: - last_val = flattened_events[-1][1] - flattened_events[-1] = (last_time, last_val + val) - else: - last_time = time - flattened_events.append((time, val)) + # Combine any events within epsilon of each other for stability + flattened_events = DistributionSummary._combine_events(events, epsilon) # Convert events to value distribution function distribution: dict[float, float] = defaultdict(float) @@ -357,6 +329,53 @@ def from_request_times( include_cdf=include_cdf, ) + @staticmethod + def _convert_to_timing_events( + requests: list[tuple[float, float]], + distribution_type: Literal["concurrency", "rate"], + weights: list[float], + ) -> list[tuple[float, float]]: + events: list[tuple[float, float]] = [] + + if distribution_type == "concurrency": + # For concurrency, each request adds to concurrency at start + # and subtracts at end + for (start, end), weight in zip(requests, weights, strict=False): + events.append((start, weight)) + events.append((end, -1 * weight)) + elif distribution_type == "rate": + # For rate, each request is added at the end time only + global_start = min(start for start, _ in requests) if requests else 0.0 + events.append((global_start, 0.0)) + for (_, end), weight in zip(requests, weights, strict=False): + events.append((end, weight)) + else: + raise ValueError( + f"Invalid distribution_type '{distribution_type}'. " + "Must be 'concurrency' or 'rate'." + ) + return events + + @staticmethod + def _combine_events( + events: list[tuple[float, float]], + epsilon: float, + ) -> list[tuple[float, float]]: + sorted_events = sorted(events, key=lambda event: event[0]) + flattened_events: list[tuple[float, float]] = ( + [sorted_events.pop(0)] if sorted_events else [] + ) + last_time = flattened_events[0][0] if flattened_events else 0.0 + + for time, val in sorted_events: + if abs(time - last_time) <= epsilon: + last_val = flattened_events[-1][1] + flattened_events[-1] = (last_time, last_val + val) + else: + last_time = time + flattened_events.append((time, val)) + return flattened_events + @staticmethod def from_iterable_request_times( requests: list[tuple[float, float]],