diff --git a/src/guidellm/utils/__init__.py b/src/guidellm/utils/__init__.py index bd6b5a90..702b2a9d 100644 --- a/src/guidellm/utils/__init__.py +++ b/src/guidellm/utils/__init__.py @@ -81,7 +81,6 @@ "EndlessTextCreator", "InfoMixin", "IntegerRangeSampler", - "camelize_str", "InterProcessMessaging", "InterProcessMessagingManagerQueue", "InterProcessMessagingPipe", @@ -107,14 +106,15 @@ "ThreadSafeSingletonMixin", "TimeRunningStats", "all_defined", + "camelize_str", "check_load_processor", "clean_text", "filter_text", "format_value_display", "get_literal_vals", "is_punctuation", - "recursive_key_update", "load_text", + "recursive_key_update", "safe_add", "safe_divide", "safe_format_timestamp", diff --git a/src/guidellm/utils/console.py b/src/guidellm/utils/console.py index c8cd6825..54e90cf7 100644 --- a/src/guidellm/utils/console.py +++ b/src/guidellm/utils/console.py @@ -155,7 +155,7 @@ def print_update_details(self, details: Any | None): block = Padding( Text.from_markup(str(details)), (0, 0, 0, 2), - style=StatusStyles.get("debug"), + style=StatusStyles.get("debug", "dim"), ) self.print(block) diff --git a/src/guidellm/utils/encoding.py b/src/guidellm/utils/encoding.py index ccd26982..6823fb77 100644 --- a/src/guidellm/utils/encoding.py +++ b/src/guidellm/utils/encoding.py @@ -12,10 +12,10 @@ import json from collections.abc import Mapping -from typing import Annotated, Any, ClassVar, Generic, Literal, Optional, TypeVar +from typing import Annotated, Any, ClassVar, Generic, Literal, Optional, TypeVar, cast try: - import msgpack + import msgpack # type: ignore[import-untyped] # Optional dependency from msgpack import Packer, Unpacker HAS_MSGPACK = True @@ -24,8 +24,12 @@ HAS_MSGPACK = False try: - from msgspec.msgpack import Decoder as MsgspecDecoder - from msgspec.msgpack import Encoder as MsgspecEncoder + from msgspec.msgpack import ( # type: ignore[import-not-found] # Optional dependency + Decoder as MsgspecDecoder, + ) + from msgspec.msgpack import ( # type: ignore[import-not-found] # Optional dependency + Encoder as MsgspecEncoder, + ) HAS_MSGSPEC = True except ImportError: @@ -33,7 +37,7 @@ HAS_MSGSPEC = False try: - import orjson + import orjson # type: ignore[import-not-found] # Optional dependency HAS_ORJSON = True except ImportError: @@ -116,7 +120,7 @@ def encode_message( """ serialized = serializer.serialize(obj) if serializer else obj - return encoder.encode(serialized) if encoder else serialized + return cast("MsgT", encoder.encode(serialized) if encoder else serialized) @classmethod def decode_message( @@ -137,7 +141,9 @@ def decode_message( """ serialized = encoder.decode(message) if encoder else message - return serializer.deserialize(serialized) if serializer else serialized + return cast( + "ObjT", serializer.deserialize(serialized) if serializer else serialized + ) def __init__( self, @@ -296,6 +302,15 @@ def _get_available_encoder_decoder( return None, None, None +PayloadType = Literal[ + "pydantic", + "python", + "collection_tuple", + "collection_sequence", + "collection_mapping", +] + + class Serializer: """ Object serialization with specialized Pydantic model support. @@ -474,6 +489,7 @@ def to_sequence(self, obj: Any) -> str | Any: :param obj: Object to serialize to sequence format :return: Serialized sequence string or bytes """ + payload_type: PayloadType if isinstance(obj, BaseModel): payload_type = "pydantic" payload = self.to_sequence_pydantic(obj) @@ -515,7 +531,9 @@ def to_sequence(self, obj: Any) -> str | Any: payload_type = "python" payload = self.to_sequence_python(obj) - return self.pack_next_sequence(payload_type, payload, None) + return self.pack_next_sequence( + payload_type, payload if payload is not None else "", None + ) def from_sequence(self, data: str | Any) -> Any: # noqa: C901, PLR0912 """ @@ -529,6 +547,7 @@ def from_sequence(self, data: str | Any) -> Any: # noqa: C901, PLR0912 :raises ValueError: If sequence format is invalid or contains multiple packed sequences """ + payload: str | bytes | None type_, payload, remaining = self.unpack_next_sequence(data) if remaining is not None: raise ValueError("Data contains multiple packed sequences; expected one.") @@ -540,16 +559,16 @@ def from_sequence(self, data: str | Any) -> Any: # noqa: C901, PLR0912 return self.from_sequence_python(payload) if type_ in {"collection_sequence", "collection_tuple"}: - items = [] + c_items = [] while payload: type_, item_payload, payload = self.unpack_next_sequence(payload) if type_ == "pydantic": - items.append(self.from_sequence_pydantic(item_payload)) + c_items.append(self.from_sequence_pydantic(item_payload)) elif type_ == "python": - items.append(self.from_sequence_python(item_payload)) + c_items.append(self.from_sequence_python(item_payload)) else: raise ValueError("Invalid type in collection sequence") - return items + return c_items if type_ != "collection_mapping": raise ValueError(f"Invalid type for mapping sequence: {type_}") @@ -604,6 +623,7 @@ def from_sequence_pydantic(self, data: str | bytes) -> BaseModel: :param data: Sequence data containing class metadata and JSON :return: Reconstructed Pydantic model instance """ + json_data: str | bytes | bytearray if isinstance(data, bytes): class_name_end = data.index(b"|") class_name = data[:class_name_end].decode() @@ -647,13 +667,7 @@ def from_sequence_python(self, data: str | bytes) -> Any: def pack_next_sequence( # noqa: C901, PLR0912 self, - type_: Literal[ - "pydantic", - "python", - "collection_tuple", - "collection_sequence", - "collection_mapping", - ], + type_: PayloadType, payload: str | bytes, current: str | bytes | None, ) -> str | bytes: @@ -672,9 +686,11 @@ def pack_next_sequence( # noqa: C901, PLR0912 raise ValueError("Payload and current must be of the same type") payload_len = len(payload) - + payload_len_output: str | bytes + payload_type: str | bytes + delimiter: str | bytes if isinstance(payload, bytes): - payload_len = payload_len.to_bytes( + payload_len_output = payload_len.to_bytes( length=(payload_len.bit_length() + 7) // 8 if payload_len > 0 else 1, byteorder="big", ) @@ -692,7 +708,7 @@ def pack_next_sequence( # noqa: C901, PLR0912 raise ValueError(f"Unknown type for packing: {type_}") delimiter = b"|" else: - payload_len = str(payload_len) + payload_len_output = str(payload_len) if type_ == "pydantic": payload_type = "P" elif type_ == "python": @@ -707,20 +723,16 @@ def pack_next_sequence( # noqa: C901, PLR0912 raise ValueError(f"Unknown type for packing: {type_}") delimiter = "|" - next_sequence = payload_type + delimiter + payload_len + delimiter + payload - - return current + next_sequence if current else next_sequence + # Type ignores because types are enforced at runtime + next_sequence = ( + payload_type + delimiter + payload_len_output + delimiter + payload # type: ignore[operator] + ) + return current + next_sequence if current else next_sequence # type: ignore[operator] def unpack_next_sequence( # noqa: C901, PLR0912 self, data: str | bytes ) -> tuple[ - Literal[ - "pydantic", - "python", - "collection_tuple", - "collection_sequence", - "collection_mapping", - ], + PayloadType, str | bytes, str | bytes | None, ]: @@ -731,57 +743,58 @@ def unpack_next_sequence( # noqa: C901, PLR0912 :return: Tuple of (type, payload, remaining_data) :raises ValueError: If sequence format is invalid or unknown type character """ + type_: PayloadType if isinstance(data, bytes): if len(data) < len(b"T|N") or data[1:2] != b"|": raise ValueError("Invalid packed data format") - type_char = data[0:1] - if type_char == b"P": + type_char_b = data[0:1] + if type_char_b == b"P": type_ = "pydantic" - elif type_char == b"p": + elif type_char_b == b"p": type_ = "python" - elif type_char == b"T": + elif type_char_b == b"T": type_ = "collection_tuple" - elif type_char == b"S": + elif type_char_b == b"S": type_ = "collection_sequence" - elif type_char == b"M": + elif type_char_b == b"M": type_ = "collection_mapping" else: raise ValueError("Unknown type character in packed data") len_end = data.index(b"|", 2) payload_len = int.from_bytes(data[2:len_end], "big") - payload = data[len_end + 1 : len_end + 1 + payload_len] - remaining = ( + payload_b = data[len_end + 1 : len_end + 1 + payload_len] + remaining_b = ( data[len_end + 1 + payload_len :] if len_end + 1 + payload_len < len(data) else None ) - return type_, payload, remaining + return type_, payload_b, remaining_b if len(data) < len("T|N") or data[1] != "|": raise ValueError("Invalid packed data format") - type_char = data[0] - if type_char == "P": + type_char_s = data[0] + if type_char_s == "P": type_ = "pydantic" - elif type_char == "p": + elif type_char_s == "p": type_ = "python" - elif type_char == "S": + elif type_char_s == "S": type_ = "collection_sequence" - elif type_char == "M": + elif type_char_s == "M": type_ = "collection_mapping" else: raise ValueError("Unknown type character in packed data") len_end = data.index("|", 2) payload_len = int(data[2:len_end]) - payload = data[len_end + 1 : len_end + 1 + payload_len] - remaining = ( + payload_s = data[len_end + 1 : len_end + 1 + payload_len] + remaining_s = ( data[len_end + 1 + payload_len :] if len_end + 1 + payload_len < len(data) else None ) - return type_, payload, remaining + return type_, payload_s, remaining_s diff --git a/src/guidellm/utils/functions.py b/src/guidellm/utils/functions.py index 6343cbf2..ed4a2075 100644 --- a/src/guidellm/utils/functions.py +++ b/src/guidellm/utils/functions.py @@ -96,19 +96,20 @@ def safe_add( if not values: return default - values = list(values) + values_list = list(values) if signs is None: - signs = [1] * len(values) + signs = [1] * len(values_list) - if len(signs) != len(values): + if len(signs) != len(values_list): raise ValueError("Length of signs must match length of values") - result = values[0] if values[0] is not None else default + result = values_list[0] if values_list[0] is not None else default - for ind in range(1, len(values)): - val = values[ind] if values[ind] is not None else default - result += signs[ind] * val + for ind in range(1, len(values_list)): + value = values_list[ind] + checked_value = value if value is not None else default + result += signs[ind] * checked_value return result diff --git a/src/guidellm/utils/messaging.py b/src/guidellm/utils/messaging.py index c56ec29a..9311259d 100644 --- a/src/guidellm/utils/messaging.py +++ b/src/guidellm/utils/messaging.py @@ -22,7 +22,7 @@ from multiprocessing.managers import SyncManager from multiprocessing.synchronize import Event as ProcessingEvent from threading import Event as ThreadingEvent -from typing import Any, Callable, Generic, Protocol, TypeVar +from typing import Any, Callable, Generic, Protocol, TypeVar, cast import culsans from pydantic import BaseModel @@ -48,19 +48,21 @@ ReceiveMessageT = TypeVar("ReceiveMessageT", bound=Any) """Generic type variable for messages received through the messaging system""" +CheckStopCallableT = Callable[[bool, int], bool] + class MessagingStopCallback(Protocol): """Protocol for evaluating stop conditions in messaging operations.""" def __call__( - self, messaging: InterProcessMessaging, pending: bool, queue_empty: int + self, messaging: InterProcessMessaging, pending: bool, queue_empty_count: int ) -> bool: """ Evaluate whether messaging operations should stop. :param messaging: The messaging instance to evaluate :param pending: Whether there are pending operations - :param queue_empty: The number of times in a row the queue has been empty + :param queue_empty_count: The number of times in a row the queue has been empty :return: True if operations should stop, False otherwise """ ... @@ -90,7 +92,7 @@ class InterProcessMessaging(Generic[SendMessageT, ReceiveMessageT], ABC): await messaging.stop() """ - STOP_REQUIRED_QUEUE_EMPTY: int = 3 + STOP_REQUIRED_QUEUE_EMPTY_COUNT: int = 3 def __init__( self, @@ -126,13 +128,13 @@ def __init__( self.max_buffer_receive_size = max_buffer_receive_size self.poll_interval = poll_interval - self.send_stopped_event: ThreadingEvent | ProcessingEvent = None - self.receive_stopped_event: ThreadingEvent | ProcessingEvent = None - self.shutdown_event: ThreadingEvent = None - self.buffer_send_queue: culsans.Queue[SendMessageT] = None - self.buffer_receive_queue: culsans.Queue[ReceiveMessageT] = None - self.send_task: asyncio.Task = None - self.receive_task: asyncio.Task = None + self.send_stopped_event: ThreadingEvent | ProcessingEvent | None = None + self.receive_stopped_event: ThreadingEvent | ProcessingEvent | None = None + self.shutdown_event: ThreadingEvent | None = None + self.buffer_send_queue: culsans.Queue[SendMessageT] | None = None + self.buffer_receive_queue: culsans.Queue[ReceiveMessageT] | None = None + self.send_task: asyncio.Task | None = None + self.receive_task: asyncio.Task | None = None self.running = False @abstractmethod @@ -152,7 +154,7 @@ def create_send_messages_threads( self, send_items: Iterable[Any] | None, message_encoding: MessageEncoding, - check_stop: Callable[[bool, bool], bool], + check_stop: CheckStopCallableT, ) -> list[tuple[Callable, tuple[Any, ...]]]: """ Create send message processing threads for transport implementation. @@ -169,7 +171,7 @@ def create_receive_messages_threads( self, receive_callback: Callable[[Any], Any] | None, message_encoding: MessageEncoding, - check_stop: Callable[[bool, bool], bool], + check_stop: CheckStopCallableT, ) -> list[tuple[Callable, tuple[Any, ...]]]: """ Create receive message processing threads for transport implementation. @@ -216,9 +218,8 @@ async def start( self.buffer_receive_queue = culsans.Queue[ReceiveMessageT]( maxsize=self.max_buffer_receive_size or 0 ) - self.tasks_lock = threading.Lock() - message_encoding = MessageEncoding( + message_encoding: MessageEncoding = MessageEncoding( serialization=self.serialization, encoding=self.encoding, pydantic_models=pydantic_models, @@ -245,18 +246,29 @@ async def stop(self): """ Stop message processing tasks and clean up resources. """ - self.shutdown_event.set() - with contextlib.suppress(asyncio.CancelledError): - await asyncio.gather( - self.send_task, self.receive_task, return_exceptions=True + if self.shutdown_event is not None: + self.shutdown_event.set() + else: + raise RuntimeError( + "shutdown_event is not set; was start() not called or " + "is this a redundant stop() call?" ) + tasks = [self.send_task, self.receive_task] + tasks_to_run: list[asyncio.Task[Any]] = [ + task for task in tasks if task is not None + ] + if len(tasks_to_run) > 0: + with contextlib.suppress(asyncio.CancelledError): + await asyncio.gather(*tasks_to_run, return_exceptions=True) self.send_task = None self.receive_task = None if self.worker_index is None: - self.buffer_send_queue.clear() - await self.buffer_send_queue.aclose() - self.buffer_receive_queue.clear() - await self.buffer_receive_queue.aclose() + if self.buffer_send_queue is not None: + self.buffer_send_queue.clear() + await self.buffer_send_queue.aclose() + if self.buffer_receive_queue is not None: + self.buffer_receive_queue.clear() + await self.buffer_receive_queue.aclose() self.buffer_send_queue = None self.buffer_receive_queue = None self.send_stopped_event = None @@ -298,7 +310,8 @@ async def send_messages_coroutine( canceled_event.set() raise finally: - self.send_stopped_event.set() + if self.send_stopped_event is not None: + self.send_stopped_event.set() async def receive_messages_coroutine( self, @@ -334,15 +347,20 @@ async def receive_messages_coroutine( canceled_event.set() raise finally: - self.receive_stopped_event.set() + if self.receive_stopped_event is not None: + self.receive_stopped_event.set() async def get(self, timeout: float | None = None) -> ReceiveMessageT: """ - Retrieve message from receive buffer with optional timeout. + Retrieve a message from receive buffer with optional timeout. :param timeout: Maximum time to wait for a message :return: Decoded message from the receive buffer """ + if self.buffer_receive_queue is None: + raise RuntimeError( + "buffer receive queue is None; check start()/stop() calls" + ) return await asyncio.wait_for( self.buffer_receive_queue.async_get(), timeout=timeout ) @@ -354,6 +372,10 @@ def get_sync(self, timeout: float | None = None) -> ReceiveMessageT: :param timeout: Maximum time to wait for a message, if <=0 uses get_nowait :return: Decoded message from the receive buffer """ + if self.buffer_receive_queue is None: + raise RuntimeError( + "buffer receive queue is None; check start()/stop() calls" + ) if timeout is not None and timeout <= 0: return self.buffer_receive_queue.get_nowait() else: @@ -366,6 +388,10 @@ async def put(self, item: SendMessageT, timeout: float | None = None): :param item: Message item to add to the send buffer :param timeout: Maximum time to wait for buffer space """ + if self.buffer_send_queue is None: + raise RuntimeError( + "buffer receive queue is None; check start()/stop() calls" + ) await asyncio.wait_for(self.buffer_send_queue.async_put(item), timeout=timeout) def put_sync(self, item: SendMessageT, timeout: float | None = None): @@ -375,6 +401,10 @@ def put_sync(self, item: SendMessageT, timeout: float | None = None): :param item: Message item to add to the send buffer :param timeout: Maximum time to wait for buffer space, if <=0 uses put_nowait """ + if self.buffer_send_queue is None: + raise RuntimeError( + "buffer receive queue is None; check start()/stop() calls" + ) if timeout is not None and timeout <= 0: self.buffer_send_queue.put_nowait(item) else: @@ -394,18 +424,21 @@ def _create_check_stop_callable( ) stop_callbacks = tuple(item for item in stop_criteria or [] if callable(item)) - def check_stop(pending: bool, queue_empty: int) -> bool: + def check_stop(pending: bool, queue_empty_count: int) -> bool: if canceled_event.is_set(): return True if stop_callbacks and any( - cb(self, pending, queue_empty) for cb in stop_callbacks + cb(self, pending, queue_empty_count) for cb in stop_callbacks ): return True + if self.shutdown_event is None: + return True + return ( not pending - and queue_empty >= self.STOP_REQUIRED_QUEUE_EMPTY + and queue_empty_count >= self.STOP_REQUIRED_QUEUE_EMPTY_COUNT and ( self.shutdown_event.is_set() or any(event.is_set() for event in stop_events) @@ -437,6 +470,9 @@ class InterProcessMessagingQueue(InterProcessMessaging[SendMessageT, ReceiveMess worker_messaging = messaging.create_worker_copy(worker_index=0) """ + pending_queue: multiprocessing.Queue | queue.Queue[Any] | None + done_queue: multiprocessing.Queue | queue.Queue[Any] | None + def __init__( self, mp_context: BaseContext | None = None, @@ -448,8 +484,8 @@ def __init__( max_buffer_receive_size: int | None = None, poll_interval: float = 0.1, worker_index: int | None = None, - pending_queue: multiprocessing.Queue | None = None, - done_queue: multiprocessing.Queue | None = None, + pending_queue: multiprocessing.Queue | queue.Queue[Any] | None = None, + done_queue: multiprocessing.Queue | queue.Queue[Any] | None = None, ): """ Initialize queue-based messaging for inter-process communication. @@ -506,9 +542,9 @@ def create_worker_copy( "pending_queue": self.pending_queue, "done_queue": self.done_queue, } - copy_args.update(kwargs) + final_args = {**copy_args, **kwargs} - return InterProcessMessagingQueue[ReceiveMessageT, SendMessageT](**copy_args) + return InterProcessMessagingQueue[ReceiveMessageT, SendMessageT](**final_args) async def stop(self): """ @@ -517,15 +553,21 @@ async def stop(self): await super().stop() if self.worker_index is None: # only main process should close the queues + if self.pending_queue is None: + raise RuntimeError("pending_queue is None; was stop() already called?") with contextlib.suppress(queue.Empty): while True: self.pending_queue.get_nowait() - self.pending_queue.close() + if hasattr(self.pending_queue, "close"): + self.pending_queue.close() + if self.done_queue is None: + raise RuntimeError("done_queue is None; was stop() already called?") with contextlib.suppress(queue.Empty): while True: self.done_queue.get_nowait() - self.done_queue.close() + if hasattr(self.done_queue, "close"): + self.done_queue.close() self.pending_queue = None self.done_queue = None @@ -534,7 +576,7 @@ def create_send_messages_threads( self, send_items: Iterable[Any] | None, message_encoding: MessageEncoding, - check_stop: Callable[[bool, bool], bool], + check_stop: CheckStopCallableT, ) -> list[tuple[Callable, tuple[Any, ...]]]: """ Create send message processing threads for queue-based transport. @@ -555,7 +597,7 @@ def create_receive_messages_threads( self, receive_callback: Callable[[Any], Any] | None, message_encoding: MessageEncoding, - check_stop: Callable[[bool, bool], bool], + check_stop: CheckStopCallableT, ) -> list[tuple[Callable, tuple[Any, ...]]]: """ Create receive message processing threads for queue-based transport. @@ -576,35 +618,51 @@ def _send_messages_task_thread( # noqa: C901, PLR0912 self, send_items: Iterable[Any] | None, message_encoding: MessageEncoding, - check_stop: Callable[[bool, bool], bool], + check_stop: CheckStopCallableT, ): send_items_iter = iter(send_items) if send_items is not None else None pending_item = None - queue_empty = 0 + queue_empty_count = 0 - while not check_stop(pending_item is not None, queue_empty): + while not check_stop(pending_item is not None, queue_empty_count): if pending_item is None: try: if send_items_iter is not None: item = next(send_items_iter) else: + if self.buffer_send_queue is None: + raise RuntimeError( + "buffer_send_queue is None; was stop() already called?" + ) item = self.buffer_send_queue.sync_get( timeout=self.poll_interval ) pending_item = message_encoding.encode(item) - queue_empty = 0 + queue_empty_count = 0 except (culsans.QueueEmpty, queue.Empty, StopIteration): - queue_empty += 1 + queue_empty_count += 1 if pending_item is not None: try: if self.worker_index is None: # Main publisher + if self.pending_queue is None: + raise RuntimeError( + "pending_queue is None; was stop() already called?" + ) self.pending_queue.put(pending_item, timeout=self.poll_interval) else: # Worker + if self.done_queue is None: + raise RuntimeError( + "done_queue is None; was stop() already called?" + ) self.done_queue.put(pending_item, timeout=self.poll_interval) if send_items_iter is None: + if self.buffer_send_queue is None: + raise RuntimeError( + "buffer_send_queue is None; was stop() already called?" + ) self.buffer_send_queue.task_done() pending_item = None except (culsans.QueueFull, queue.Full): @@ -614,25 +672,33 @@ def _receive_messages_task_thread( # noqa: C901 self, receive_callback: Callable[[Any], Any] | None, message_encoding: MessageEncoding, - check_stop: Callable[[bool, bool], bool], + check_stop: CheckStopCallableT, ): pending_item = None received_item = None - queue_empty = 0 + queue_empty_count = 0 - while not check_stop(pending_item is not None, queue_empty): + while not check_stop(pending_item is not None, queue_empty_count): if pending_item is None: try: if self.worker_index is None: # Main publisher + if self.done_queue is None: + raise RuntimeError( + "done_queue is None; check start()/stop() calls" + ) item = self.done_queue.get(timeout=self.poll_interval) else: # Worker + if self.pending_queue is None: + raise RuntimeError( + "pending_queue is None; check start()/stop() calls" + ) item = self.pending_queue.get(timeout=self.poll_interval) pending_item = message_encoding.decode(item) - queue_empty = 0 + queue_empty_count = 0 except (culsans.QueueEmpty, queue.Empty): - queue_empty += 1 + queue_empty_count += 1 if pending_item is not None or received_item is not None: try: @@ -643,7 +709,13 @@ def _receive_messages_task_thread( # noqa: C901 else receive_callback(pending_item) ) - self.buffer_receive_queue.sync_put(received_item) + if self.buffer_receive_queue is None: + raise RuntimeError( + "buffer_receive_queue is None; check start()/stop() calls" + ) + self.buffer_receive_queue.sync_put( + cast("ReceiveMessageT", received_item) + ) pending_item = None received_item = None except (culsans.QueueFull, queue.Full): @@ -714,8 +786,8 @@ def __init__( max_buffer_receive_size=max_buffer_receive_size, poll_interval=poll_interval, worker_index=worker_index, - pending_queue=pending_queue or manager.Queue(maxsize=max_pending_size or 0), # type: ignore [assignment] - done_queue=done_queue or manager.Queue(maxsize=max_done_size or 0), # type: ignore [assignment] + pending_queue=pending_queue or manager.Queue(maxsize=max_pending_size or 0), + done_queue=done_queue or manager.Queue(maxsize=max_done_size or 0), ) def create_worker_copy( @@ -741,9 +813,9 @@ def create_worker_copy( "pending_queue": self.pending_queue, "done_queue": self.done_queue, } - copy_args.update(kwargs) + final_args = {**copy_args, **kwargs} - return InterProcessMessagingManagerQueue(**copy_args) + return InterProcessMessagingManagerQueue(**final_args) async def stop(self): """ @@ -818,12 +890,11 @@ def __init__( ) self.num_workers = num_workers + self.pipes: list[tuple[Connection, Connection]] if pipe is None: - self.pipes: list[tuple[Connection, Connection]] = [ - self.mp_context.Pipe(duplex=True) for _ in range(num_workers) - ] + self.pipes = [self.mp_context.Pipe(duplex=True) for _ in range(num_workers)] else: - self.pipes: list[tuple[Connection, Connection]] = [pipe] + self.pipes = [pipe] def create_worker_copy( self, worker_index: int, **kwargs @@ -847,9 +918,10 @@ def create_worker_copy( "worker_index": worker_index, "pipe": self.pipes[worker_index], } - copy_args.update(kwargs) - return InterProcessMessagingPipe(**copy_args) + final_args = {**copy_args, **kwargs} + + return InterProcessMessagingPipe(**final_args) async def stop(self): """ @@ -866,7 +938,7 @@ def create_send_messages_threads( self, send_items: Iterable[Any] | None, message_encoding: MessageEncoding, - check_stop: Callable[[bool, bool], bool], + check_stop: CheckStopCallableT, ) -> list[tuple[Callable, tuple[Any, ...]]]: """ Create send message processing threads for pipe-based transport. @@ -897,7 +969,7 @@ def create_receive_messages_threads( self, receive_callback: Callable[[Any], Any] | None, message_encoding: MessageEncoding, - check_stop: Callable[[bool, bool], bool], + check_stop: CheckStopCallableT, ) -> list[tuple[Callable, tuple[Any, ...]]]: """ Create receive message processing threads for pipe-based transport. @@ -924,18 +996,18 @@ def create_receive_messages_threads( ) ] - def _send_messages_task_thread( # noqa: C901, PLR0912 + def _send_messages_task_thread( # noqa: C901, PLR0912, PLR0915 self, pipe: tuple[Connection, Connection], send_items: Iterable[Any] | None, message_encoding: MessageEncoding, - check_stop: Callable[[bool, bool], bool], + check_stop: CheckStopCallableT, ): local_stop = ThreadingEvent() send_connection: Connection = pipe[0] if self.worker_index is None else pipe[1] send_items_iter = iter(send_items) if send_items is not None else None pending_item = None - queue_empty = 0 + queue_empty_count = 0 pipe_item = None pipe_lock = threading.Lock() @@ -957,19 +1029,23 @@ def _background_pipe_recv(): threading.Thread(target=_background_pipe_recv, daemon=True).start() try: - while not check_stop(pending_item is not None, queue_empty): + while not check_stop(pending_item is not None, queue_empty_count): if pending_item is None: try: if send_items_iter is not None: item = next(send_items_iter) else: + if self.buffer_send_queue is None: + raise RuntimeError( + "buffer_send_queue is None; check start()/stop() calls" # noqa: E501 + ) item = self.buffer_send_queue.sync_get( timeout=self.poll_interval ) pending_item = message_encoding.encode(item) - queue_empty = 0 + queue_empty_count = 0 except (culsans.QueueEmpty, queue.Empty, StopIteration): - queue_empty += 1 + queue_empty_count += 1 if pending_item is not None: try: @@ -980,6 +1056,10 @@ def _background_pipe_recv(): else: pipe_item = pending_item if send_items_iter is None: + if self.buffer_send_queue is None: + raise RuntimeError( + "buffer_send_queue is None; check start()/stop() calls" # noqa: E501 + ) self.buffer_send_queue.task_done() pending_item = None except (culsans.QueueFull, queue.Full): @@ -992,16 +1072,16 @@ def _receive_messages_task_thread( # noqa: C901 pipe: tuple[Connection, Connection], receive_callback: Callable[[Any], Any] | None, message_encoding: MessageEncoding, - check_stop: Callable[[bool, bool], bool], + check_stop: CheckStopCallableT, ): receive_connection: Connection = ( pipe[0] if self.worker_index is not None else pipe[1] ) pending_item = None received_item = None - queue_empty = 0 + queue_empty_count = 0 - while not check_stop(pending_item is not None, queue_empty): + while not check_stop(pending_item is not None, queue_empty_count): if pending_item is None: try: if receive_connection.poll(self.poll_interval): @@ -1009,9 +1089,9 @@ def _receive_messages_task_thread( # noqa: C901 pending_item = message_encoding.decode(item) else: raise queue.Empty - queue_empty = 0 + queue_empty_count = 0 except (culsans.QueueEmpty, queue.Empty): - queue_empty += 1 + queue_empty_count += 1 if pending_item is not None or received_item is not None: try: @@ -1021,8 +1101,13 @@ def _receive_messages_task_thread( # noqa: C901 if not receive_callback else receive_callback(pending_item) ) - - self.buffer_receive_queue.sync_put(received_item) + if self.buffer_receive_queue is None: + raise RuntimeError( + "buffer receive queue is None; check start()/stop() calls" + ) + self.buffer_receive_queue.sync_put( + cast("ReceiveMessageT", received_item) + ) pending_item = None received_item = None except (culsans.QueueFull, queue.Full): diff --git a/src/guidellm/utils/pydantic_utils.py b/src/guidellm/utils/pydantic_utils.py index 27c2e1cf..515b445e 100644 --- a/src/guidellm/utils/pydantic_utils.py +++ b/src/guidellm/utils/pydantic_utils.py @@ -11,7 +11,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, ClassVar, Generic, TypeVar +from typing import Any, ClassVar, Generic, TypeVar, cast from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler from pydantic_core import CoreSchema, core_schema @@ -29,7 +29,7 @@ BaseModelT = TypeVar("BaseModelT", bound=BaseModel) -RegisterClassT = TypeVar("RegisterClassT") +RegisterClassT = TypeVar("RegisterClassT", bound=type) SuccessfulT = TypeVar("SuccessfulT") ErroredT = TypeVar("ErroredT") IncompleteT = TypeVar("IncompleteT") @@ -300,7 +300,7 @@ def register_decorator( super().register_decorator(clazz, name=name) cls.reload_schema() - return clazz + return cast("RegisterClassT", clazz) @classmethod def __get_pydantic_core_schema__( diff --git a/src/guidellm/utils/registry.py b/src/guidellm/utils/registry.py index b9e3faf5..e6f1b657 100644 --- a/src/guidellm/utils/registry.py +++ b/src/guidellm/utils/registry.py @@ -10,7 +10,7 @@ from __future__ import annotations -from typing import Callable, ClassVar, Generic, TypeVar, cast +from typing import Any, Callable, ClassVar, Generic, TypeVar, cast from guidellm.utils.auto_importer import AutoImporterMixin @@ -19,7 +19,9 @@ RegistryObjT = TypeVar("RegistryObjT") """Generic type variable for objects managed by the registry system.""" -RegisterT = TypeVar("RegisterT") +RegisterT = TypeVar( + "RegisterT", bound=type +) # Must be bound to type to ensure __name__ is available. """Generic type variable for the args and return values within the registry.""" @@ -62,7 +64,7 @@ class TokenProposal(RegistryMixin): :cvar registry_populated: Track whether auto-discovery has completed """ - registry: ClassVar[dict[str, RegistryObjT] | None] = None + registry: ClassVar[dict[str, Any] | None] = None registry_auto_discovery: ClassVar[bool] = False registry_populated: ClassVar[bool] = False @@ -209,6 +211,9 @@ def get_registered_object(cls, name: str) -> RegistryObjT | None: if name in cls.registry: return cls.registry[name] - lower_key_map = {key.lower(): key for key in cls.registry} + name_casefold = name.lower() + for k, v in cls.registry.items(): + if name_casefold == k.lower(): + return v - return cls.registry.get(lower_key_map.get(name.lower())) + return None # Not found diff --git a/src/guidellm/utils/singleton.py b/src/guidellm/utils/singleton.py index 3ec10f79..693bbf2e 100644 --- a/src/guidellm/utils/singleton.py +++ b/src/guidellm/utils/singleton.py @@ -36,6 +36,9 @@ def __init__(self, config_path: str): assert manager1 is manager2 """ + _singleton_initialized: bool + _init_lock: threading.Lock + def __new__(cls, *args, **kwargs): # noqa: ARG004 """ Create or return the singleton instance. diff --git a/src/guidellm/utils/synchronous.py b/src/guidellm/utils/synchronous.py index 14f3d908..64c14e94 100644 --- a/src/guidellm/utils/synchronous.py +++ b/src/guidellm/utils/synchronous.py @@ -131,8 +131,9 @@ async def wait_for_sync_objects( :param poll_interval: Time in seconds between polling checks for each object :return: Index (for list/single) or key name (for dict) of the first completed object - :raises asyncio.CancelledError: If the async task is cancelled + :raises asyncio.CancelledError: If the async task is canceled """ + keys: list[int | str] if isinstance(objects, dict): keys = list(objects.keys()) objects = list(objects.values()) diff --git a/tests/unit/utils/test_pydantic_utils.py b/tests/unit/utils/test_pydantic_utils.py index 726b5ddf..b1278f51 100644 --- a/tests/unit/utils/test_pydantic_utils.py +++ b/tests/unit/utils/test_pydantic_utils.py @@ -41,7 +41,7 @@ def test_register_class_t(): """Test that RegisterClassT is configured correctly as a TypeVar.""" assert isinstance(RegisterClassT, type(TypeVar("test"))) assert RegisterClassT.__name__ == "RegisterClassT" - assert RegisterClassT.__bound__ is None + assert RegisterClassT.__bound__ is type assert RegisterClassT.__constraints__ == () diff --git a/tests/unit/utils/test_registry.py b/tests/unit/utils/test_registry.py index eed126d3..7bd0eaf8 100644 --- a/tests/unit/utils/test_registry.py +++ b/tests/unit/utils/test_registry.py @@ -26,7 +26,7 @@ def test_registered_type(): """Test that RegisterT is configured correctly as a TypeVar.""" assert isinstance(RegisterT, type(TypeVar("test"))) assert RegisterT.__name__ == "RegisterT" - assert RegisterT.__bound__ is None + assert RegisterT.__bound__ is type assert RegisterT.__constraints__ == ()