Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix streaming endpoint failure handling #314

Merged
merged 11 commits into from
Oct 11, 2023
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ repos:
hooks:
- id: mypy
name: mypy-clients-python
files: clients/python/.*
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't be using clients/python mypy checking server files

entry: mypy --config-file clients/python/mypy.ini
language: system
- repo: https://github.com/pre-commit/mirrors-mypy
Expand Down
23 changes: 23 additions & 0 deletions clients/python/llmengine/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,26 @@ class CompletionStreamOutput(BaseModel):
"""Detailed token information."""


class StreamErrorContent(BaseModel):
error: str
"""Error message."""
timestamp: str
"""Timestamp of the error."""
request_id: str
"""Server generated unique ID of the corresponding request."""


class StreamError(BaseModel):
"""
Error object for a stream prompt completion task.
"""

status_code: int
"""The HTTP status code of the error."""
content: StreamErrorContent
"""The error content."""


class CompletionStreamResponse(BaseModel):
"""
Response object for a stream prompt completion task.
Expand All @@ -372,6 +392,9 @@ class CompletionStreamResponse(BaseModel):
output: Optional[CompletionStreamOutput] = None
"""Completion output."""

error: Optional[StreamError] = None
"""Error of the response (if any)."""


class CreateFineTuneRequest(BaseModel):
"""
Expand Down
11 changes: 5 additions & 6 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,10 @@ stream = Completion.create(
)

for response in stream:
try:
if response.output:
print(response.output.text, end="")
sys.stdout.flush()
except: # an error occurred
print(stream.text) # print the error message out
if response.output:
print(response.output.text, end="")
sys.stdout.flush()
else: # an error occurred
print(response.error) # print the error message out
break
```
11 changes: 5 additions & 6 deletions docs/guides/completions.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,11 @@ stream = Completion.create(
)

for response in stream:
try:
if response.output:
print(response.output.text, end="")
sys.stdout.flush()
except: # an error occurred
print(stream.text) # print the error message out
if response.output:
print(response.output.text, end="")
sys.stdout.flush()
else: # an error occurred
print(response.error) # print the error message out
break
```

