Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix streaming endpoint failure handling #314

Merged
merged 11 commits into from
Oct 11, 2023
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ repos:
hooks:
- id: mypy
name: mypy-clients-python
files: clients/python/.*
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't be using clients/python mypy checking server files

entry: mypy --config-file clients/python/mypy.ini
language: system
- repo: https://github.com/pre-commit/mirrors-mypy
Expand Down
21 changes: 21 additions & 0 deletions clients/python/llmengine/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,24 @@ class CompletionStreamOutput(BaseModel):
"""Detailed token information."""


class StreamErrorContent(BaseModel):
error: str
"""Error message."""
timestamp: str
"""Timestamp of the error."""


class StreamError(BaseModel):
"""
Error object for a stream prompt completion task.
"""

status_code: int
"""The HTTP status code of the error."""
content: StreamErrorContent
"""The error content."""


class CompletionStreamResponse(BaseModel):
"""
Response object for a stream prompt completion task.
Expand All @@ -372,6 +390,9 @@ class CompletionStreamResponse(BaseModel):
output: Optional[CompletionStreamOutput] = None
"""Completion output."""

error: Optional[StreamError] = None
"""Error of the response (if any)."""


class CreateFineTuneRequest(BaseModel):
"""
Expand Down
11 changes: 5 additions & 6 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,10 @@ stream = Completion.create(
)

for response in stream:
try:
if response.output:
print(response.output.text, end="")
sys.stdout.flush()
except: # an error occurred
print(stream.text) # print the error message out
if response.output:
print(response.output.text, end="")
sys.stdout.flush()
else: # an error occurred
print(response.error) # print the error message out
break
```
11 changes: 5 additions & 6 deletions docs/guides/completions.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,11 @@ stream = Completion.create(
)

for response in stream:
try:
if response.output:
print(response.output.text, end="")
sys.stdout.flush()
except: # an error occurred
print(stream.text) # print the error message out
if response.output:
print(response.output.text, end="")
sys.stdout.flush()
else: # an error occurred
print(response.error) # print the error message out
break
```

