diff --git a/src/vllm_router/app.py b/src/vllm_router/app.py index c67c291d5..93fa08ce5 100644 --- a/src/vllm_router/app.py +++ b/src/vllm_router/app.py @@ -291,6 +291,7 @@ def initialize_all(app: FastAPI, args): max_instance_failover_reroute_attempts=args.max_instance_failover_reroute_attempts, lmcache_health_check_interval=args.lmcache_health_check_interval, lmcache_worker_timeout=args.lmcache_worker_timeout, + fallback_routing_logic=args.fallback_routing_logic, ) # Initialize feature gates diff --git a/src/vllm_router/dynamic_config.py b/src/vllm_router/dynamic_config.py index 8bce9535c..fdb093d29 100644 --- a/src/vllm_router/dynamic_config.py +++ b/src/vllm_router/dynamic_config.py @@ -70,6 +70,7 @@ class DynamicRouterConfig: # Routing logic configurations session_key: Optional[str] = None + fallback_routing_logic: Optional[str] = None # Logging Options callbacks: Optional[str] = None @@ -109,6 +110,7 @@ def from_args(args) -> "DynamicRouterConfig": # Routing logic configurations routing_logic=args.routing_logic, session_key=args.session_key, + fallback_routing_logic=args.fallback_routing_logic, # Logging Options callbacks=args.callbacks, ) @@ -224,7 +226,9 @@ def reconfigure_routing_logic(self, config: DynamicRouterConfig): Reconfigures the router with the given config. """ routing_logic = reconfigure_routing_logic( - config.routing_logic, session_key=config.session_key + config.routing_logic, + session_key=config.session_key, + fallback_routing_logic=config.fallback_routing_logic, ) self.app.state.router = routing_logic logger.info("DynamicConfigWatcher: Routing logic reconfiguration complete") diff --git a/src/vllm_router/parsers/parser.py b/src/vllm_router/parsers/parser.py index 8adc74637..e27063484 100644 --- a/src/vllm_router/parsers/parser.py +++ b/src/vllm_router/parsers/parser.py @@ -443,6 +443,14 @@ def parse_args(): help="Use secure (TLS) connection for OTLP exporter. Default is insecure.", ) + parser.add_argument( + "--fallback-routing-logic", + type=str, + choices=["roundrobin", "session"], + default=None, + help="Routing logic for models without prefill/decode endpoints (only used with disaggregated_prefill_orchestrated).", + ) + parser.add_argument( "--prefill-model-labels", type=str, diff --git a/src/vllm_router/parsers/yaml_utils.py b/src/vllm_router/parsers/yaml_utils.py index 3cf398c1b..80506d79e 100644 --- a/src/vllm_router/parsers/yaml_utils.py +++ b/src/vllm_router/parsers/yaml_utils.py @@ -21,7 +21,8 @@ def generate_static_models(models: dict[str, Any]) -> str: static_models = [] for name, details in models.items(): if "static_backends" in details: - static_models.extend([name] * len(details["static_backends"])) + model_name = details.get("model_name", name) + static_models.extend([model_name] * len(details["static_backends"])) return ",".join(static_models) @@ -62,6 +63,15 @@ def generate_static_aliases(aliases: dict[str, Any]) -> str: return ",".join(parts) +def generate_static_model_labels(models: dict[str, Any]) -> str: + static_model_labels = [] + for _, details in models.items(): + if "static_backends" in details: + label = details.get("model_label", "default") + static_model_labels.extend([label] * len(details["static_backends"])) + return ",".join(static_model_labels) + + def generate_static_model_types(models: dict[str, Any]) -> str: static_model_types = [] for _, details in models.items(): @@ -101,6 +111,7 @@ def read_and_process_yaml_config_file(config_path: str) -> dict[str, Any]: if models: yaml_config["static_backends"] = generate_static_backends(models) yaml_config["static_models"] = generate_static_models(models) + yaml_config["static_model_labels"] = generate_static_model_labels(models) yaml_config["static_model_types"] = generate_static_model_types(models) yaml_config["static_healthcheck_disabled"] = ( generate_static_healthcheck_disabled(models) diff --git a/src/vllm_router/routers/routing_logic.py b/src/vllm_router/routers/routing_logic.py index 583a38b8c..5c7d58446 100644 --- a/src/vllm_router/routers/routing_logic.py +++ b/src/vllm_router/routers/routing_logic.py @@ -568,19 +568,35 @@ class DisaggregatedPrefillOrchestratedRouter(RoutingInterface): Load balancing: Uses round-robin across available prefill and decode pods. """ - def __init__(self, prefill_model_labels: List[str], decode_model_labels: List[str]): + def __init__( + self, + prefill_model_labels: List[str], + decode_model_labels: List[str], + fallback_routing_logic: Optional[str] = None, + session_key: Optional[str] = None, + ): if hasattr(self, "_initialized"): return self.prefill_model_labels = prefill_model_labels or [] self.decode_model_labels = decode_model_labels or [] - # Round-robin counters for load balancing across xPyD pods self.prefill_idx = 0 self.decode_idx = 0 + + if fallback_routing_logic == "session": + self._fallback_router: RoutingInterface = SessionRouter(session_key) + elif fallback_routing_logic == "roundrobin": + self._fallback_router = RoundRobinRouter() + elif session_key: + self._fallback_router = SessionRouter(session_key) + else: + self._fallback_router = RoundRobinRouter() + self._initialized = True logger.info( f"Initialized DisaggregatedPrefillOrchestratedRouter with " f"prefill_labels={self.prefill_model_labels}, " - f"decode_labels={self.decode_model_labels}" + f"decode_labels={self.decode_model_labels}, " + f"fallback={type(self._fallback_router).__name__}" ) def _find_endpoints(self, endpoints: List[EndpointInfo]): @@ -652,13 +668,15 @@ async def route_request( request_json: Optional[Dict] = None, ) -> str: """ - This method is called by the router framework but for orchestrated routing, - we need to handle the full flow differently. This returns the prefill URL - as a placeholder - the actual orchestration happens in route_orchestrated_disaggregated_request. + Fallback routing for models without prefill/decode endpoints. + Delegates to the configured fallback router (session or roundrobin). + P/D models are handled in route_orchestrated_disaggregated_request. """ - prefiller_endpoints, _ = self._find_endpoints(endpoints) - # Return prefill URL - actual orchestration is done in request.py - return prefiller_endpoints[0].url + if not endpoints: + raise ValueError("No endpoints available") + return await self._fallback_router.route_request( + endpoints, engine_stats, request_stats, request, request_json + ) # Instead of managing a global _global_router, we can define the initialization functions as: @@ -696,7 +714,10 @@ def initialize_routing_logic( elif routing_logic == RoutingLogic.DISAGGREGATED_PREFILL_ORCHESTRATED: logger.info("Initializing disaggregated prefill orchestrated routing logic") return DisaggregatedPrefillOrchestratedRouter( - kwargs.get("prefill_model_labels"), kwargs.get("decode_model_labels") + kwargs.get("prefill_model_labels"), + kwargs.get("decode_model_labels"), + fallback_routing_logic=kwargs.get("fallback_routing_logic"), + session_key=kwargs.get("session_key"), ) else: raise ValueError(f"Invalid routing logic {routing_logic}") diff --git a/src/vllm_router/services/request_service/request.py b/src/vllm_router/services/request_service/request.py index bf04a80a8..f061d88f6 100644 --- a/src/vllm_router/services/request_service/request.py +++ b/src/vllm_router/services/request_service/request.py @@ -401,23 +401,50 @@ async def route_general_request( Returns: StreamingResponse: A response object that streams data from the backend server to the client. """ + # Read body once upfront; Starlette caches it for subsequent reads + request_body = await request.body() + request_json = json.loads(request_body) if request_body else {} + if isinstance(request.app.state.router, DisaggregatedPrefillRouter): response = await route_disaggregated_prefill_request( request, endpoint, background_tasks ) return response - # Handle orchestrated disaggregated inference (NxDI pattern) + # Orchestrated disaggregated inference — only for models that have + # both prefill-labeled and decode-labeled endpoints configured. + # Models without P/D endpoints use standard routing regardless of + # the router type (config-driven, no model names in code). if isinstance(request.app.state.router, DisaggregatedPrefillOrchestratedRouter): - response = await route_orchestrated_disaggregated_request( - request, endpoint, background_tasks + router = request.app.state.router + service_discovery = get_service_discovery() + all_eps = service_discovery.get_endpoint_info() + requested_model = (request_json or {}).get("model", "") + aliases = getattr(service_discovery, "aliases", None) + if aliases and requested_model in aliases: + alias_config = normalize_alias_config( + requested_model, aliases[requested_model] + ) + requested_model = alias_config.model + model_eps = [ + e for e in all_eps + if requested_model in e.model_names and not e.sleep + ] + has_prefill = any( + e.model_label in router.prefill_model_labels for e in model_eps ) - return response + has_decode = any( + e.model_label in router.decode_model_labels for e in model_eps + ) + if has_prefill and has_decode: + response = await route_orchestrated_disaggregated_request( + request, endpoint, background_tasks + ) + return response + in_router_time = time.time() # Same as vllm, Get request_id from X-Request-Id header if available request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4()) - request_body = await request.body() - request_json = json.loads(request_body) if request_body else {} # OpenTelemetry tracing: extract incoming context and create parent span span, span_context = None, None