Expand Down
94 changes: 58 additions & 36 deletions model-engine/model_engine_server/api/llms_v1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""LLM Model Endpoint routes for the hosted model inference service.
"""
import json
import traceback
from datetime import datetime
from typing import Optional

import pytz
from fastapi import APIRouter, Depends, HTTPException, Query
from model_engine_server.api.dependencies import (
ExternalInterfaces,
Expand Down Expand Up @@ -71,6 +75,36 @@
logger = make_logger(filename_wo_ext(__name__))


def handle_streaming_exception(
e: Exception,
code: int,
message: str,
):
tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)
request_id = get_request_id()
timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z")
structured_log = {
"error": message,
"request_id": str(request_id),
"traceback": "".join(tb_str),
}
logger.error("Exception: %s", structured_log)
return {
"data": json.dumps(
{
"request_id": str(request_id),
yunfeng-scale marked this conversation as resolved.
Show resolved Hide resolved
"error": {
"status_code": code,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is a bit pedantic (pydantic? lol), but since you ended up creating the DTOS anyway, wonder if it makes sense to instantiate the DTOs and convert them to JSON? That way you get the typing.

Alternatively, I think we want a unit test to enforce the API contract behind the error return type.

Copy link
Collaborator Author

@yunfeng-scale yunfeng-scale Oct 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do. for test i added test_completion_stream_endpoint_not_found_returns_404 but unfortunately i still haven't figured out how to run it (it works individually)

"content": {
"error": message,
"timestamp": timestamp,
},
},
}
)
}


@llm_router_v1.post("/model-endpoints", response_model=CreateLLMModelEndpointV1Response)
async def create_model_endpoint(
request: CreateLLMModelEndpointV1Request,
Expand Down Expand Up @@ -226,42 +260,30 @@ async def create_completion_stream_task(
logger.info(
f"POST /completion_stream with {request} to endpoint {model_endpoint_name} for {auth}"
)
try:
use_case = CompletionStreamV1UseCase(
model_endpoint_service=external_interfaces.model_endpoint_service,
llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service,
)
response = use_case.execute(
user=auth, model_endpoint_name=model_endpoint_name, request=request
)
use_case = CompletionStreamV1UseCase(
model_endpoint_service=external_interfaces.model_endpoint_service,
llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service,
)
response = use_case.execute(user=auth, model_endpoint_name=model_endpoint_name, request=request)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here use_case.execute doesn't actually execute until L237


async def event_generator():
try:
async for message in response:
yield {"data": message.json()}
except InvalidRequestException as exc:
yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}}
return
async def event_generator():
try:
async for message in response:
yield {"data": message.json()}
except (InvalidRequestException, ObjectHasInvalidValueException) as exc:
yield handle_streaming_exception(exc, 400, str(exc))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my edification, was the primary fix for urllib3.exceptions.ProtocolError to push the exception handling inside the generator with yield?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exception in use_case.execute won't throw until async for message in response:, and since originally we don't capture exceptions other than InvalidRequestException, my understanding is 0 bytes would get returned and client side throws error about not able to decode

except (
ObjectNotFoundException,
ObjectNotAuthorizedException,
EndpointUnsupportedInferenceTypeException,
) as exc:
yield handle_streaming_exception(exc, 404, str(exc))
except Exception as exc:
yield handle_streaming_exception(
exc, 500, "Internal error occurred. Our team has been notified."
)

return EventSourceResponse(event_generator())
except UpstreamServiceError:
request_id = get_request_id()
logger.exception(f"Upstream service error for request {request_id}")
return EventSourceResponse(
iter((CompletionStreamV1Response(request_id=request_id).json(),)) # type: ignore
)
except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc:
raise HTTPException(
status_code=404,
detail="The specified endpoint could not be found.",
) from exc
except ObjectHasInvalidValueException as exc:
raise HTTPException(status_code=400, detail=str(exc))
except EndpointUnsupportedInferenceTypeException as exc:
raise HTTPException(
status_code=400,
detail=f"Unsupported inference type: {str(exc)}",
) from exc
return EventSourceResponse(event_generator())


@llm_router_v1.post("/fine-tunes", response_model=CreateFineTuneResponse)
Expand Down Expand Up @@ -405,12 +427,12 @@ async def delete_llm_model_endpoint(
model_endpoint_service=external_interfaces.model_endpoint_service,
)
return await use_case.execute(user=auth, model_endpoint_name=model_endpoint_name)
except (ObjectNotFoundException) as exc:
except ObjectNotFoundException as exc:
raise HTTPException(
status_code=404,
detail="The requested model endpoint could not be found.",
) from exc
except (ObjectNotAuthorizedException) as exc:
except ObjectNotAuthorizedException as exc:
raise HTTPException(
status_code=403,
detail="You don't have permission to delete the requested model endpoint.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1296,7 +1296,7 @@ async def execute(
)

if len(model_endpoints) == 0:
raise ObjectNotFoundException
raise ObjectNotFoundException(f"Model endpoint {model_endpoint_name} not found.")

if len(model_endpoints) > 1:
raise ObjectHasInvalidValueException(
Expand Down
57 changes: 57 additions & 0 deletions model-engine/tests/unit/api/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,32 @@ def test_completion_sync_success(
assert response_1.json().keys() == {"output", "request_id"}


def test_completion_sync_endpoint_not_found_returns_404(
llm_model_endpoint_sync: Tuple[ModelEndpoint, Any],
completion_sync_request: Dict[str, Any],
get_test_client_wrapper,
):
client = get_test_client_wrapper(
fake_docker_repository_image_always_exists=True,
fake_model_bundle_repository_contents={},
fake_model_endpoint_record_repository_contents={},
fake_model_endpoint_infra_gateway_contents={
llm_model_endpoint_sync[0]
.infra_state.deployment_name: llm_model_endpoint_sync[0]
.infra_state,
},
fake_batch_job_record_repository_contents={},
fake_batch_job_progress_gateway_contents={},
fake_docker_image_batch_job_bundle_repository_contents={},
)
response_1 = client.post(
f"/v1/llm/completions-sync?model_endpoint_name={llm_model_endpoint_sync[0].record.name}",
auth=("no_user", ""),
json=completion_sync_request,
)
assert response_1.status_code == 404


@pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still want to skip this test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, unfortunately i still haven't figured out how to run two streaming tests in a row. there's something wrong about how test client uses event loop that i wasn't able to fix: more context in encode/starlette#1315

def test_completion_stream_success(
llm_model_endpoint_streaming: ModelEndpoint,
Expand All @@ -136,6 +162,7 @@ def test_completion_stream_success(
f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}",
auth=("no_user", ""),
json=completion_stream_request,
stream=True,
)
assert response_1.status_code == 200
count = 0
Expand All @@ -146,3 +173,33 @@ def test_completion_stream_success(
)
count += 1
assert count == 1


@pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness")
def test_completion_stream_endpoint_not_found_returns_404(
llm_model_endpoint_streaming: ModelEndpoint,
completion_stream_request: Dict[str, Any],
get_test_client_wrapper,
):
client = get_test_client_wrapper(
fake_docker_repository_image_always_exists=True,
fake_model_bundle_repository_contents={},
fake_model_endpoint_record_repository_contents={},
fake_model_endpoint_infra_gateway_contents={
llm_model_endpoint_streaming.infra_state.deployment_name: llm_model_endpoint_streaming.infra_state,
},
fake_batch_job_record_repository_contents={},
fake_batch_job_progress_gateway_contents={},
fake_docker_image_batch_job_bundle_repository_contents={},
)
response_1 = client.post(
f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}",
auth=("no_user", ""),
json=completion_stream_request,
stream=True,
)

assert response_1.status_code == 200

for message in response_1:
assert "404" in message.decode("utf-8")
1 change: 1 addition & 0 deletions model-engine/tests/unit/api/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def test_create_streaming_task_success(
f"/v1/streaming-tasks?model_endpoint_id={model_endpoint_streaming.record.id}",
auth=(test_api_key, ""),
json=endpoint_predict_request_1[1],
stream=True,
)
assert response.status_code == 200
count = 0
Expand Down