From 1ebc95f776ec4236020922622ae5e5c2164cca81 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai Date: Mon, 9 Oct 2023 21:22:55 -0700 Subject: [PATCH 1/8] Fix streaming endpoint failure handling --- .pre-commit-config.yaml | 1 + .../model_engine_server/api/llms_v1.py | 68 +++++++++---------- .../use_cases/llm_model_endpoint_use_cases.py | 2 +- model-engine/tests/unit/api/test_llms.py | 58 ++++++++++++++++ model-engine/tests/unit/api/test_tasks.py | 1 + 5 files changed, 93 insertions(+), 37 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3fe2075c6..f75d40caf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,6 +55,7 @@ repos: hooks: - id: mypy name: mypy-clients-python + files: clients/python/.*.py entry: mypy --config-file clients/python/mypy.ini language: system - repo: https://github.com/pre-commit/mirrors-mypy diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 67abfefa4..ddbc4063b 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -226,42 +226,38 @@ async def create_completion_stream_task( logger.info( f"POST /completion_stream with {request} to endpoint {model_endpoint_name} for {auth}" ) - try: - use_case = CompletionStreamV1UseCase( - model_endpoint_service=external_interfaces.model_endpoint_service, - llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, - ) - response = use_case.execute( - user=auth, model_endpoint_name=model_endpoint_name, request=request - ) + use_case = CompletionStreamV1UseCase( + model_endpoint_service=external_interfaces.model_endpoint_service, + llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + ) + response = use_case.execute(user=auth, model_endpoint_name=model_endpoint_name, request=request) - async def event_generator(): - try: - async for message in response: - yield {"data": message.json()} - except InvalidRequestException as exc: - yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}} - return + async def event_generator(): + try: + async for message in response: + yield {"data": message.json()} + except InvalidRequestException as exc: + yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}} + except UpstreamServiceError as exc: + request_id = get_request_id() + logger.exception(f"Upstream service error for request {request_id}") + yield {"data": {"error": {"status_code": 500, "detail": str(exc)}}} + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + print(str(exc)) + yield {"data": {"error": {"status_code": 404, "detail": str(exc)}}} + except ObjectHasInvalidValueException as exc: + yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}} + except EndpointUnsupportedInferenceTypeException as exc: + yield { + "data": { + "error": { + "status_code": 400, + "detail": f"Unsupported inference type: {str(exc)}", + } + } + } - return EventSourceResponse(event_generator()) - except UpstreamServiceError: - request_id = get_request_id() - logger.exception(f"Upstream service error for request {request_id}") - return EventSourceResponse( - iter((CompletionStreamV1Response(request_id=request_id).json(),)) # type: ignore - ) - except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: - raise HTTPException( - status_code=404, - detail="The specified endpoint could not be found.", - ) from exc - except ObjectHasInvalidValueException as exc: - raise HTTPException(status_code=400, detail=str(exc)) - except EndpointUnsupportedInferenceTypeException as exc: - raise HTTPException( - status_code=400, - detail=f"Unsupported inference type: {str(exc)}", - ) from exc + return EventSourceResponse(event_generator()) @llm_router_v1.post("/fine-tunes", response_model=CreateFineTuneResponse) @@ -405,12 +401,12 @@ async def delete_llm_model_endpoint( model_endpoint_service=external_interfaces.model_endpoint_service, ) return await use_case.execute(user=auth, model_endpoint_name=model_endpoint_name) - except (ObjectNotFoundException) as exc: + except ObjectNotFoundException as exc: raise HTTPException( status_code=404, detail="The requested model endpoint could not be found.", ) from exc - except (ObjectNotAuthorizedException) as exc: + except ObjectNotAuthorizedException as exc: raise HTTPException( status_code=403, detail="You don't have permission to delete the requested model endpoint.", diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index f1eb665ab..56fe18127 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -1296,7 +1296,7 @@ async def execute( ) if len(model_endpoints) == 0: - raise ObjectNotFoundException + raise ObjectNotFoundException(f"Model endpoint {model_endpoint_name} not found.") if len(model_endpoints) > 1: raise ObjectHasInvalidValueException( diff --git a/model-engine/tests/unit/api/test_llms.py b/model-engine/tests/unit/api/test_llms.py index 2e909aeb5..edbd5213d 100644 --- a/model-engine/tests/unit/api/test_llms.py +++ b/model-engine/tests/unit/api/test_llms.py @@ -113,6 +113,32 @@ def test_completion_sync_success( assert response_1.json().keys() == {"output", "request_id"} +def test_completion_sync_endpoint_not_found_returns_404( + llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], + completion_sync_request: Dict[str, Any], + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={}, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={ + llm_model_endpoint_sync[0] + .infra_state.deployment_name: llm_model_endpoint_sync[0] + .infra_state, + }, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + response_1 = client.post( + f"/v1/llm/completions-sync?model_endpoint_name={llm_model_endpoint_sync[0].record.name}", + auth=("no_user", ""), + json=completion_sync_request, + ) + assert response_1.status_code == 404 + + @pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness") def test_completion_stream_success( llm_model_endpoint_streaming: ModelEndpoint, @@ -136,6 +162,7 @@ def test_completion_stream_success( f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}", auth=("no_user", ""), json=completion_stream_request, + stream=True, ) assert response_1.status_code == 200 count = 0 @@ -146,3 +173,34 @@ def test_completion_stream_success( ) count += 1 assert count == 1 + + +@pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness") +def test_completion_stream_endpoint_not_found_returns_404( + llm_model_endpoint_streaming: ModelEndpoint, + completion_stream_request: Dict[str, Any], + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={}, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={ + llm_model_endpoint_streaming.infra_state.deployment_name: llm_model_endpoint_streaming.infra_state, + }, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + response_1 = client.post( + f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}", + auth=("no_user", ""), + json=completion_stream_request, + stream=True, + ) + + assert response_1.status_code == 200 + + for message in response_1: + print(message) + assert "404" in message.decode("utf-8") diff --git a/model-engine/tests/unit/api/test_tasks.py b/model-engine/tests/unit/api/test_tasks.py index 5192f0250..611195bd3 100644 --- a/model-engine/tests/unit/api/test_tasks.py +++ b/model-engine/tests/unit/api/test_tasks.py @@ -364,6 +364,7 @@ def test_create_streaming_task_success( f"/v1/streaming-tasks?model_endpoint_id={model_endpoint_streaming.record.id}", auth=(test_api_key, ""), json=endpoint_predict_request_1[1], + stream=True, ) assert response.status_code == 200 count = 0 From 93e009b0117be958f64795ce8d11e002fc3baae3 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai Date: Mon, 9 Oct 2023 21:23:36 -0700 Subject: [PATCH 2/8] Fix streaming endpoint failure handling --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f75d40caf..bb2d9cc0b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,7 +55,7 @@ repos: hooks: - id: mypy name: mypy-clients-python - files: clients/python/.*.py + files: clients/python/.* entry: mypy --config-file clients/python/mypy.ini language: system - repo: https://github.com/pre-commit/mirrors-mypy From 0e46b27103dfca6b4606aa63f4ede9afa1d515be Mon Sep 17 00:00:00 2001 From: Yunfeng Bai Date: Mon, 9 Oct 2023 21:27:39 -0700 Subject: [PATCH 3/8] remove print --- model-engine/model_engine_server/api/llms_v1.py | 1 - 1 file changed, 1 deletion(-) diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index ddbc4063b..d9ce36799 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -243,7 +243,6 @@ async def event_generator(): logger.exception(f"Upstream service error for request {request_id}") yield {"data": {"error": {"status_code": 500, "detail": str(exc)}}} except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: - print(str(exc)) yield {"data": {"error": {"status_code": 404, "detail": str(exc)}}} except ObjectHasInvalidValueException as exc: yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}} From ff37909144663919796222e0de39dc874d90224d Mon Sep 17 00:00:00 2001 From: Yunfeng Bai Date: Wed, 11 Oct 2023 11:16:06 -0700 Subject: [PATCH 4/8] comments --- model-engine/model_engine_server/api/llms_v1.py | 8 ++++---- model-engine/tests/unit/api/test_llms.py | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index d9ce36799..aebe65e56 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -238,10 +238,6 @@ async def event_generator(): yield {"data": message.json()} except InvalidRequestException as exc: yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}} - except UpstreamServiceError as exc: - request_id = get_request_id() - logger.exception(f"Upstream service error for request {request_id}") - yield {"data": {"error": {"status_code": 500, "detail": str(exc)}}} except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: yield {"data": {"error": {"status_code": 404, "detail": str(exc)}}} except ObjectHasInvalidValueException as exc: @@ -255,6 +251,10 @@ async def event_generator(): } } } + except Exception as exc: + request_id = get_request_id() + logger.exception(f"Internal exception for request {request_id}") + yield {"data": {"error": {"status_code": 500, "detail": str(exc)}}} return EventSourceResponse(event_generator()) diff --git a/model-engine/tests/unit/api/test_llms.py b/model-engine/tests/unit/api/test_llms.py index edbd5213d..32178b499 100644 --- a/model-engine/tests/unit/api/test_llms.py +++ b/model-engine/tests/unit/api/test_llms.py @@ -202,5 +202,4 @@ def test_completion_stream_endpoint_not_found_returns_404( assert response_1.status_code == 200 for message in response_1: - print(message) assert "404" in message.decode("utf-8") From 9143da2ec10d19070d730d2be0ff401b854ccade Mon Sep 17 00:00:00 2001 From: Yunfeng Bai Date: Wed, 11 Oct 2023 13:08:19 -0700 Subject: [PATCH 5/8] client side changes --- clients/python/llmengine/data_types.py | 23 +++++++ docs/getting_started.md | 11 ++-- docs/guides/completions.md | 11 ++-- .../model_engine_server/api/llms_v1.py | 64 +++++++++++++------ 4 files changed, 79 insertions(+), 30 deletions(-) diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 2cdc2f894..207d8a795 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -354,6 +354,26 @@ class CompletionStreamOutput(BaseModel): """Detailed token information.""" +class StreamErrorContent(BaseModel): + error: str + """Error message.""" + timestamp: str + """Timestamp of the error.""" + request_id: str + """Server generated unique ID of the corresponding request.""" + + +class StreamError(BaseModel): + """ + Error object for a stream prompt completion task. + """ + + status_code: int + """The HTTP status code of the error.""" + content: StreamErrorContent + """The error content.""" + + class CompletionStreamResponse(BaseModel): """ Response object for a stream prompt completion task. @@ -371,6 +391,9 @@ class CompletionStreamResponse(BaseModel): output: Optional[CompletionStreamOutput] = None """Completion output.""" + error: Optional[StreamError] = None + """Error of the response (if any).""" + class CreateFineTuneRequest(BaseModel): """ diff --git a/docs/getting_started.md b/docs/getting_started.md index 46741d1b4..fea0531a0 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -81,11 +81,10 @@ stream = Completion.create( ) for response in stream: - try: - if response.output: - print(response.output.text, end="") - sys.stdout.flush() - except: # an error occurred - print(stream.text) # print the error message out + if response.output: + print(response.output.text, end="") + sys.stdout.flush() + else: # an error occurred + print(response.error) # print the error message out break ``` diff --git a/docs/guides/completions.md b/docs/guides/completions.md index 4719edc39..dee51f615 100644 --- a/docs/guides/completions.md +++ b/docs/guides/completions.md @@ -87,12 +87,11 @@ stream = Completion.create( ) for response in stream: - try: - if response.output: - print(response.output.text, end="") - sys.stdout.flush() - except: # an error occurred - print(stream.text) # print the error message out + if response.output: + print(response.output.text, end="") + sys.stdout.flush() + else: # an error occurred + print(response.error) # print the error message out break ``` diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index aebe65e56..02b4cb13f 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -1,7 +1,11 @@ """LLM Model Endpoint routes for the hosted model inference service. """ +import json +import traceback +from datetime import datetime from typing import Optional +import pytz from fastapi import APIRouter, Depends, HTTPException, Query from model_engine_server.api.dependencies import ( ExternalInterfaces, @@ -71,6 +75,37 @@ logger = make_logger(filename_wo_ext(__name__)) +def handle_streaming_exception( + e: Exception, + code: int, + message: str, +): + tb_str = traceback.format_exception(e) + request_id = get_request_id() + timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") + structured_log = { + "error": message, + "request_id": str(request_id), + "traceback": "".join(tb_str), + } + logger.error("Exception: %s", structured_log) + return { + "data": json.dumps( + { + "request_id": str(request_id), + "error": { + "status_code": code, + "content": { + "error": message, + "timestamp": timestamp, + "request_id": request_id, + }, + }, + } + ) + } + + @llm_router_v1.post("/model-endpoints", response_model=CreateLLMModelEndpointV1Response) async def create_model_endpoint( request: CreateLLMModelEndpointV1Request, @@ -236,25 +271,18 @@ async def event_generator(): try: async for message in response: yield {"data": message.json()} - except InvalidRequestException as exc: - yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}} - except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: - yield {"data": {"error": {"status_code": 404, "detail": str(exc)}}} - except ObjectHasInvalidValueException as exc: - yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}} - except EndpointUnsupportedInferenceTypeException as exc: - yield { - "data": { - "error": { - "status_code": 400, - "detail": f"Unsupported inference type: {str(exc)}", - } - } - } + except (InvalidRequestException, ObjectHasInvalidValueException) as exc: + yield handle_streaming_exception(exc, 400, str(exc)) + except ( + ObjectNotFoundException, + ObjectNotAuthorizedException, + EndpointUnsupportedInferenceTypeException, + ) as exc: + yield handle_streaming_exception(exc, 404, str(exc)) except Exception as exc: - request_id = get_request_id() - logger.exception(f"Internal exception for request {request_id}") - yield {"data": {"error": {"status_code": 500, "detail": str(exc)}}} + yield handle_streaming_exception( + exc, 500, "Internal error occurred. Our team has been notified." + ) return EventSourceResponse(event_generator()) From 2f744ec544ac493d523b93fd8b1991819bebd51c Mon Sep 17 00:00:00 2001 From: Yunfeng Bai Date: Wed, 11 Oct 2023 13:11:26 -0700 Subject: [PATCH 6/8] client side changes --- model-engine/model_engine_server/api/llms_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 02b4cb13f..42897120b 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -80,7 +80,7 @@ def handle_streaming_exception( code: int, message: str, ): - tb_str = traceback.format_exception(e) + tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) request_id = get_request_id() timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") structured_log = { From b28baa2a2515e76a36d9ba2501e35a3c75864217 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai Date: Wed, 11 Oct 2023 14:47:51 -0700 Subject: [PATCH 7/8] fix --- model-engine/model_engine_server/api/llms_v1.py | 1 - 1 file changed, 1 deletion(-) diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 42897120b..37e7203ee 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -98,7 +98,6 @@ def handle_streaming_exception( "content": { "error": message, "timestamp": timestamp, - "request_id": request_id, }, }, } From 21b10e9c27de7991e3782d617dde0b9050747595 Mon Sep 17 00:00:00 2001 From: Yunfeng Bai Date: Wed, 11 Oct 2023 15:03:28 -0700 Subject: [PATCH 8/8] strong typing --- clients/python/llmengine/data_types.py | 2 -- .../model_engine_server/api/llms_v1.py | 25 +++++++++---------- .../model_engine_server/common/dtos/llms.py | 20 +++++++++++++++ 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index f75cca3ce..076124207 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -360,8 +360,6 @@ class StreamErrorContent(BaseModel): """Error message.""" timestamp: str """Timestamp of the error.""" - request_id: str - """Server generated unique ID of the corresponding request.""" class StreamError(BaseModel): diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 37e7203ee..92ddad0e7 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -1,6 +1,5 @@ """LLM Model Endpoint routes for the hosted model inference service. """ -import json import traceback from datetime import datetime from typing import Optional @@ -32,6 +31,8 @@ ListLLMModelEndpointsV1Response, ModelDownloadRequest, ModelDownloadResponse, + StreamError, + StreamErrorContent, ) from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy from model_engine_server.core.auth.authentication_repository import User @@ -90,18 +91,16 @@ def handle_streaming_exception( } logger.error("Exception: %s", structured_log) return { - "data": json.dumps( - { - "request_id": str(request_id), - "error": { - "status_code": code, - "content": { - "error": message, - "timestamp": timestamp, - }, - }, - } - ) + "data": CompletionStreamV1Response( + request_id=str(request_id), + error=StreamError( + status_code=code, + content=StreamErrorContent( + error=message, + timestamp=timestamp, + ), + ), + ).json() } diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index 27a12ddcf..bf0b75199 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -202,6 +202,24 @@ class CompletionStreamOutput(BaseModel): token: Optional[TokenOutput] = None +class StreamErrorContent(BaseModel): + error: str + """Error message.""" + timestamp: str + """Timestamp of the error.""" + + +class StreamError(BaseModel): + """ + Error object for a stream prompt completion task. + """ + + status_code: int + """The HTTP status code of the error.""" + content: StreamErrorContent + """The error content.""" + + class CompletionStreamV1Response(BaseModel): """ Response object for a stream prompt completion task. @@ -209,6 +227,8 @@ class CompletionStreamV1Response(BaseModel): request_id: str output: Optional[CompletionStreamOutput] = None + error: Optional[StreamError] = None + """Error of the response (if any).""" class CreateFineTuneRequest(BaseModel):