diff --git a/proposals/imgs/enqueue_flowchart.png b/proposals/imgs/enqueue_flowchart.png new file mode 100644 index 000000000..1e1b17d1b Binary files /dev/null and b/proposals/imgs/enqueue_flowchart.png differ diff --git a/proposals/imgs/sched_loop_flowchart.png b/proposals/imgs/sched_loop_flowchart.png new file mode 100644 index 000000000..c28eebc02 Binary files /dev/null and b/proposals/imgs/sched_loop_flowchart.png differ diff --git a/proposals/queue_manager_README.md b/proposals/queue_manager_README.md new file mode 100644 index 000000000..465f88722 --- /dev/null +++ b/proposals/queue_manager_README.md @@ -0,0 +1,147 @@ +# Queue Manager for LLM Endpoint Routing + +This module implements an asynchronous queue manager for dispatching of LLM inference requests to backend endpoints. It directly dispatches or queues requests based on endpoint load and endpoint metrics, improving overall Quality of Experience (QoE). + +--- + +## Features + +- Per-endpoint asynchronous request queues +- Condition-based dispatching using `asyncio.Condition` +- Request rerouting if an endpoint remains overloaded too long +- Session affinity preservation (stubbed for future KV cache usage) +- Graceful shutdown of all schedulers +- Queue can be enabled or disabled. Default is enabled. + - Note that queue manager is still instantiated, just not used. + +--- + +## Flow Chart + +![Logic flow for incoming request.](imgs/enqueue_flowchart.png) +![Logic flow for scheduler loop that runs per endpoint.](imgs/sched_loop_flowchart.png) + +--- + +## File: `src/vllm_router/services/queue_service/queue_manager.py` + +### Class: `EndpointQueueManager` + +This class manages: + +- `endpoint_queues`: A `PriorityQueue` per endpoint holding pending requests. +- `conditions`: An `asyncio.Condition` per endpoint used to notify the scheduler loop. +- `endpoint_tasks`: Background async tasks for each endpoint’s queue loop. +- `EngineStatsScraper`: Periodically scrapes GPU & model stats per endpoint. + +--- + +## Request Lifecycle + +### 1. Check Endpoint Availability + +```python +if not queue_manager._endpoint_is_free(server_url): +``` + +- If the endpoint is overloaded (e.g. high GPU usage or too many active requests), the request is queued. +- If it's free, the request is dispatched immediately. + +--- + +### 2. Enqueue Logic + +```python +queue_manager.register_endpoint(server_url) + +await queue_manager.enqueue( + server_url, + { + "request": request, + "request_id": request_id, + "body": request_body, + "endpoint": endpoint, + "background_tasks": background_tasks, + "result_future": response_future + + }, + priority=queue_manager.calculate_request_priority(request) +) +``` + +- Registers the endpoint queue and scheduler if not already present. +- Adds the request to a `PriorityQueue`. +- Notifies the condition variable to wake the scheduler. + +If queued, awaits future response. + +--- + +### 3. Scheduler Loop + +```python +async def _scheduler_loop(self, endpoint_url: str): +``` + +Runs a background task for each endpoint: + +- Waits for new requests in the queue. +- If the endpoint is free, dispatches the request. +- If a request has waited longer than max_queue_wait_time, the scheduler calls `_reroute_or_dispatch_stale_request` to determine next actions. + +--- + +### 4. Dispatch Logic + +```python +async def _dispatch_and_signal(...) +``` + +- Sends the request to the backend via `process_request(...)`. +- Returns a streaming response with appropriate headers. + +--- + +### 5. Rerouting Stale Requests + +If a request exceeds the `max_queue_wait_time` threshold: + +```python +await self._reroute_or_dispatch_stale_request(request, original_endpoint) +``` + +- Attempts to reroute the request to a different free endpoint. +- Currently always reroutes +- If the new endpoint is also busy, queues the request there. + +In future, can choose to keep request at that endpoint if it has session history, or KVCache matches. + +--- + +## Configuration + +```python +queue_manager = EndpointQueueManager(max_queue_wait_time=10) +``` + +- `max_queue_wait_time`: Max seconds a request can wait in queue before being rerouted or retried. + +--- + +## Dependencies + +- `asyncio` +- `EngineStatsScraper` from `vllm_router.stats.engine_stats` +- `process_request()` from `vllm_router.services.request_service.request` + +--- + +## TODOs + +- [ ] Implement KV cache-aware, session affinity logic +- [ ] Implement request priority classification +- [ ] Replace round-robin stale routing policy +- [ ] Retry policies and smarter rerouting heuristics +- [ ] Implement knapsack-like selection allowing for a group of requests to be dispatched at once + +--- diff --git a/src/tests/test_queue.py b/src/tests/test_queue.py new file mode 100644 index 000000000..b1ea19486 --- /dev/null +++ b/src/tests/test_queue.py @@ -0,0 +1,214 @@ +import asyncio +import json +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio +from fastapi.responses import StreamingResponse + +from vllm_router.services.queue_service.queue import ( + get_queue_manager, + initialize_queue_manager, +) + + +@pytest.fixture +def mock_scraper(): + scraper = MagicMock() + scraper.get_engine_stats.return_value = { + "endpoint1": MagicMock(num_running_requests=0, gpu_cache_usage_perc=0), + "endpoint2": MagicMock(num_running_requests=5, gpu_cache_usage_perc=50), + } + return scraper + + +@pytest_asyncio.fixture +async def queue_manager(mock_scraper): + initialize_queue_manager( + max_queue_wait_time=10, + max_running_requests=10, + max_gpu_perc=95, + scraper=mock_scraper, + ) + manager = get_queue_manager() + await manager.register_endpoint("endpoint1") + await manager.register_endpoint("endpoint2") + yield manager + await manager.close() + + +@pytest.mark.asyncio +async def test_queue_manager_initialization(mock_scraper): + initialize_queue_manager( + max_queue_wait_time=10, + max_running_requests=10, + max_gpu_perc=95, + scraper=mock_scraper, + ) + manager = get_queue_manager() + assert manager.max_queue_wait_time == 10 + assert manager.max_running_requests == 10 + assert manager.max_gpu_perc == 95 + assert manager.scraper == mock_scraper + + +@pytest.mark.asyncio +async def test_register_endpoint(queue_manager): + for endpoint in ["endpoint1", "endpoint2"]: + assert endpoint in queue_manager.endpoint_queues + assert endpoint in queue_manager.conditions + assert endpoint in queue_manager.endpoint_tasks + + +@pytest.mark.asyncio +async def test_enqueue_request(queue_manager): + test_request = {"request_id": "test123", "body": "test"} + future = asyncio.Future() + test_request["_result_future"] = future + await queue_manager.enqueue("endpoint1", test_request, priority=1) + assert not queue_manager.endpoint_queues["endpoint1"].empty() + assert not future.done() + + +@pytest.mark.asyncio +async def test_endpoint_is_free(queue_manager, mock_scraper): + assert queue_manager._endpoint_is_free("endpoint1") is True + assert queue_manager._endpoint_is_free("endpoint2") is True + + mock_scraper.get_engine_stats.return_value["endpoint2"].num_running_requests = 15 + assert queue_manager._endpoint_is_free("endpoint2") is False + + mock_scraper.get_engine_stats.return_value["endpoint2"].num_running_requests = 5 + mock_scraper.get_engine_stats.return_value["endpoint2"].gpu_cache_usage_perc = 96 + assert queue_manager._endpoint_is_free("endpoint2") is False + + +@pytest.mark.asyncio +async def test_dispatch_and_signal(queue_manager): + test_request = { + "request_id": "test123", + "body": json.dumps({"prompt": "hello"}), + "request": MagicMock(), + "endpoint": "endpoint1", + "background_tasks": MagicMock(), + "result_future": asyncio.Future(), + } + + with patch( + "vllm_router.services.request_service.request.process_request", + new_callable=AsyncMock, + ) as mock_process: + + async def mock_stream(): + yield ("content-type", 200) + yield StreamingResponse(content=MagicMock()) + + mock_process.return_value.__aiter__.return_value = mock_stream() + + await queue_manager._dispatch_and_signal("endpoint1", test_request) + + +@pytest.mark.asyncio +async def test_scheduler_loop(queue_manager): + test_request = { + "request_id": "test123", + "body": json.dumps({"prompt": "hello"}), + "request": MagicMock(), + "endpoint": "endpoint1", + "background_tasks": MagicMock(), + "result_future": asyncio.Future(), + } + + with patch( + "vllm_router.services.request_service.request.process_request" + ) as mock_process: + mock_headers = {"content-type": "application/json"} + mock_status = 200 + mock_stream = MagicMock() + mock_process.return_value = (mock_headers, mock_status, mock_stream) + + await queue_manager.enqueue("endpoint1", test_request) + await asyncio.sleep(1.5) # Wait enough time for scheduler loop + + assert test_request["result_future"].done() + + +@pytest.mark.asyncio +@patch( + "vllm_router.services.request_service.request.process_request", + new_callable=AsyncMock, +) +@patch( + "vllm_router.services.queue_service.queue.EndpointQueueManager._reroute_or_dispatch_stale_request", + new_callable=AsyncMock, +) +async def test_stale_request_rerouting( + mock_reroute, mock_process_request, queue_manager +): + dummy_request = MagicMock() + dummy_request.state = MagicMock() + + # Simulate a quick response stream + async def dummy_stream(): + yield ({"content-type": "application/json"}, 200) + + mock_process_request.return_value = dummy_stream() + + # Simulate a stale request + stale_request = { + "request_id": "stale123", + "body": '{"input": "hello"}', + "model_name": "test-model", + "session_id": "abc", + "request": dummy_request, + "endpoint": "endpoint1", + "background_tasks": MagicMock(), + "result_future": asyncio.Future(), + "enqueue_timestamp": time.time() - 15, # 15s ago + } + queue_manager._endpoint_is_free = MagicMock(return_value=False) + + await queue_manager.enqueue("endpoint1", stale_request) + + # Let scheduler tick + await asyncio.sleep(15) + + mock_reroute.assert_called_once() + + +@pytest.mark.asyncio +async def test_shutdown(queue_manager): + assert not queue_manager._shutdown_event.is_set() + await queue_manager.close() + assert queue_manager._shutdown_event.is_set() + for task in queue_manager.endpoint_tasks.values(): + assert task.done() + + +@pytest.mark.asyncio +async def test_singleton_pattern(): + from vllm_router.services.queue_service import queue as queue_module + + queue_module._global_queue_manager = None + + scraper = MagicMock() + scraper.get_engine_stats.return_value = { + "endpoint1": MagicMock(num_running_requests=0, gpu_cache_usage_perc=0), + } + + queue_module.initialize_queue_manager( + max_queue_wait_time=10, + max_running_requests=10, + max_gpu_perc=95, + scraper=scraper, + ) + manager1 = queue_module.get_queue_manager() + manager2 = queue_module.get_queue_manager() + assert manager1 is manager2 + + await manager1.close() + queue_module._global_queue_manager = None + + with pytest.raises(ValueError, match="Queue manager not initialized"): + queue_module.get_queue_manager() diff --git a/src/vllm_router/app.py b/src/vllm_router/app.py index 0713e9c0f..8090f7dc9 100644 --- a/src/vllm_router/app.py +++ b/src/vllm_router/app.py @@ -43,6 +43,10 @@ from vllm_router.services.batch_service import initialize_batch_processor from vllm_router.services.callbacks_service.callbacks import configure_custom_callbacks from vllm_router.services.files_service import initialize_storage +from vllm_router.services.queue_service.queue import ( + get_queue_manager, + initialize_queue_manager, +) from vllm_router.services.request_service.rewriter import ( get_request_rewriter, ) @@ -108,6 +112,16 @@ async def lifespan(app: FastAPI): logger.info("Closing dynamic config watcher") dyn_cfg_watcher.close() + # Close the queue manager + try: + queue_manager = get_queue_manager() + if queue_manager is not None: + logger.info("Closing per endpoint queues and tasks") + await queue_manager.close() + except ValueError: + # Queue manager was not initialized + pass + def initialize_all(app: FastAPI, args): """ @@ -175,6 +189,14 @@ def initialize_all(app: FastAPI, args): initialize_engine_stats_scraper(args.engine_stats_interval) initialize_request_stats_monitor(args.request_stats_window) + # Initialize queue + initialize_queue_manager( + args.enable_queue, + args.max_wait_time, + args.max_running_requests, + args.max_gpu_perc, + ) + if args.enable_batch_api: logger.info("Initializing batch API") app.state.batch_storage = initialize_storage( diff --git a/src/vllm_router/parsers/parser.py b/src/vllm_router/parsers/parser.py index 8b12cf983..3c691a4de 100644 --- a/src/vllm_router/parsers/parser.py +++ b/src/vllm_router/parsers/parser.py @@ -379,6 +379,32 @@ def parse_args(): help="The threshold for kv-aware routing.", ) + parser.add_argument( + "--enable_queue", + action=argparse.BooleanOptionalAction, + default=False, + help="Enable router-side queuing. Note that queue will still be initialized, just not actually enqueued.", + ) + parser.add_argument( + "--max-wait-time", + type=int, + default=10, + help="The maximum amount of time a request waits in a queue before it gets rerouted. E.g., 10s", + ) + + parser.add_argument( + "--max-running-requests", + type=int, + default=10, + help="The maximum number of running requests in an endpoint before the router enqueues an incoming request", + ) + + parser.add_argument( + "--max-gpu-perc", + type=int, + default=95, + help="The maximum GPU use percentage of an endpoint before the router enqueues an incoming request", + ) args = parser.parse_args() args = load_initial_config_from_config_file_if_required(parser, args) diff --git a/src/vllm_router/services/queue_service/queue.py b/src/vllm_router/services/queue_service/queue.py new file mode 100644 index 000000000..9b1e4ca0c --- /dev/null +++ b/src/vllm_router/services/queue_service/queue.py @@ -0,0 +1,350 @@ +# services/queue_manager.py + +import asyncio +import time +from threading import Lock +from typing import Any, Dict + +from fastapi.responses import StreamingResponse + +from vllm_router.stats.engine_stats import get_engine_stats_scraper + +_global_queue_manager = None + + +class EndpointQueueManager: + def __init__( + self, + enable_queue, + max_queue_wait_time, + max_running_requests, + max_gpu_perc, + scraper=None, + ): + """ + Initializes the queue manager responsible for scheduling and dispatching + requests to backend endpoints based on GPU load, request priority, and wait time. + + Args: + max_queue_wait_time (float): Maximum time (in seconds) a request can wait before being rerouted. + max_running_requests (int): Maximum number of concurrent requests allowed on an endpoint. + max_gpu_perc (float): Maximum allowed GPU usage percentage per endpoint. + scraper: Optional engine stats scraper for monitoring backend load. + """ + self.enable_queue = enable_queue + self.endpoint_queues: Dict[str, asyncio.PriorityQueue] = {} + self.conditions: Dict[str, asyncio.Condition] = {} + self._register_lock = asyncio.Lock() + + self.scraper = scraper or get_engine_stats_scraper() + if self.scraper is None: + raise RuntimeError("Engine stats scraper not initialized.") + + # User configurable fields + self.max_running_requests = max_running_requests + self.max_gpu_perc = max_gpu_perc + self.max_queue_wait_time = max_queue_wait_time + + self.stale_check_interval = 2 + + # Stale request round-robin fallback strategy + self.req_id = 0 + self._lock = Lock() + + # Kept for shutdown + self.endpoint_tasks: Dict[str, asyncio.Task] = {} + self._shutdown_event = asyncio.Event() + + async def register_endpoint(self, endpoint_url: str): + """ + Registers an endpoint with the queue manager. Initializes a queue and + a scheduler loop for the endpoint if not already registered. + + Args: + endpoint_url (str): The unique identifier (typically URL) for the backend endpoint. + """ + async with self._register_lock: + if endpoint_url in self.endpoint_queues: + return # Already registered + + self.endpoint_queues[endpoint_url] = asyncio.PriorityQueue() + self.conditions[endpoint_url] = asyncio.Condition() + task = asyncio.create_task(self._scheduler_loop(endpoint_url)) + self.endpoint_tasks[endpoint_url] = task + + async def enqueue( + self, endpoint_url: str, request: Dict[str, Any], priority: int = 0 + ): + """ + Adds a request to the endpoint-specific priority queue and notifies + the scheduler that a new request is available. + + Args: + endpoint_url (str): The endpoint to which the request should be enqueued. + request (dict): Metadata and payload for the request. + priority (int): Priority value (lower values are dequeued earlier). + """ + if self._shutdown_event.is_set(): + raise RuntimeError( + "Scheduler is shutting down, can't enqueue new requests." + ) + + await self.endpoint_queues[endpoint_url].put((priority, time.time(), request)) + async with self.conditions[endpoint_url]: + self.conditions[ + endpoint_url + ].notify() # Tell queue that a request is available + + async def _scheduler_loop(self, endpoint_url: str): + """ + Continuously monitors the request queue for the given endpoint, and + dispatches or reroutes requests based on endpoint load and wait time. + + This function runs in the background per endpoint. + """ + + queue = self.endpoint_queues[endpoint_url] + condition = self.conditions[endpoint_url] + + last_stale_check = 0 + + while not self._shutdown_event.is_set(): + async with condition: + # Wait until queue not empty or shutdown + await condition.wait_for( + lambda: (not queue.empty()) or self._shutdown_event.is_set() + ) + if self._shutdown_event.is_set(): + break + + # Dispatch as many requests as endpoint allows + while not queue.empty() and self._endpoint_is_free(endpoint_url): + _, _, request = queue.get_nowait() + asyncio.create_task( + self._dispatch_and_signal(endpoint_url, request) + ) + + # After dispatching, periodically check stale requests outside the condition + now = time.time() + if now - last_stale_check > self.stale_check_interval: + last_stale_check = now + + # Check for stale requests without holding the condition lock + try: + priority, enqueue_time, stale_request = queue._queue[0] + except IndexError: + continue # queue empty + + wait_duration = now - enqueue_time + if wait_duration > self.max_queue_wait_time: + async with condition: + try: + _, _, stale_request = queue.get_nowait() + except asyncio.QueueEmpty: + continue + + await self._reroute_or_dispatch_stale_request( + stale_request, endpoint_url + ) + + await asyncio.sleep(0.05) # small sleep to avoid tight loop + + def _endpoint_is_free( + self, endpoint_url: str + ) -> bool: # TODO: What stats could be relevant + """ + Determines whether the specified endpoint is currently available to handle a new request, + based on configured load and GPU thresholds. + + Args: + endpoint_url (str): The endpoint to check. + + Returns: + bool: True if the endpoint is under capacity, False otherwise. + """ + + stats = self.scraper.get_engine_stats().get(endpoint_url) + return ( + stats + and stats.num_running_requests < self.max_running_requests + and stats.gpu_cache_usage_perc < self.max_gpu_perc + ) + + async def _dispatch_and_signal(self, endpoint_url: str, request: Dict[str, Any]): + """ + Sends a request to the target endpoint and fulfills any associated future + used by upstream logic to await response. + + Args: + endpoint_url (str): The backend endpoint to dispatch the request to. + request (dict): Request metadata, including content and completion future. + """ + from vllm_router.services.request_service.request import process_request + + result_future = request.get("result_future") + try: + stream_generator = process_request( + request["request"], + request["body"], + endpoint_url, + request["request_id"], + request["endpoint"], + request["background_tasks"], + self.conditions[endpoint_url], + ) + headers, status_code = await anext(stream_generator) + headers_dict = dict(headers) + headers_dict["X-Request-Id"] = request["request_id"] + + response = StreamingResponse( + stream_generator, + status_code=status_code, + headers=headers_dict, + media_type="text/event-stream", + ) + + # Fulfill the future + if result_future and not result_future.done(): + result_future.set_result(response) + + except Exception as e: + if result_future and not result_future.done(): + result_future.set_exception(e) + else: + print(f"[Queue Dispatch Error] {e}") + + return + + async def _reroute_or_dispatch_stale_request( + self, request: dict, original_endpoint: str + ): + """ + Handles requests that have waited in the queue too long. Either reroutes + them to a different eligible endpoint or re-enqueues them with higher priority. + + Args: + request (dict): The request object to be rerouted or re-enqueued. + original_endpoint (str): The endpoint where the request was originally queued. + """ + + # TODO: Use KV cache hit estimation in future, session aware id + + priority = max( + 0, self.calculate_request_priority(request) - 1 + ) # priority is boosted + + if ( + True + ): # Replace with conditionals, ie, no session affinity or high KV cache matches + + new_endpoint = self.find_new_endpoint(exclude=original_endpoint) + await self.register_endpoint(new_endpoint) + if new_endpoint and new_endpoint != original_endpoint: + # print(f"[Rerouting] Request {request_id} → {new_endpoint} (was {original_endpoint})") + + if self._endpoint_is_free(new_endpoint): + asyncio.create_task( + self._dispatch_and_signal(new_endpoint, request) + ) + else: + await self.enqueue(new_endpoint, request, priority) + return + + # Keep original endpoint + # print(f"[Requeue] Request {request_id} stays at {original_endpoint}") + await self.enqueue(original_endpoint, request, priority) + + def find_new_endpoint(self, exclude: str) -> str: + """ + Selects a new endpoint to reroute a stale request, excluding the original one. + Uses round-robin logic to rotate among available endpoints. + + Args: + exclude (str): The endpoint to avoid in selection. + + Returns: + str: Chosen new endpoint (or original if no other available). + """ + # TODO: Get currently used router and pass in list of endpoints excluding orig endpoint to preserve routing strategy + endpoints = [ep for ep in self.endpoint_queues.keys() if ep != exclude] + + if not endpoints: + return exclude + + with self._lock: + new_endpoint = sorted(endpoints, key=lambda e: e)[ + self.req_id % len(endpoints) + ] + self.req_id += 1 + return new_endpoint + + def calculate_request_priority(self, request) -> int: # TODO + """ + Determines the priority of a request. Placeholder for future QoS heuristics. + + Args: + request (dict): The request to score. + + Returns: + int: Priority value (lower = higher priority). + """ + return 0 + + async def close(self): + """ + Shuts down the queue manager by cancelling all scheduler tasks + and waiting for them to complete. Ensures no new requests are accepted. + """ + + self._shutdown_event.set() + + for task in self.endpoint_tasks.values(): + task.cancel() + + # wait for all tasks to cancel + await asyncio.gather(*self.endpoint_tasks.values(), return_exceptions=True) + + print("Scheduler shutdown complete.") + + +def initialize_queue_manager( + enable_queue=True, + max_queue_wait_time=10, + max_running_requests=10, + max_gpu_perc=95, + scraper=None, +): + """ + Initializes and globally registers the queue manager with the specified configuration. + + Args: + max_queue_wait_time (float): Max time a request can wait in queue before reroute. + max_running_requests (int): Max concurrent requests per endpoint. + max_gpu_perc (float): Max allowed GPU usage per endpoint. + scraper: Optional engine stats scraper override. + """ + + global _global_queue_manager + _global_queue_manager = EndpointQueueManager( + enable_queue=enable_queue, + max_queue_wait_time=max_queue_wait_time, + max_running_requests=max_running_requests, + max_gpu_perc=max_gpu_perc, + scraper=scraper, + ) + + +def get_queue_manager() -> "EndpointQueueManager": + """ + Returns the globally initialized queue manager instance. + + Raises: + ValueError: If the queue manager has not been initialized. + + Returns: + EndpointQueueManager: The singleton instance of the queue manager. + """ + + if _global_queue_manager is None: + raise ValueError("Queue manager not initialized") + return _global_queue_manager diff --git a/src/vllm_router/services/request_service/request.py b/src/vllm_router/services/request_service/request.py index 83e647927..af71a7666 100644 --- a/src/vllm_router/services/request_service/request.py +++ b/src/vllm_router/services/request_service/request.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio + # --- Request Processing & Routing --- import json import os @@ -30,6 +32,7 @@ PrefixAwareRouter, ) from vllm_router.service_discovery import get_service_discovery +from vllm_router.services.queue_service.queue import get_queue_manager from vllm_router.services.request_service.rewriter import ( get_request_rewriter, is_request_rewriter_initialized, @@ -59,6 +62,7 @@ async def process_request( endpoint, background_tasks: BackgroundTasks, debug_request=None, + condition: asyncio.Condition = None, ): """ Process a request by sending it to the chosen backend. @@ -120,7 +124,11 @@ async def process_request( request.app.state.request_stats_monitor.on_request_complete( backend_url, request_id, time.time() ) - + if ( + condition + ): # lets scheduler know that an endpoint-specific request has completed, can perhaps dispatch new + async with condition: + condition.notify() # if debug_request: # logger.debug(f"Finished the request with request id: {debug_request.headers.get('x-request-id', None)} at {time.time()}") # Store in semantic cache if applicable @@ -154,6 +162,9 @@ async def route_general_request( Returns: StreamingResponse: A response object that streams data from the backend server to the client. """ + # if queue enabled? + queue_manager = get_queue_manager() + if isinstance(request.app.state.router, DisaggregatedPrefillRouter): response = await route_disaggregated_prefill_request( request, endpoint, background_tasks @@ -279,9 +290,34 @@ async def route_general_request( logger.debug(f"Debug session extraction - Request headers: {dict(request.headers)}") logger.debug(f"Debug session extraction - Extracted session ID: {session_id}") + await queue_manager.register_endpoint(server_url) # if queue does not already exist + # Enqueue if endpoint load is too high + if queue_manager.enable_queue and not queue_manager._endpoint_is_free(server_url): + + response_future = asyncio.get_event_loop().create_future() + + await queue_manager.enqueue( + server_url, + { + "request": request, + "request_id": request_id, + "body": request_body, + "endpoint": endpoint, + "background_tasks": background_tasks, + "result_future": response_future, + }, + priority=queue_manager.calculate_request_priority(request), + ) + + return await response_future + logger.info( f"Routing request {request_id} with session id {session_id_display} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}" ) + condition = None + if queue_manager.enable_queue: + condition = queue_manager.conditions[server_url] + stream_generator = process_request( request, request_body, @@ -289,6 +325,7 @@ async def route_general_request( request_id, endpoint, background_tasks, + condition=condition, ) headers, status = await anext(stream_generator) headers_dict = {key: value for key, value in headers.items()}