Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5111db5
initial changes
meher-m Sep 23, 2025
3c42d1e
Merge branch 'main' into meher-m/vllm-upgrade
meher-m Sep 25, 2025
71b5d99
reverting some forwarder changes that aren't needed
meher-m Sep 25, 2025
8d30ab3
remove some other unneeded stuff
meher-m Sep 25, 2025
22a0cf9
not sure
meher-m Sep 25, 2025
28cea65
adding cpu
meher-m Sep 25, 2025
6ea87b4
add column
meher-m Sep 25, 2025
5d8a634
add file for db model change
meher-m Sep 25, 2025
5229c55
update readme instructions
meher-m Sep 25, 2025
1b7414c
fix column name
meher-m Sep 25, 2025
494ea1e
reformat
meher-m Sep 25, 2025
43b5054
remove unused commits
meher-m Sep 25, 2025
a2e4b50
fix
meher-m Sep 25, 2025
8c4930f
fix readme
meher-m Sep 25, 2025
ff8766b
fix types
meher-m Sep 25, 2025
4b2103e
leave the existing variables for backwards compatibility
meher-m Sep 25, 2025
29877c3
edit types
meher-m Sep 25, 2025
24cc393
remove EXTRA ROUTES completely. its not used by the async or triton e…
meher-m Sep 25, 2025
68fd91d
adding FORWARDER_SYNC_ROUTES and FORWARDER_STREAMING_ROUTES to the tr…
meher-m Sep 25, 2025
8b623b3
change to pass unit tests
meher-m Sep 25, 2025
cbb0e05
update orm
meher-m Sep 26, 2025
fd3ad7a
test change
meher-m Sep 26, 2025
9b59d4d
change test bundle
meher-m Sep 26, 2025
fa1a5b3
add debug logs
meher-m Sep 26, 2025
edcf542
trying to fix
meher-m Sep 26, 2025
3397278
changes
meher-m Sep 26, 2025
c849a5e
cleanup debug code
meher-m Sep 26, 2025
ca603dd
reformat
meher-m Sep 26, 2025
9a5f789
remove 1
meher-m Sep 26, 2025
6affb1a
remove 2
meher-m Sep 27, 2025
b25e7a7
revert 3
meher-m Sep 27, 2025
dcca5a8
reorder params
meher-m Sep 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions charts/model-engine/templates/service_template_config_map.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,6 @@ data:
- --set
- "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}"
- --set
- "forwarder.sync.extra_routes=${FORWARDER_EXTRA_ROUTES}"
- --set
- "forwarder.stream.extra_routes=${FORWARDER_EXTRA_ROUTES}"
- --set
- "forwarder.sync.forwarder_type=${FORWARDER_TYPE}"
- --set
- "forwarder.stream.forwarder_type=${FORWARDER_TYPE}"
Expand Down Expand Up @@ -370,7 +366,7 @@ data:
name: {{ $service_template_aws_config_map_name }}
{{- else }}
name: {{ $aws_config_map_name }}
{{- end }}
{{- end }}
{{- end }}
- name: user-config
configMap:
Expand Down Expand Up @@ -487,15 +483,15 @@ data:
threshold: "${CONCURRENCY}"
metricName: request_concurrency_average
query: sum(rate(istio_request_duration_milliseconds_sum{destination_workload="${RESOURCE_NAME}"}[2m])) / 1000
serverAddress: ${PROMETHEUS_SERVER_ADDRESS}
serverAddress: ${PROMETHEUS_SERVER_ADDRESS}
{{- range $device := tuple "gpu" }}
{{- range $mode := tuple "streaming"}}
leader-worker-set-{{ $mode }}-{{ $device }}.yaml: |-
apiVersion: leaderworkerset.x-k8s.io/v1
kind: LeaderWorkerSet
metadata:
name: ${RESOURCE_NAME}
namespace: ${NAMESPACE}
name: ${RESOURCE_NAME}
namespace: ${NAMESPACE}
labels:
{{- $service_template_labels | nindent 8 }}
spec:
Expand Down Expand Up @@ -617,10 +613,6 @@ data:
- --set
- "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}"
- --set
- "forwarder.sync.extra_routes=${FORWARDER_EXTRA_ROUTES}"
- --set
- "forwarder.stream.extra_routes=${FORWARDER_EXTRA_ROUTES}"
- --set
- "forwarder.sync.forwarder_type=${FORWARDER_TYPE}"
- --set
- "forwarder.stream.forwarder_type=${FORWARDER_TYPE}"
Expand Down Expand Up @@ -748,7 +740,7 @@ data:
name: {{ $service_template_aws_config_map_name }}
{{- else }}
name: {{ $aws_config_map_name }}
{{- end }}
{{- end }}
{{- end }}
- name: user-config
configMap:
Expand Down Expand Up @@ -856,7 +848,7 @@ data:
name: {{ $service_template_aws_config_map_name }}
{{- else }}
name: {{ $aws_config_map_name }}
{{- end }}
{{- end }}
{{- end }}
- name: user-config
configMap:
Expand Down
5 changes: 4 additions & 1 deletion model-engine/model_engine_server/db/migrations/README
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ We introduce alembic by
1. dumping the current db schemas into 'initial.sql' via pg_dump

