Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/vllm_router/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def initialize_all(app: FastAPI, args):
max_instance_failover_reroute_attempts=args.max_instance_failover_reroute_attempts,
lmcache_health_check_interval=args.lmcache_health_check_interval,
lmcache_worker_timeout=args.lmcache_worker_timeout,
fallback_routing_logic=args.fallback_routing_logic,
)

# Initialize feature gates
Expand Down
6 changes: 5 additions & 1 deletion src/vllm_router/dynamic_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class DynamicRouterConfig:

# Routing logic configurations
session_key: Optional[str] = None
fallback_routing_logic: Optional[str] = None

# Logging Options
callbacks: Optional[str] = None
Expand Down Expand Up @@ -109,6 +110,7 @@ def from_args(args) -> "DynamicRouterConfig":
# Routing logic configurations
routing_logic=args.routing_logic,
session_key=args.session_key,
fallback_routing_logic=args.fallback_routing_logic,
# Logging Options
callbacks=args.callbacks,
)
Expand Down Expand Up @@ -224,7 +226,9 @@ def reconfigure_routing_logic(self, config: DynamicRouterConfig):
Reconfigures the router with the given config.
"""
routing_logic = reconfigure_routing_logic(
config.routing_logic, session_key=config.session_key
config.routing_logic,
session_key=config.session_key,
fallback_routing_logic=config.fallback_routing_logic,
)
self.app.state.router = routing_logic
logger.info("DynamicConfigWatcher: Routing logic reconfiguration complete")
Expand Down
8 changes: 8 additions & 0 deletions src/vllm_router/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,14 @@ def parse_args():
help="Use secure (TLS) connection for OTLP exporter. Default is insecure.",
)

parser.add_argument(
"--fallback-routing-logic",
type=str,
choices=["roundrobin", "session"],
default=None,
help="Routing logic for models without prefill/decode endpoints (only used with disaggregated_prefill_orchestrated).",
)

parser.add_argument(
"--prefill-model-labels",
type=str,
Expand Down
13 changes: 12 additions & 1 deletion src/vllm_router/parsers/yaml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def generate_static_models(models: dict[str, Any]) -> str:
static_models = []
for name, details in models.items():
if "static_backends" in details:
static_models.extend([name] * len(details["static_backends"]))
model_name = details.get("model_name", name)
static_models.extend([model_name] * len(details["static_backends"]))
return ",".join(static_models)


Expand Down Expand Up @@ -62,6 +63,15 @@ def generate_static_aliases(aliases: dict[str, Any]) -> str:
return ",".join(parts)


def generate_static_model_labels(models: dict[str, Any]) -> str:
static_model_labels = []
for _, details in models.items():
if "static_backends" in details:
label = details.get("model_label", "default")
static_model_labels.extend([label] * len(details["static_backends"]))
return ",".join(static_model_labels)


def generate_static_model_types(models: dict[str, Any]) -> str:
static_model_types = []
for _, details in models.items():
Expand Down Expand Up @@ -101,6 +111,7 @@ def read_and_process_yaml_config_file(config_path: str) -> dict[str, Any]:
if models:
yaml_config["static_backends"] = generate_static_backends(models)
yaml_config["static_models"] = generate_static_models(models)
yaml_config["static_model_labels"] = generate_static_model_labels(models)
yaml_config["static_model_types"] = generate_static_model_types(models)
yaml_config["static_healthcheck_disabled"] = (
generate_static_healthcheck_disabled(models)
Expand Down
41 changes: 31 additions & 10 deletions src/vllm_router/routers/routing_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,19 +568,35 @@ class DisaggregatedPrefillOrchestratedRouter(RoutingInterface):
Load balancing: Uses round-robin across available prefill and decode pods.
"""

def __init__(self, prefill_model_labels: List[str], decode_model_labels: List[str]):
def __init__(
self,
prefill_model_labels: List[str],
decode_model_labels: List[str],
fallback_routing_logic: Optional[str] = None,
session_key: Optional[str] = None,
):
if hasattr(self, "_initialized"):
return
self.prefill_model_labels = prefill_model_labels or []
self.decode_model_labels = decode_model_labels or []
# Round-robin counters for load balancing across xPyD pods
self.prefill_idx = 0
self.decode_idx = 0