Expand Down
93 changes: 57 additions & 36 deletions model-engine/model_engine_server/api/llms_v1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""LLM Model Endpoint routes for the hosted model inference service.
"""
import traceback
from datetime import datetime
from typing import Optional

import pytz
from fastapi import APIRouter, Depends, HTTPException, Query
from model_engine_server.api.dependencies import (
ExternalInterfaces,
Expand All @@ -28,6 +31,8 @@
ListLLMModelEndpointsV1Response,
ModelDownloadRequest,
ModelDownloadResponse,
StreamError,
StreamErrorContent,
)
from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy
from model_engine_server.core.auth.authentication_repository import User
Expand Down Expand Up @@ -71,6 +76,34 @@
logger = make_logger(filename_wo_ext(__name__))


def handle_streaming_exception(
e: Exception,
code: int,
message: str,
):
tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)
request_id = get_request_id()
timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z")
structured_log = {
"error": message,
"request_id": str(request_id),
"traceback": "".join(tb_str),
}
logger.error("Exception: %s", structured_log)
return {
"data": CompletionStreamV1Response(
request_id=str(request_id),
error=StreamError(
status_code=code,
content=StreamErrorContent(
error=message,
timestamp=timestamp,
),
),
).json()
}


@llm_router_v1.post("/model-endpoints", response_model=CreateLLMModelEndpointV1Response)
async def create_model_endpoint(
request: CreateLLMModelEndpointV1Request,
Expand Down Expand Up @@ -226,42 +259,30 @@ async def create_completion_stream_task(
logger.info(
f"POST /completion_stream with {request} to endpoint {model_endpoint_name} for {auth}"
)
try:
use_case = CompletionStreamV1UseCase(
model_endpoint_service=external_interfaces.model_endpoint_service,
llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service,
)
response = use_case.execute(
user=auth, model_endpoint_name=model_endpoint_name, request=request
)
use_case = CompletionStreamV1UseCase(
model_endpoint_service=external_interfaces.model_endpoint_service,
llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service,
)
response = use_case.execute(user=auth, model_endpoint_name=model_endpoint_name, request=request)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here use_case.execute doesn't actually execute until L237


async def event_generator():
try:
async for message in response:
yield {"data": message.json()}
except InvalidRequestException as exc:
yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}}
return
async def event_generator():
try:
async for message in response:
yield {"data": message.json()}
except (InvalidRequestException, ObjectHasInvalidValueException) as exc:
yield handle_streaming_exception(exc, 400, str(exc))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my edification, was the primary fix for urllib3.exceptions.ProtocolError to push the exception handling inside the generator with yield?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exception in use_case.execute won't throw until async for message in response:, and since originally we don't capture exceptions other than InvalidRequestException, my understanding is 0 bytes would get returned and client side throws error about not able to decode

except (
ObjectNotFoundException,
ObjectNotAuthorizedException,
EndpointUnsupportedInferenceTypeException,
) as exc:
yield handle_streaming_exception(exc, 404, str(exc))
except Exception as exc:
yield handle_streaming_exception(
exc, 500, "Internal error occurred. Our team has been notified."
)

return EventSourceResponse(event_generator())
except UpstreamServiceError:
request_id = get_request_id()
logger.exception(f"Upstream service error for request {request_id}")
return EventSourceResponse(
iter((CompletionStreamV1Response(request_id=request_id).json(),)) # type: ignore
)
except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc:
raise HTTPException(
status_code=404,
detail="The specified endpoint could not be found.",
) from exc
except ObjectHasInvalidValueException as exc:
raise HTTPException(status_code=400, detail=str(exc))
except EndpointUnsupportedInferenceTypeException as exc:
raise HTTPException(
status_code=400,
detail=f"Unsupported inference type: {str(exc)}",
) from exc
return EventSourceResponse(event_generator())


@llm_router_v1.post("/fine-tunes", response_model=CreateFineTuneResponse)
Expand Down Expand Up @@ -405,12 +426,12 @@ async def delete_llm_model_endpoint(
model_endpoint_service=external_interfaces.model_endpoint_service,
)
return await use_case.execute(user=auth, model_endpoint_name=model_endpoint_name)
except (ObjectNotFoundException) as exc:
except ObjectNotFoundException as exc:
raise HTTPException(
status_code=404,
detail="The requested model endpoint could not be found.",
) from exc
except (ObjectNotAuthorizedException) as exc:
except ObjectNotAuthorizedException as exc:
raise HTTPException(
status_code=403,
detail="You don't have permission to delete the requested model endpoint.",
Expand Down
20 changes: 20 additions & 0 deletions model-engine/model_engine_server/common/dtos/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,33 @@ class CompletionStreamOutput(BaseModel):
token: Optional[TokenOutput] = None


class StreamErrorContent(BaseModel):
error: str
"""Error message."""
timestamp: str
"""Timestamp of the error."""


class StreamError(BaseModel):
"""
Error object for a stream prompt completion task.
"""

status_code: int
"""The HTTP status code of the error."""
content: StreamErrorContent
"""The error content."""


class CompletionStreamV1Response(BaseModel):
"""
Response object for a stream prompt completion task.
"""

request_id: str
output: Optional[CompletionStreamOutput] = None
error: Optional[StreamError] = None
"""Error of the response (if any)."""


class CreateFineTuneRequest(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1308,7 +1308,7 @@ async def execute(
)

if len(model_endpoints) == 0:
raise ObjectNotFoundException
raise ObjectNotFoundException(f"Model endpoint {model_endpoint_name} not found.")

if len(model_endpoints) > 1:
raise ObjectHasInvalidValueException(
Expand Down
57 changes: 57 additions & 0 deletions model-engine/tests/unit/api/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,32 @@ def test_completion_sync_success(
assert response_1.json().keys() == {"output", "request_id"}


def test_completion_sync_endpoint_not_found_returns_404(
llm_model_endpoint_sync: Tuple[ModelEndpoint, Any],
completion_sync_request: Dict[str, Any],
get_test_client_wrapper,
):
client = get_test_client_wrapper(
fake_docker_repository_image_always_exists=True,
fake_model_bundle_repository_contents={},
fake_model_endpoint_record_repository_contents={},
fake_model_endpoint_infra_gateway_contents={
llm_model_endpoint_sync[0]
.infra_state.deployment_name: llm_model_endpoint_sync[0]
.infra_state,
},
fake_batch_job_record_repository_contents={},
fake_batch_job_progress_gateway_contents={},
fake_docker_image_batch_job_bundle_repository_contents={},
)
response_1 = client.post(
f"/v1/llm/completions-sync?model_endpoint_name={llm_model_endpoint_sync[0].record.name}",
auth=("no_user", ""),
json=completion_sync_request,
)
assert response_1.status_code == 404


@pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still want to skip this test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, unfortunately i still haven't figured out how to run two streaming tests in a row. there's something wrong about how test client uses event loop that i wasn't able to fix: more context in encode/starlette#1315

def test_completion_stream_success(
llm_model_endpoint_streaming: ModelEndpoint,
Expand All @@ -136,6 +162,7 @@ def test_completion_stream_success(
f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}",
auth=("no_user", ""),
json=completion_stream_request,
stream=True,
)
assert response_1.status_code == 200
count = 0
Expand All @@ -146,3 +173,33 @@ def test_completion_stream_success(
)
count += 1
assert count == 1


@pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness")
def test_completion_stream_endpoint_not_found_returns_404(
llm_model_endpoint_streaming: ModelEndpoint,
completion_stream_request: Dict[str, Any],
get_test_client_wrapper,
):
client = get_test_client_wrapper(
fake_docker_repository_image_always_exists=True,
fake_model_bundle_repository_contents={},
fake_model_endpoint_record_repository_contents={},
fake_model_endpoint_infra_gateway_contents={
llm_model_endpoint_streaming.infra_state.deployment_name: llm_model_endpoint_streaming.infra_state,
},
fake_batch_job_record_repository_contents={},
fake_batch_job_progress_gateway_contents={},
fake_docker_image_batch_job_bundle_repository_contents={},
)
response_1 = client.post(
f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}",
auth=("no_user", ""),
json=completion_stream_request,
stream=True,
)

assert response_1.status_code == 200

for message in response_1:
assert "404" in message.decode("utf-8")
1 change: 1 addition & 0 deletions model-engine/tests/unit/api/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def test_create_streaming_task_success(
f"/v1/streaming-tasks?model_endpoint_id={model_endpoint_streaming.record.id}",
auth=(test_api_key, ""),
json=endpoint_predict_request_1[1],
stream=True,
)
assert response.status_code == 200
count = 0
Expand Down