-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix streaming endpoint failure handling #314
Changes from all commits
1ebc95f
93e009b
0e46b27
ff37909
ff2083d
9143da2
2f744ec
4896762
b28baa2
21b10e9
9fb8d5d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,10 @@ | ||
"""LLM Model Endpoint routes for the hosted model inference service. | ||
""" | ||
import traceback | ||
from datetime import datetime | ||
from typing import Optional | ||
|
||
import pytz | ||
from fastapi import APIRouter, Depends, HTTPException, Query | ||
from model_engine_server.api.dependencies import ( | ||
ExternalInterfaces, | ||
|
@@ -28,6 +31,8 @@ | |
ListLLMModelEndpointsV1Response, | ||
ModelDownloadRequest, | ||
ModelDownloadResponse, | ||
StreamError, | ||
StreamErrorContent, | ||
) | ||
from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy | ||
from model_engine_server.core.auth.authentication_repository import User | ||
|
@@ -71,6 +76,34 @@ | |
logger = make_logger(filename_wo_ext(__name__)) | ||
|
||
|
||
def handle_streaming_exception( | ||
e: Exception, | ||
code: int, | ||
message: str, | ||
): | ||
tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) | ||
request_id = get_request_id() | ||
timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") | ||
structured_log = { | ||
"error": message, | ||
"request_id": str(request_id), | ||
"traceback": "".join(tb_str), | ||
} | ||
logger.error("Exception: %s", structured_log) | ||
return { | ||
"data": CompletionStreamV1Response( | ||
request_id=str(request_id), | ||
error=StreamError( | ||
status_code=code, | ||
content=StreamErrorContent( | ||
error=message, | ||
timestamp=timestamp, | ||
), | ||
), | ||
).json() | ||
} | ||
|
||
|
||
@llm_router_v1.post("/model-endpoints", response_model=CreateLLMModelEndpointV1Response) | ||
async def create_model_endpoint( | ||
request: CreateLLMModelEndpointV1Request, | ||
|
@@ -226,42 +259,30 @@ async def create_completion_stream_task( | |
logger.info( | ||
f"POST /completion_stream with {request} to endpoint {model_endpoint_name} for {auth}" | ||
) | ||
try: | ||
use_case = CompletionStreamV1UseCase( | ||
model_endpoint_service=external_interfaces.model_endpoint_service, | ||
llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, | ||
) | ||
response = use_case.execute( | ||
user=auth, model_endpoint_name=model_endpoint_name, request=request | ||
) | ||
use_case = CompletionStreamV1UseCase( | ||
model_endpoint_service=external_interfaces.model_endpoint_service, | ||
llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, | ||
) | ||
response = use_case.execute(user=auth, model_endpoint_name=model_endpoint_name, request=request) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here |
||
|
||
async def event_generator(): | ||
try: | ||
async for message in response: | ||
yield {"data": message.json()} | ||
except InvalidRequestException as exc: | ||
yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}} | ||
return | ||
async def event_generator(): | ||
try: | ||
async for message in response: | ||
yield {"data": message.json()} | ||
except (InvalidRequestException, ObjectHasInvalidValueException) as exc: | ||
yield handle_streaming_exception(exc, 400, str(exc)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just for my edification, was the primary fix for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. exception in |
||
except ( | ||
ObjectNotFoundException, | ||
ObjectNotAuthorizedException, | ||
EndpointUnsupportedInferenceTypeException, | ||
) as exc: | ||
yield handle_streaming_exception(exc, 404, str(exc)) | ||
except Exception as exc: | ||
yield handle_streaming_exception( | ||
exc, 500, "Internal error occurred. Our team has been notified." | ||
) | ||
|
||
return EventSourceResponse(event_generator()) | ||
except UpstreamServiceError: | ||
request_id = get_request_id() | ||
logger.exception(f"Upstream service error for request {request_id}") | ||
return EventSourceResponse( | ||
iter((CompletionStreamV1Response(request_id=request_id).json(),)) # type: ignore | ||
) | ||
except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: | ||
raise HTTPException( | ||
status_code=404, | ||
detail="The specified endpoint could not be found.", | ||
) from exc | ||
except ObjectHasInvalidValueException as exc: | ||
raise HTTPException(status_code=400, detail=str(exc)) | ||
except EndpointUnsupportedInferenceTypeException as exc: | ||
raise HTTPException( | ||
status_code=400, | ||
detail=f"Unsupported inference type: {str(exc)}", | ||
) from exc | ||
return EventSourceResponse(event_generator()) | ||
|
||
|
||
@llm_router_v1.post("/fine-tunes", response_model=CreateFineTuneResponse) | ||
|
@@ -405,12 +426,12 @@ async def delete_llm_model_endpoint( | |
model_endpoint_service=external_interfaces.model_endpoint_service, | ||
) | ||
return await use_case.execute(user=auth, model_endpoint_name=model_endpoint_name) | ||
except (ObjectNotFoundException) as exc: | ||
except ObjectNotFoundException as exc: | ||
raise HTTPException( | ||
status_code=404, | ||
detail="The requested model endpoint could not be found.", | ||
) from exc | ||
except (ObjectNotAuthorizedException) as exc: | ||
except ObjectNotAuthorizedException as exc: | ||
raise HTTPException( | ||
status_code=403, | ||
detail="You don't have permission to delete the requested model endpoint.", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -113,6 +113,32 @@ def test_completion_sync_success( | |
assert response_1.json().keys() == {"output", "request_id"} | ||
|
||
|
||
def test_completion_sync_endpoint_not_found_returns_404( | ||
llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], | ||
completion_sync_request: Dict[str, Any], | ||
get_test_client_wrapper, | ||
): | ||
client = get_test_client_wrapper( | ||
fake_docker_repository_image_always_exists=True, | ||
fake_model_bundle_repository_contents={}, | ||
fake_model_endpoint_record_repository_contents={}, | ||
fake_model_endpoint_infra_gateway_contents={ | ||
llm_model_endpoint_sync[0] | ||
.infra_state.deployment_name: llm_model_endpoint_sync[0] | ||
.infra_state, | ||
}, | ||
fake_batch_job_record_repository_contents={}, | ||
fake_batch_job_progress_gateway_contents={}, | ||
fake_docker_image_batch_job_bundle_repository_contents={}, | ||
) | ||
response_1 = client.post( | ||
f"/v1/llm/completions-sync?model_endpoint_name={llm_model_endpoint_sync[0].record.name}", | ||
auth=("no_user", ""), | ||
json=completion_sync_request, | ||
) | ||
assert response_1.status_code == 404 | ||
|
||
|
||
@pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we still want to skip this test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, unfortunately i still haven't figured out how to run two streaming tests in a row. there's something wrong about how test client uses event loop that i wasn't able to fix: more context in encode/starlette#1315 |
||
def test_completion_stream_success( | ||
llm_model_endpoint_streaming: ModelEndpoint, | ||
|
@@ -136,6 +162,7 @@ def test_completion_stream_success( | |
f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}", | ||
auth=("no_user", ""), | ||
json=completion_stream_request, | ||
stream=True, | ||
) | ||
assert response_1.status_code == 200 | ||
count = 0 | ||
|
@@ -146,3 +173,33 @@ def test_completion_stream_success( | |
) | ||
count += 1 | ||
assert count == 1 | ||
|
||
|
||
@pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness") | ||
def test_completion_stream_endpoint_not_found_returns_404( | ||
llm_model_endpoint_streaming: ModelEndpoint, | ||
completion_stream_request: Dict[str, Any], | ||
get_test_client_wrapper, | ||
): | ||
client = get_test_client_wrapper( | ||
fake_docker_repository_image_always_exists=True, | ||
fake_model_bundle_repository_contents={}, | ||
fake_model_endpoint_record_repository_contents={}, | ||
fake_model_endpoint_infra_gateway_contents={ | ||
llm_model_endpoint_streaming.infra_state.deployment_name: llm_model_endpoint_streaming.infra_state, | ||
}, | ||
fake_batch_job_record_repository_contents={}, | ||
fake_batch_job_progress_gateway_contents={}, | ||
fake_docker_image_batch_job_bundle_repository_contents={}, | ||
) | ||
response_1 = client.post( | ||
f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}", | ||
auth=("no_user", ""), | ||
json=completion_stream_request, | ||
stream=True, | ||
) | ||
|
||
assert response_1.status_code == 200 | ||
|
||
for message in response_1: | ||
assert "404" in message.decode("utf-8") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't be using clients/python mypy checking server files