if fallback_routing_logic == "session":
self._fallback_router: RoutingInterface = SessionRouter(session_key)
elif fallback_routing_logic == "roundrobin":
self._fallback_router = RoundRobinRouter()
elif session_key:
self._fallback_router = SessionRouter(session_key)
else:
self._fallback_router = RoundRobinRouter()

self._initialized = True
logger.info(
f"Initialized DisaggregatedPrefillOrchestratedRouter with "
f"prefill_labels={self.prefill_model_labels}, "
f"decode_labels={self.decode_model_labels}"
f"decode_labels={self.decode_model_labels}, "
f"fallback={type(self._fallback_router).__name__}"
)

def _find_endpoints(self, endpoints: List[EndpointInfo]):
Expand Down Expand Up @@ -652,13 +668,15 @@ async def route_request(
request_json: Optional[Dict] = None,
) -> str:
"""
This method is called by the router framework but for orchestrated routing,
we need to handle the full flow differently. This returns the prefill URL
as a placeholder - the actual orchestration happens in route_orchestrated_disaggregated_request.
Fallback routing for models without prefill/decode endpoints.
Delegates to the configured fallback router (session or roundrobin).
P/D models are handled in route_orchestrated_disaggregated_request.
"""
prefiller_endpoints, _ = self._find_endpoints(endpoints)
# Return prefill URL - actual orchestration is done in request.py
return prefiller_endpoints[0].url
if not endpoints:
raise ValueError("No endpoints available")
return await self._fallback_router.route_request(
endpoints, engine_stats, request_stats, request, request_json
)


# Instead of managing a global _global_router, we can define the initialization functions as:
Expand Down Expand Up @@ -696,7 +714,10 @@ def initialize_routing_logic(
elif routing_logic == RoutingLogic.DISAGGREGATED_PREFILL_ORCHESTRATED:
logger.info("Initializing disaggregated prefill orchestrated routing logic")
return DisaggregatedPrefillOrchestratedRouter(
kwargs.get("prefill_model_labels"), kwargs.get("decode_model_labels")
kwargs.get("prefill_model_labels"),
kwargs.get("decode_model_labels"),
fallback_routing_logic=kwargs.get("fallback_routing_logic"),
session_key=kwargs.get("session_key"),
)
else:
raise ValueError(f"Invalid routing logic {routing_logic}")
Expand Down
39 changes: 33 additions & 6 deletions src/vllm_router/services/request_service/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,23 +401,50 @@ async def route_general_request(
Returns:
StreamingResponse: A response object that streams data from the backend server to the client.
"""
# Read body once upfront; Starlette caches it for subsequent reads
request_body = await request.body()
request_json = json.loads(request_body) if request_body else {}

if isinstance(request.app.state.router, DisaggregatedPrefillRouter):
response = await route_disaggregated_prefill_request(
request, endpoint, background_tasks
)
return response

# Handle orchestrated disaggregated inference (NxDI pattern)
# Orchestrated disaggregated inference — only for models that have
# both prefill-labeled and decode-labeled endpoints configured.
# Models without P/D endpoints use standard routing regardless of
# the router type (config-driven, no model names in code).
if isinstance(request.app.state.router, DisaggregatedPrefillOrchestratedRouter):
response = await route_orchestrated_disaggregated_request(
request, endpoint, background_tasks
router = request.app.state.router
service_discovery = get_service_discovery()
all_eps = service_discovery.get_endpoint_info()
requested_model = (request_json or {}).get("model", "")
aliases = getattr(service_discovery, "aliases", None)
if aliases and requested_model in aliases:
alias_config = normalize_alias_config(
requested_model, aliases[requested_model]
)
requested_model = alias_config.model
model_eps = [
e for e in all_eps
if requested_model in e.model_names and not e.sleep
]
has_prefill = any(
e.model_label in router.prefill_model_labels for e in model_eps
)
return response
has_decode = any(
e.model_label in router.decode_model_labels for e in model_eps
)
if has_prefill and has_decode:
response = await route_orchestrated_disaggregated_request(
request, endpoint, background_tasks
)
return response

in_router_time = time.time()
# Same as vllm, Get request_id from X-Request-Id header if available
request_id = request.headers.get("X-Request-Id") or str(uuid.uuid4())
request_body = await request.body()
request_json = json.loads(request_body) if request_body else {}

# OpenTelemetry tracing: extract incoming context and create parent span
span, span_context = None, None
Expand Down
Loading