Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/guidellm/scheduler/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,9 +1005,7 @@ def info(self) -> dict[str, Any]:
return self.model_dump()

def __call__(
self,
state: SchedulerState,
request_info: RequestInfo, # noqa: ARG002
self, state: SchedulerState, _request: RequestInfo
) -> SchedulerUpdateAction:
create_exceeded = state.created_requests >= self.num_requests
processed_exceeded = state.processed_requests >= self.num_requests
Expand Down
4 changes: 2 additions & 2 deletions src/guidellm/scheduler/environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ async def sync_run_start(self) -> float:
async def update_run_iteration(
self,
response: ResponseT | None,
request: RequestT,
request: RequestT | MultiTurnRequestT[RequestT],
request_info: RequestInfo,
state: SchedulerState,
):
Expand Down Expand Up @@ -201,7 +201,7 @@ async def sync_run_start(self) -> float:
async def update_run_iteration(
self,
response: ResponseT | None,
request: RequestT,
request: RequestT | MultiTurnRequestT[RequestT],
request_info: RequestInfo,
state: SchedulerState,
):
Expand Down
2 changes: 1 addition & 1 deletion src/guidellm/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def run(
) -> AsyncIterator[
tuple[
ResponseT | None,
RequestT,
RequestT | MultiTurnRequestT[RequestT],
RequestInfo,
SchedulerState,
]
Expand Down
35 changes: 31 additions & 4 deletions src/guidellm/scheduler/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def __pydantic_schema_base_type__(cls) -> type[SchedulingStrategy]:
description="Number of worker processes to use for this strategy",
ge=0,
)
max_concurrency: int = Field(
default=0,
max_concurrency: int | None = Field(
default=None,
description="Maximum number of concurrent requests to allow",
ge=0,
)
Expand Down Expand Up @@ -122,8 +122,8 @@ def init_processes_timings(
self.startup_duration = startup_duration

self._processes_request_index = Value("i", 0)
self._processes_lock = Lock()
self._processes_start_time = Value("d", -1.0)
self._processes_lock = Lock()

def init_processes_start(self, start_time: float):
"""
Expand All @@ -137,6 +137,10 @@ def init_processes_start(self, start_time: float):
"SchedulingStrategy init_processes_start called before "
"init_processes_timings"
)
if self._processes_start_time is None:
raise RuntimeError(
"_processes_lock is not None but _processes_start_time is None"
)

with self._processes_lock:
self._processes_start_time.value = start_time
Expand All @@ -153,6 +157,10 @@ async def get_processes_start_time(self) -> float:
"SchedulingStrategy get_processes_start_time called before "
"init_processes_timings"
)
if self._processes_start_time is None:
raise RuntimeError(
"_processes_lock is not None but _processes_start_time is None"
)

while self._cached_processes_start_time is None:
with self._processes_lock:
Expand All @@ -175,6 +183,10 @@ def next_request_index(self) -> int:
"SchedulingStrategy next_request_index called before "
"init_processes_timings"
)
if self._processes_request_index is None:
raise RuntimeError(
"_processes_lock is not None but _processes_request_index is None"
)

with self._processes_lock:
self._processes_request_index.value += 1
Expand Down Expand Up @@ -369,7 +381,8 @@ async def next_request_time(self, offset: int) -> float:
start_time = await self.get_processes_start_time()

if (
self.startup_duration > 0
self.max_concurrency is not None
and self.startup_duration > 0
and (time.time() - start_time) < self.startup_duration
and (current_index := self.next_request_index()) <= self.max_concurrency
):
Expand Down Expand Up @@ -477,6 +490,8 @@ def init_processes_timings(
:param startup_duration: Duration in seconds for request startup ramping
"""
super().init_processes_timings(worker_count, max_concurrency, startup_duration)
if self._processes_lock is None:
raise RuntimeError("_processes_lock is None in init_processes_timings")
with self._processes_lock:
self._offset = Value("d", -1.0)

Expand All @@ -487,6 +502,12 @@ def init_processes_start(self, start_time: float):
:param start_time: Unix timestamp when request processing should begin
"""
ThroughputStrategy.init_processes_start(self, start_time)

if self._processes_lock is None:
raise RuntimeError("_processes_lock is None in init_processes_start")
if self._offset is None:
raise RuntimeError("_offset is None in init_processes_start; was "
"init_processes_timings not called?")
with self._processes_lock:
self._offset.value = start_time

Expand All @@ -505,6 +526,12 @@ async def next_request_time(self, offset: int) -> float:

next_delay = self._random.expovariate(self.rate)

if self._processes_lock is None:
raise RuntimeError("_processes_lock is None in next_request_time; was "
"init_processes_timings not called?")
if self._offset is None:
raise RuntimeError("_offset is None in next_request_time; was "
"init_processes_timings not called?")
with self._processes_lock:
self._offset.value += next_delay

Expand Down
86 changes: 56 additions & 30 deletions src/guidellm/scheduler/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,9 @@
bool, "Flag indicating uvloop availability for event loop optimization"
] = True
except ImportError:
uvloop = None
uvloop = None # type: ignore[assignment] # Optional dependency

HAS_UVLOOP: Annotated[
bool, "Flag indicating uvloop availability for event loop optimization"
] = False
HAS_UVLOOP = False


from guidellm.scheduler.schemas import (
Expand Down Expand Up @@ -84,6 +82,10 @@ def __init__(
RequestT | MultiTurnRequestT[RequestT],
RequestInfo,
],
tuple[
RequestT | MultiTurnRequestT[RequestT],
RequestInfo,
],
],
backend: BackendInterface[RequestT, ResponseT],
strategy: SchedulingStrategy,
Expand Down Expand Up @@ -201,8 +203,11 @@ async def run_async(self):

async def _stop_monitor(
self,
) -> Literal["error_event", "shutdown_event"]:
"""Monitor shutdown and error events for worker termination."""
) -> None:
"""
Monitor shutdown and error events for worker termination.
:raises RuntimeError if the work process received an error signal.
"""
exit_key = await wait_for_sync_objects(
{
"error_event": self.error_event,
Expand Down Expand Up @@ -322,7 +327,7 @@ async def _cancel_requests_loop(self):
"""Cancel all remaining queued requests until worker process terminates."""
while True:
try:
request: RequestT
request: RequestT | MultiTurnRequestT[RequestT]
request_info: RequestInfo
request, request_info = await self.messaging.get(
timeout=self.messaging.poll_interval
Expand Down Expand Up @@ -350,31 +355,19 @@ async def _process_next_request(self, target_start: float):

try:
# Pull request from the queue, update state, and send "pending" update
request, request_info = await self.messaging.get()
request_info.timings.dequeued = time.time()
request_info.scheduler_node_id = self.messaging.worker_index or -1
request_info.timings.targeted_start = target_start
self._send_update("pending", response, request, request_info)

if request is None or request_info is None:
raise RuntimeError("Received invalid request or request info")
if isinstance(request, list | tuple):
raise NotImplementedError("Multi-turn requests are not yet supported")

# Schedule the request
current_time = time.time()
request_info.timings.scheduled_at = current_time
if target_start > current_time:
await asyncio.sleep(target_start - current_time)
# Adapt delay so that scheduled at reflects the sleep time
request_info.timings.scheduled_at = target_start

# Process the request with the backend
request_info.timings.resolve_start = time.time()
self._send_update("in_progress", response, request, request_info)
async for resp, info in self.backend.resolve(request, request_info, None):
request, request_info = await self._dequeue_next_request(target_start)

# Schedule the request and send "in_progress" update
await self._schedule_request(request, request_info, target_start)

async for resp, info in self.backend.resolve( # type: ignore[attr-defined]
request, request_info, None
):

response = resp
request_info = info
if request_info is None:
raise RuntimeError("Received invalid request info from backend")

# Complete the request
request_info.timings.resolve_end = time.time()
Expand All @@ -397,6 +390,39 @@ async def _process_next_request(self, target_start: float):
if request_info is not None:
self.strategy.request_completed(request_info)

async def _dequeue_next_request(
self, target_start: float
) -> tuple[RequestT, RequestInfo]:
request, request_info = await self.messaging.get()
dequeued_time = time.time() # Ensure accurate dequeue timing
if request is None or request_info is None:
raise RuntimeError("Received invalid request or request info")
if isinstance(request, list | tuple):
raise NotImplementedError("Multi-turn requests are not yet supported")

request_info.timings.dequeued = dequeued_time
request_info.scheduler_node_id = self.messaging.worker_index or -1
request_info.timings.targeted_start = target_start
self._send_update("pending", None, request, request_info)
return request, request_info

async def _schedule_request(
self,
request: RequestT,
request_info: RequestInfo,
target_start: float
):
current_time = time.time()
request_info.timings.scheduled_at = current_time
if target_start > current_time:
await asyncio.sleep(target_start - current_time)
# Adapt delay so that scheduled at reflects the sleep time
request_info.timings.scheduled_at = target_start

# Process the request with the backend
request_info.timings.resolve_start = time.time()
self._send_update("in_progress", None, request, request_info)

def _send_update(
self,
new_status: Literal[
Expand Down
64 changes: 33 additions & 31 deletions src/guidellm/scheduler/worker_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(
backend: BackendInterface[RequestT, ResponseT],
strategy: SchedulingStrategy,
startup_duration: float,
**constraints: dict[str, Constraint],
**constraints: Constraint,
):
"""
Initialize a worker process group for distributed request processing.
Expand Down Expand Up @@ -232,7 +232,7 @@ async def create_processes(self):
worker_index=rank,
max_buffer_send_size=None,
max_buffer_receive_size=per_proc_max_buffer_size,
),
), # The non-group worker lacks the SchedulerState type. Type err.
backend=self.backend,
strategy=self.strategy,
async_limit=async_limit,
Expand Down Expand Up @@ -478,9 +478,9 @@ def __init__(
num_processes=len(processes),
start_time=start_time,
)
self._queued_requests: set[RequestT | MultiTurnRequestT[RequestT]] = set()
self._pending_requests: set[RequestT | MultiTurnRequestT[RequestT]] = set()
self._processing_requests: set[RequestT | MultiTurnRequestT[RequestT]] = set()
self._queued_request_ids: set[str] = set()
self._pending_request_ids: set[str] = set()
self._processing_request_ids: set[str] = set()

