diff --git a/src/guidellm/scheduler/constraints.py b/src/guidellm/scheduler/constraints.py index 2eb24bdb..e24419ea 100644 --- a/src/guidellm/scheduler/constraints.py +++ b/src/guidellm/scheduler/constraints.py @@ -1005,9 +1005,7 @@ def info(self) -> dict[str, Any]: return self.model_dump() def __call__( - self, - state: SchedulerState, - request_info: RequestInfo, # noqa: ARG002 + self, state: SchedulerState, _request: RequestInfo ) -> SchedulerUpdateAction: create_exceeded = state.created_requests >= self.num_requests processed_exceeded = state.processed_requests >= self.num_requests diff --git a/src/guidellm/scheduler/environments.py b/src/guidellm/scheduler/environments.py index 4f02d772..fae85d54 100644 --- a/src/guidellm/scheduler/environments.py +++ b/src/guidellm/scheduler/environments.py @@ -84,7 +84,7 @@ async def sync_run_start(self) -> float: async def update_run_iteration( self, response: ResponseT | None, - request: RequestT, + request: RequestT | MultiTurnRequestT[RequestT], request_info: RequestInfo, state: SchedulerState, ): @@ -201,7 +201,7 @@ async def sync_run_start(self) -> float: async def update_run_iteration( self, response: ResponseT | None, - request: RequestT, + request: RequestT | MultiTurnRequestT[RequestT], request_info: RequestInfo, state: SchedulerState, ): diff --git a/src/guidellm/scheduler/scheduler.py b/src/guidellm/scheduler/scheduler.py index ca5935fa..6da76438 100644 --- a/src/guidellm/scheduler/scheduler.py +++ b/src/guidellm/scheduler/scheduler.py @@ -69,7 +69,7 @@ async def run( ) -> AsyncIterator[ tuple[ ResponseT | None, - RequestT, + RequestT | MultiTurnRequestT[RequestT], RequestInfo, SchedulerState, ] diff --git a/src/guidellm/scheduler/strategies.py b/src/guidellm/scheduler/strategies.py index 0cd3bc63..448266cf 100644 --- a/src/guidellm/scheduler/strategies.py +++ b/src/guidellm/scheduler/strategies.py @@ -70,8 +70,8 @@ def __pydantic_schema_base_type__(cls) -> type[SchedulingStrategy]: description="Number of worker processes to use for this strategy", ge=0, ) - max_concurrency: int = Field( - default=0, + max_concurrency: int | None = Field( + default=None, description="Maximum number of concurrent requests to allow", ge=0, ) @@ -122,8 +122,8 @@ def init_processes_timings( self.startup_duration = startup_duration self._processes_request_index = Value("i", 0) - self._processes_lock = Lock() self._processes_start_time = Value("d", -1.0) + self._processes_lock = Lock() def init_processes_start(self, start_time: float): """ @@ -137,6 +137,10 @@ def init_processes_start(self, start_time: float): "SchedulingStrategy init_processes_start called before " "init_processes_timings" ) + if self._processes_start_time is None: + raise RuntimeError( + "_processes_lock is not None but _processes_start_time is None" + ) with self._processes_lock: self._processes_start_time.value = start_time @@ -153,6 +157,10 @@ async def get_processes_start_time(self) -> float: "SchedulingStrategy get_processes_start_time called before " "init_processes_timings" ) + if self._processes_start_time is None: + raise RuntimeError( + "_processes_lock is not None but _processes_start_time is None" + ) while self._cached_processes_start_time is None: with self._processes_lock: @@ -175,6 +183,10 @@ def next_request_index(self) -> int: "SchedulingStrategy next_request_index called before " "init_processes_timings" ) + if self._processes_request_index is None: + raise RuntimeError( + "_processes_lock is not None but _processes_request_index is None" + ) with self._processes_lock: self._processes_request_index.value += 1 @@ -369,7 +381,8 @@ async def next_request_time(self, offset: int) -> float: start_time = await self.get_processes_start_time() if ( - self.startup_duration > 0 + self.max_concurrency is not None + and self.startup_duration > 0 and (time.time() - start_time) < self.startup_duration and (current_index := self.next_request_index()) <= self.max_concurrency ): @@ -477,6 +490,8 @@ def init_processes_timings( :param startup_duration: Duration in seconds for request startup ramping """ super().init_processes_timings(worker_count, max_concurrency, startup_duration) + if self._processes_lock is None: + raise RuntimeError("_processes_lock is None in init_processes_timings") with self._processes_lock: self._offset = Value("d", -1.0) @@ -487,6 +502,12 @@ def init_processes_start(self, start_time: float): :param start_time: Unix timestamp when request processing should begin """ ThroughputStrategy.init_processes_start(self, start_time) + + if self._processes_lock is None: + raise RuntimeError("_processes_lock is None in init_processes_start") + if self._offset is None: + raise RuntimeError("_offset is None in init_processes_start; was " + "init_processes_timings not called?") with self._processes_lock: self._offset.value = start_time @@ -505,6 +526,12 @@ async def next_request_time(self, offset: int) -> float: next_delay = self._random.expovariate(self.rate) + if self._processes_lock is None: + raise RuntimeError("_processes_lock is None in next_request_time; was " + "init_processes_timings not called?") + if self._offset is None: + raise RuntimeError("_offset is None in next_request_time; was " + "init_processes_timings not called?") with self._processes_lock: self._offset.value += next_delay diff --git a/src/guidellm/scheduler/worker.py b/src/guidellm/scheduler/worker.py index a46455f9..45b4042b 100644 --- a/src/guidellm/scheduler/worker.py +++ b/src/guidellm/scheduler/worker.py @@ -23,11 +23,9 @@ bool, "Flag indicating uvloop availability for event loop optimization" ] = True except ImportError: - uvloop = None + uvloop = None # type: ignore[assignment] # Optional dependency - HAS_UVLOOP: Annotated[ - bool, "Flag indicating uvloop availability for event loop optimization" - ] = False + HAS_UVLOOP = False from guidellm.scheduler.schemas import ( @@ -84,6 +82,10 @@ def __init__( RequestT | MultiTurnRequestT[RequestT], RequestInfo, ], + tuple[ + RequestT | MultiTurnRequestT[RequestT], + RequestInfo, + ], ], backend: BackendInterface[RequestT, ResponseT], strategy: SchedulingStrategy, @@ -201,8 +203,11 @@ async def run_async(self): async def _stop_monitor( self, - ) -> Literal["error_event", "shutdown_event"]: - """Monitor shutdown and error events for worker termination.""" + ) -> None: + """ + Monitor shutdown and error events for worker termination. + :raises RuntimeError if the work process received an error signal. + """ exit_key = await wait_for_sync_objects( { "error_event": self.error_event, @@ -322,7 +327,7 @@ async def _cancel_requests_loop(self): """Cancel all remaining queued requests until worker process terminates.""" while True: try: - request: RequestT + request: RequestT | MultiTurnRequestT[RequestT] request_info: RequestInfo request, request_info = await self.messaging.get( timeout=self.messaging.poll_interval @@ -350,31 +355,19 @@ async def _process_next_request(self, target_start: float): try: # Pull request from the queue, update state, and send "pending" update - request, request_info = await self.messaging.get() - request_info.timings.dequeued = time.time() - request_info.scheduler_node_id = self.messaging.worker_index or -1 - request_info.timings.targeted_start = target_start - self._send_update("pending", response, request, request_info) - - if request is None or request_info is None: - raise RuntimeError("Received invalid request or request info") - if isinstance(request, list | tuple): - raise NotImplementedError("Multi-turn requests are not yet supported") - - # Schedule the request - current_time = time.time() - request_info.timings.scheduled_at = current_time - if target_start > current_time: - await asyncio.sleep(target_start - current_time) - # Adapt delay so that scheduled at reflects the sleep time - request_info.timings.scheduled_at = target_start - - # Process the request with the backend - request_info.timings.resolve_start = time.time() - self._send_update("in_progress", response, request, request_info) - async for resp, info in self.backend.resolve(request, request_info, None): + request, request_info = await self._dequeue_next_request(target_start) + + # Schedule the request and send "in_progress" update + await self._schedule_request(request, request_info, target_start) + + async for resp, info in self.backend.resolve( # type: ignore[attr-defined] + request, request_info, None + ): + response = resp request_info = info + if request_info is None: + raise RuntimeError("Received invalid request info from backend") # Complete the request request_info.timings.resolve_end = time.time() @@ -397,6 +390,39 @@ async def _process_next_request(self, target_start: float): if request_info is not None: self.strategy.request_completed(request_info) + async def _dequeue_next_request( + self, target_start: float + ) -> tuple[RequestT, RequestInfo]: + request, request_info = await self.messaging.get() + dequeued_time = time.time() # Ensure accurate dequeue timing + if request is None or request_info is None: + raise RuntimeError("Received invalid request or request info") + if isinstance(request, list | tuple): + raise NotImplementedError("Multi-turn requests are not yet supported") + + request_info.timings.dequeued = dequeued_time + request_info.scheduler_node_id = self.messaging.worker_index or -1 + request_info.timings.targeted_start = target_start + self._send_update("pending", None, request, request_info) + return request, request_info + + async def _schedule_request( + self, + request: RequestT, + request_info: RequestInfo, + target_start: float + ): + current_time = time.time() + request_info.timings.scheduled_at = current_time + if target_start > current_time: + await asyncio.sleep(target_start - current_time) + # Adapt delay so that scheduled at reflects the sleep time + request_info.timings.scheduled_at = target_start + + # Process the request with the backend + request_info.timings.resolve_start = time.time() + self._send_update("in_progress", None, request, request_info) + def _send_update( self, new_status: Literal[ diff --git a/src/guidellm/scheduler/worker_group.py b/src/guidellm/scheduler/worker_group.py index c6027989..2a0a51de 100644 --- a/src/guidellm/scheduler/worker_group.py +++ b/src/guidellm/scheduler/worker_group.py @@ -84,7 +84,7 @@ def __init__( backend: BackendInterface[RequestT, ResponseT], strategy: SchedulingStrategy, startup_duration: float, - **constraints: dict[str, Constraint], + **constraints: Constraint, ): """ Initialize a worker process group for distributed request processing. @@ -232,7 +232,7 @@ async def create_processes(self): worker_index=rank, max_buffer_send_size=None, max_buffer_receive_size=per_proc_max_buffer_size, - ), + ), # The non-group worker lacks the SchedulerState type. Type err. backend=self.backend, strategy=self.strategy, async_limit=async_limit, @@ -478,9 +478,9 @@ def __init__( num_processes=len(processes), start_time=start_time, ) - self._queued_requests: set[RequestT | MultiTurnRequestT[RequestT]] = set() - self._pending_requests: set[RequestT | MultiTurnRequestT[RequestT]] = set() - self._processing_requests: set[RequestT | MultiTurnRequestT[RequestT]] = set() + self._queued_request_ids: set[str] = set() + self._pending_request_ids: set[str] = set() + self._processing_request_ids: set[str] = set() def requests_generator( self, requests: Iterable[RequestT | MultiTurnRequestT[RequestT]] @@ -517,11 +517,13 @@ def requests_generator( ) state_update = self._locked_update(request_info) request_info.timings.queued = time.time() + if self.messaging.buffer_receive_queue is None: + raise RuntimeError("buffer receive queue is None") self.messaging.buffer_receive_queue.sync_put( (None, request, request_info, state_update.state) ) - yield (request, request_info) + yield request, request_info if state_update.stop_queueing: self.stop_send_requests_event.set() @@ -530,8 +532,8 @@ def requests_generator( # Reached the end, inject a RequestsExhaustedConstraint to record self._locked_update( info=None, - requests_exhausted={ - "requests_exhausted": RequestsExhaustedConstraint( + add_constraints={ + "requests_exhausted": RequestsExhaustedConstraint( # type: ignore[dict-item] num_requests=count ) }, @@ -610,10 +612,10 @@ def received_callback( def _locked_update( self, info: RequestInfo | None = None, - **add_constraints: dict[str, Constraint], + add_constraints: dict[str, Constraint] | None = None, ) -> _StateUpdate: with self._update_lock: - if add_constraints: + if add_constraints is not None: self.constraints.update(add_constraints) if info is not None: @@ -631,34 +633,34 @@ def _locked_update( def _update_state_request_counts(self, info: RequestInfo): if info.status == "queued": - self._queued_requests.add(info.request_id) - self._state.queued_requests = len(self._queued_requests) + self._queued_request_ids.add(info.request_id) + self._state.queued_requests = len(self._queued_request_ids) self._state.created_requests += 1 elif info.status == "pending": - self._queued_requests.remove(info.request_id) - self._state.queued_requests = len(self._queued_requests) - self._pending_requests.add(info.request_id) - self._state.pending_requests = len(self._pending_requests) + self._queued_request_ids.remove(info.request_id) + self._state.queued_requests = len(self._queued_request_ids) + self._pending_request_ids.add(info.request_id) + self._state.pending_requests = len(self._pending_request_ids) elif info.status == "in_progress": - self._pending_requests.remove(info.request_id) - self._state.pending_requests = len(self._pending_requests) - self._processing_requests.add(info.request_id) - self._state.processing_requests = len(self._processing_requests) + self._pending_request_ids.remove(info.request_id) + self._state.pending_requests = len(self._pending_request_ids) + self._processing_request_ids.add(info.request_id) + self._state.processing_requests = len(self._processing_request_ids) elif info.status == "completed": - self._processing_requests.remove(info.request_id) - self._state.processing_requests = len(self._processing_requests) + self._processing_request_ids.remove(info.request_id) + self._state.processing_requests = len(self._processing_request_ids) self._state.processed_requests += 1 self._state.successful_requests += 1 elif info.status in ("errored", "cancelled"): - if info.request_id in self._queued_requests: - self._queued_requests.remove(info.request_id) - self._state.queued_requests = len(self._queued_requests) - elif info.request_id in self._pending_requests: - self._pending_requests.remove(info.request_id) - self._state.pending_requests = len(self._pending_requests) - elif info.request_id in self._processing_requests: - self._processing_requests.remove(info.request_id) - self._state.processing_requests = len(self._processing_requests) + if info.request_id in self._queued_request_ids: + self._queued_request_ids.remove(info.request_id) + self._state.queued_requests = len(self._queued_request_ids) + elif info.request_id in self._pending_request_ids: + self._pending_request_ids.remove(info.request_id) + self._state.pending_requests = len(self._pending_request_ids) + elif info.request_id in self._processing_request_ids: + self._processing_request_ids.remove(info.request_id) + self._state.processing_requests = len(self._processing_request_ids) self._state.processed_requests += 1 self._state.errored_requests += 1 if info.status == "errored" else 0