```
pg_dump -h $HOST -U postgres -O -s -d $DB_NAME -n hosted_model_inference -n model -f initial.sql
pg_dump -h $HOST -U postgres -O -s -d $DB_NAME -n hosted_model_inference -n model -f initial.sql
```

2. writing an initial revision that reads and applies intial.sql script
Expand All @@ -19,6 +19,9 @@ alembic revision -m “initial”
alembic stamp fa3267c80731
```

# Steps to make generic database schema changes

Steps can be found here: https://alembic.sqlalchemy.org/en/latest/tutorial.html#running-our-second-migration

# Test db migration from scratch

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""add routes column

Revision ID: 221aa19d3f32
Revises: e580182d6bfd
Create Date: 2025-09-25 19:40:24.927198

"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = '221aa19d3f32'
down_revision = 'e580182d6bfd'
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column(
'bundles',
sa.Column('runnable_image_routes', sa.ARRAY(sa.Text), nullable=True),
schema='hosted_model_inference',
)


def downgrade() -> None:
op.drop_column(
'bundles',
'runnable_image_routes',
schema='hosted_model_inference',
)
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class Bundle(Base):
runnable_image_env = Column(JSON, nullable=True)
runnable_image_protocol = Column(Text, nullable=True)
runnable_image_readiness_initial_delay_seconds = Column(Integer, nullable=True)
runnable_image_routes = Column(ARRAY(Text), nullable=True)
runnable_image_extra_routes = Column(ARRAY(Text), nullable=True)
runnable_image_forwarder_type = Column(Text, nullable=True)
runnable_image_worker_command = Column(ARRAY(Text), nullable=True)
Expand Down Expand Up @@ -209,6 +210,7 @@ def __init__(
runnable_image_env: Optional[Dict[str, Any]] = None,
runnable_image_protocol: Optional[str] = None,
runnable_image_readiness_initial_delay_seconds: Optional[int] = None,
runnable_image_routes: Optional[List[str]] = None,
runnable_image_extra_routes: Optional[List[str]] = None,
runnable_image_forwarder_type: Optional[str] = None,
runnable_image_worker_command: Optional[List[str]] = None,
Expand Down Expand Up @@ -268,6 +270,7 @@ def __init__(
self.runnable_image_healthcheck_route = runnable_image_healthcheck_route
self.runnable_image_env = runnable_image_env
self.runnable_image_protocol = runnable_image_protocol
self.runnable_image_routes = runnable_image_routes
self.runnable_image_extra_routes = runnable_image_extra_routes
self.runnable_image_forwarder_type = runnable_image_forwarder_type
self.runnable_image_worker_command = runnable_image_worker_command
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ async def create_vllm_bundle(
healthcheck_route="/health",
predict_route="/predict",
streaming_predict_route="/stream",
extra_routes=[
routes=[
OPENAI_CHAT_COMPLETION_PATH,
OPENAI_COMPLETION_PATH,
],
Expand Down Expand Up @@ -1101,7 +1101,7 @@ async def create_vllm_multinode_bundle(
healthcheck_route="/health",
predict_route="/predict",
streaming_predict_route="/stream",
extra_routes=[OPENAI_CHAT_COMPLETION_PATH, OPENAI_COMPLETION_PATH],
routes=[OPENAI_CHAT_COMPLETION_PATH, OPENAI_COMPLETION_PATH],
env=common_vllm_envs,
worker_command=worker_command,
worker_env=common_vllm_envs,
Expand Down
181 changes: 5 additions & 176 deletions model-engine/model_engine_server/inference/vllm/vllm_server.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,14 @@
import asyncio
import code
import json
import os
import signal
import subprocess
import traceback
from logging import Logger
from typing import AsyncGenerator, Dict, List, Optional

import vllm.envs as envs
from fastapi import APIRouter, BackgroundTasks, Request
from fastapi.responses import Response, StreamingResponse
from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.openai.api_server import (
build_app,
build_async_engine_client,
init_app_state,
load_log_config,
maybe_register_tokenizer_info_endpoint,
setup_server,
)
from vllm.entrypoints.openai.api_server import run_server
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob
from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.utils import FlexibleArgumentParser

logger = Logger("vllm_server")

Expand All @@ -36,88 +17,8 @@
TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds

router = APIRouter()


@router.post("/predict")
@router.post("/stream")
async def generate(request: Request) -> Response:
"""Generate completion for the request.

The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
# check health before accepting request and fail fast if engine isn't healthy
try:
await engine_client.check_health()

request_dict = await request.json()
prompt = request_dict.pop("prompt")
stream = request_dict.pop("stream", False)

sampling_params = SamplingParams(**request_dict)

request_id = random_uuid()

results_generator = engine_client.generate(prompt, sampling_params, request_id)

async def abort_request() -> None:
await engine_client.abort(request_id)

if stream:
# Streaming case
async def stream_results() -> AsyncGenerator[str, None]:
last_output_text = ""
async for request_output in results_generator:
log_probs = format_logprobs(request_output)
ret = {
"text": request_output.outputs[-1].text[len(last_output_text) :],
"count_prompt_tokens": len(request_output.prompt_token_ids),
"count_output_tokens": len(request_output.outputs[0].token_ids),
"log_probs": (
log_probs[-1] if log_probs and sampling_params.logprobs else None
),
"finished": request_output.finished,
}
last_output_text = request_output.outputs[-1].text
yield f"data:{json.dumps(ret)}\n\n"

background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)

return StreamingResponse(stream_results(), background=background_tasks)

# Non-streaming case
final_output = None
tokens = []
last_output_text = ""
async for request_output in results_generator:
tokens.append(request_output.outputs[-1].text[len(last_output_text) :])
last_output_text = request_output.outputs[-1].text
if await request.is_disconnected():
# Abort the request if the client disconnects.
await engine_client.abort(request_id)
return Response(status_code=499)
final_output = request_output

assert final_output is not None
prompt = final_output.prompt
ret = {
"text": final_output.outputs[0].text,
"count_prompt_tokens": len(final_output.prompt_token_ids),
"count_output_tokens": len(final_output.outputs[0].token_ids),
"log_probs": format_logprobs(final_output),
"tokens": tokens,
}
return Response(content=json.dumps(ret))

except AsyncEngineDeadError as e:
logger.error(f"The vllm engine is dead, exiting the pod: {e}")
os.kill(os.getpid(), signal.SIGINT)
raise e
# Legacy endpoints /predit and /stream removed - using vLLM's native OpenAI-compatible endpoints instead
# All requests now go through /v1/completions, /v1/chat/completions, etc.


def get_gpu_free_memory():
Expand Down Expand Up @@ -171,90 +72,18 @@ def debug(sig, frame):
i.interact(message)


def format_logprobs(
request_output: CompletionOutput,
) -> Optional[List[Dict[int, float]]]:
"""Given a request output, format the logprobs if they exist."""
output_logprobs = request_output.outputs[0].logprobs
if output_logprobs is None:
return None

def extract_logprobs(logprobs: Dict[int, Logprob]) -> Dict[int, float]:
return {k: v.logprob for k, v in logprobs.items()}

return [extract_logprobs(logprobs) for logprobs in output_logprobs]


def parse_args(parser: FlexibleArgumentParser):
parser = make_arg_parser(parser)
parser.add_argument("--attention-backend", type=str, help="The attention backend to use")
Copy link
Collaborator

@dmchoiboi dmchoiboi Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can remove run_server_worker and run_server, and just use run_server from vllm

return parser.parse_args()


async def run_server(args, **uvicorn_kwargs) -> None:
"""Run a single-worker API server."""
listen_address, sock = setup_server(args)
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)


async def run_server_worker(
listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
"""Run a single API server worker."""

if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)

server_index = client_config.get("client_index", 0) if client_config else 0

# Load logging config for uvicorn if specified
log_config = load_log_config(args.log_config_file)
if log_config is not None:
uvicorn_kwargs["log_config"] = log_config

global engine_client

async with build_async_engine_client(args, client_config=client_config) as engine_client:
maybe_register_tokenizer_info_endpoint(args)
app = build_app(args)

vllm_config = await engine_client.get_vllm_config()
await init_app_state(engine_client, vllm_config, app.state, args)
app.include_router(router)

logger.info("Starting vLLM API server %d on %s", server_index, listen_address)
shutdown_task = await serve_http(
app,
sock=sock,
enable_ssl_refresh=args.enable_ssl_refresh,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
# NOTE: When the 'disable_uvicorn_access_log' value is True,
# no access log will be output.
access_log=not args.disable_uvicorn_access_log,
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
h11_max_incomplete_event_size=args.h11_max_incomplete_event_size,
h11_max_header_count=args.h11_max_header_count,
**uvicorn_kwargs,
)

# NB: Await server shutdown only after the backend context is exited
try:
await shutdown_task
finally:
sock.close()


if __name__ == "__main__":
check_unknown_startup_memory_usage()

parser = FlexibleArgumentParser()
args = parse_args(parser)
if args.attention_backend is not None:
os.environ["VLLM_ATTENTION_BACKEND"] = args.attention_backend
# Using vllm's run_server
asyncio.run(run_server(args))
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ async def streaming_predict(
if predict_request.num_retries is None
else predict_request.num_retries
)

response = self.make_request_with_retries(
request_url=deployment_url,
payload_json=predict_request.model_dump(exclude_none=True),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ async def predict(
if predict_request.num_retries is None
else predict_request.num_retries
)

response = await self.make_request_with_retries(
request_url=deployment_url,
payload_json=predict_request.model_dump(exclude_none=True),
Expand Down
Loading