def requests_generator(
self, requests: Iterable[RequestT | MultiTurnRequestT[RequestT]]
Expand Down Expand Up @@ -517,11 +517,13 @@ def requests_generator(
)
state_update = self._locked_update(request_info)
request_info.timings.queued = time.time()
if self.messaging.buffer_receive_queue is None:
raise RuntimeError("buffer receive queue is None")
self.messaging.buffer_receive_queue.sync_put(
(None, request, request_info, state_update.state)
)

yield (request, request_info)
yield request, request_info

if state_update.stop_queueing:
self.stop_send_requests_event.set()
Expand All @@ -530,8 +532,8 @@ def requests_generator(
# Reached the end, inject a RequestsExhaustedConstraint to record
self._locked_update(
info=None,
requests_exhausted={
"requests_exhausted": RequestsExhaustedConstraint(
add_constraints={
"requests_exhausted": RequestsExhaustedConstraint( # type: ignore[dict-item]
num_requests=count
)
},
Expand Down Expand Up @@ -610,10 +612,10 @@ def received_callback(
def _locked_update(
self,
info: RequestInfo | None = None,
**add_constraints: dict[str, Constraint],
add_constraints: dict[str, Constraint] | None = None,
) -> _StateUpdate:
with self._update_lock:
if add_constraints:
if add_constraints is not None:
self.constraints.update(add_constraints)

if info is not None:
Expand All @@ -631,34 +633,34 @@ def _locked_update(

def _update_state_request_counts(self, info: RequestInfo):
if info.status == "queued":
self._queued_requests.add(info.request_id)
self._state.queued_requests = len(self._queued_requests)
self._queued_request_ids.add(info.request_id)
self._state.queued_requests = len(self._queued_request_ids)
self._state.created_requests += 1
elif info.status == "pending":
self._queued_requests.remove(info.request_id)
self._state.queued_requests = len(self._queued_requests)
self._pending_requests.add(info.request_id)
self._state.pending_requests = len(self._pending_requests)
self._queued_request_ids.remove(info.request_id)
self._state.queued_requests = len(self._queued_request_ids)
self._pending_request_ids.add(info.request_id)
self._state.pending_requests = len(self._pending_request_ids)
elif info.status == "in_progress":
self._pending_requests.remove(info.request_id)
self._state.pending_requests = len(self._pending_requests)
self._processing_requests.add(info.request_id)
self._state.processing_requests = len(self._processing_requests)
self._pending_request_ids.remove(info.request_id)
self._state.pending_requests = len(self._pending_request_ids)
self._processing_request_ids.add(info.request_id)
self._state.processing_requests = len(self._processing_request_ids)
elif info.status == "completed":
self._processing_requests.remove(info.request_id)
self._state.processing_requests = len(self._processing_requests)
self._processing_request_ids.remove(info.request_id)
self._state.processing_requests = len(self._processing_request_ids)
self._state.processed_requests += 1
self._state.successful_requests += 1
elif info.status in ("errored", "cancelled"):
if info.request_id in self._queued_requests:
self._queued_requests.remove(info.request_id)
self._state.queued_requests = len(self._queued_requests)
elif info.request_id in self._pending_requests:
self._pending_requests.remove(info.request_id)
self._state.pending_requests = len(self._pending_requests)
elif info.request_id in self._processing_requests:
self._processing_requests.remove(info.request_id)
self._state.processing_requests = len(self._processing_requests)
if info.request_id in self._queued_request_ids:
self._queued_request_ids.remove(info.request_id)
self._state.queued_requests = len(self._queued_request_ids)
elif info.request_id in self._pending_request_ids:
self._pending_request_ids.remove(info.request_id)
self._state.pending_requests = len(self._pending_request_ids)
elif info.request_id in self._processing_request_ids:
self._processing_request_ids.remove(info.request_id)
self._state.processing_requests = len(self._processing_request_ids)

self._state.processed_requests += 1
self._state.errored_requests += 1 if info.status == "errored" else 0
Expand Down
Loading