diff --git a/model-engine/model_engine_server/inference/forwarding/echo_server.py b/model-engine/model_engine_server/inference/forwarding/echo_server.py index 12470cfc..0a44b832 100644 --- a/model-engine/model_engine_server/inference/forwarding/echo_server.py +++ b/model-engine/model_engine_server/inference/forwarding/echo_server.py @@ -33,7 +33,7 @@ def entrypoint(): parser.add_argument("--host", type=str, default="[::]") parser.add_argument("--port", type=int, default=5009) - args = parser.parse_args() + args, extra_args = parser.parse_known_args() command = [ "gunicorn", @@ -48,6 +48,7 @@ def entrypoint(): "--workers", str(args.num_workers), "model_engine_server.inference.forwarding.echo_server:app", + *extra_args, ] subprocess.run(command) diff --git a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py index 85de6ded..5943bc50 100644 --- a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py +++ b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py @@ -138,8 +138,9 @@ def entrypoint(): parser.add_argument("--host", type=str, default="[::]") parser.add_argument("--port", type=int, default=5000) parser.add_argument("--set", type=str, action="append") + parser.add_argument("--graceful-timeout", type=int, default=600) - args = parser.parse_args() + args, extra_args = parser.parse_known_args() values = [f"CONFIG_FILE={args.config}"] if args.set is not None: @@ -160,8 +161,11 @@ def entrypoint(): "uvicorn.workers.UvicornWorker", "--workers", str(args.num_workers), + "--graceful-timeout", + str(args.graceful_timeout), *envs, "model_engine_server.inference.forwarding.http_forwarder:app", + *extra_args, ] subprocess.run(command) diff --git a/model-engine/model_engine_server/inference/sync_inference/start_fastapi_server.py b/model-engine/model_engine_server/inference/sync_inference/start_fastapi_server.py index 97aea0ed..2b3aef79 100644 --- a/model-engine/model_engine_server/inference/sync_inference/start_fastapi_server.py +++ b/model-engine/model_engine_server/inference/sync_inference/start_fastapi_server.py @@ -1,3 +1,4 @@ +import argparse import os import subprocess @@ -8,6 +9,10 @@ def start_server(): + parser = argparse.ArgumentParser() + parser.add_argument("--graceful-timeout", type=int, default=600) + args, extra_args = parser.parse_known_args() + # TODO: HTTPS command = [ "gunicorn", @@ -21,7 +26,10 @@ def start_server(): "uvicorn.workers.UvicornWorker", "--workers", str(NUM_PROCESSES), + "--graceful-timeout", + str(args.graceful_timeout), "model_engine_server.inference.sync_inference.fastapi_server:app", + *extra_args, ] unset_sensitive_envvars() subprocess.run(command)