diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3fe2075c6..bb2d9cc0b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,6 +55,7 @@ repos: hooks: - id: mypy name: mypy-clients-python + files: clients/python/.* entry: mypy --config-file clients/python/mypy.ini language: system - repo: https://github.com/pre-commit/mirrors-mypy diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py index 211106d8f..076124207 100644 --- a/clients/python/llmengine/data_types.py +++ b/clients/python/llmengine/data_types.py @@ -355,6 +355,24 @@ class CompletionStreamOutput(BaseModel): """Detailed token information.""" +class StreamErrorContent(BaseModel): + error: str + """Error message.""" + timestamp: str + """Timestamp of the error.""" + + +class StreamError(BaseModel): + """ + Error object for a stream prompt completion task. + """ + + status_code: int + """The HTTP status code of the error.""" + content: StreamErrorContent + """The error content.""" + + class CompletionStreamResponse(BaseModel): """ Response object for a stream prompt completion task. @@ -372,6 +390,9 @@ class CompletionStreamResponse(BaseModel): output: Optional[CompletionStreamOutput] = None """Completion output.""" + error: Optional[StreamError] = None + """Error of the response (if any).""" + class CreateFineTuneRequest(BaseModel): """ diff --git a/docs/getting_started.md b/docs/getting_started.md index 46741d1b4..fea0531a0 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -81,11 +81,10 @@ stream = Completion.create( ) for response in stream: - try: - if response.output: - print(response.output.text, end="") - sys.stdout.flush() - except: # an error occurred - print(stream.text) # print the error message out + if response.output: + print(response.output.text, end="") + sys.stdout.flush() + else: # an error occurred + print(response.error) # print the error message out break ``` diff --git a/docs/guides/completions.md b/docs/guides/completions.md index 4719edc39..dee51f615 100644 --- a/docs/guides/completions.md +++ b/docs/guides/completions.md @@ -87,12 +87,11 @@ stream = Completion.create( ) for response in stream: - try: - if response.output: - print(response.output.text, end="") - sys.stdout.flush() - except: # an error occurred - print(stream.text) # print the error message out + if response.output: + print(response.output.text, end="") + sys.stdout.flush() + else: # an error occurred + print(response.error) # print the error message out break ``` diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py index 67abfefa4..92ddad0e7 100644 --- a/model-engine/model_engine_server/api/llms_v1.py +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -1,7 +1,10 @@ """LLM Model Endpoint routes for the hosted model inference service. """ +import traceback +from datetime import datetime from typing import Optional +import pytz from fastapi import APIRouter, Depends, HTTPException, Query from model_engine_server.api.dependencies import ( ExternalInterfaces, @@ -28,6 +31,8 @@ ListLLMModelEndpointsV1Response, ModelDownloadRequest, ModelDownloadResponse, + StreamError, + StreamErrorContent, ) from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy from model_engine_server.core.auth.authentication_repository import User @@ -71,6 +76,34 @@ logger = make_logger(filename_wo_ext(__name__)) +def handle_streaming_exception( + e: Exception, + code: int, + message: str, +): + tb_str = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) + request_id = get_request_id() + timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") + structured_log = { + "error": message, + "request_id": str(request_id), + "traceback": "".join(tb_str), + } + logger.error("Exception: %s", structured_log) + return { + "data": CompletionStreamV1Response( + request_id=str(request_id), + error=StreamError( + status_code=code, + content=StreamErrorContent( + error=message, + timestamp=timestamp, + ), + ), + ).json() + } + + @llm_router_v1.post("/model-endpoints", response_model=CreateLLMModelEndpointV1Response) async def create_model_endpoint( request: CreateLLMModelEndpointV1Request, @@ -226,42 +259,30 @@ async def create_completion_stream_task( logger.info( f"POST /completion_stream with {request} to endpoint {model_endpoint_name} for {auth}" ) - try: - use_case = CompletionStreamV1UseCase( - model_endpoint_service=external_interfaces.model_endpoint_service, - llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, - ) - response = use_case.execute( - user=auth, model_endpoint_name=model_endpoint_name, request=request - ) + use_case = CompletionStreamV1UseCase( + model_endpoint_service=external_interfaces.model_endpoint_service, + llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + ) + response = use_case.execute(user=auth, model_endpoint_name=model_endpoint_name, request=request) - async def event_generator(): - try: - async for message in response: - yield {"data": message.json()} - except InvalidRequestException as exc: - yield {"data": {"error": {"status_code": 400, "detail": str(exc)}}} - return + async def event_generator(): + try: + async for message in response: + yield {"data": message.json()} + except (InvalidRequestException, ObjectHasInvalidValueException) as exc: + yield handle_streaming_exception(exc, 400, str(exc)) + except ( + ObjectNotFoundException, + ObjectNotAuthorizedException, + EndpointUnsupportedInferenceTypeException, + ) as exc: + yield handle_streaming_exception(exc, 404, str(exc)) + except Exception as exc: + yield handle_streaming_exception( + exc, 500, "Internal error occurred. Our team has been notified." + ) - return EventSourceResponse(event_generator()) - except UpstreamServiceError: - request_id = get_request_id() - logger.exception(f"Upstream service error for request {request_id}") - return EventSourceResponse( - iter((CompletionStreamV1Response(request_id=request_id).json(),)) # type: ignore - ) - except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: - raise HTTPException( - status_code=404, - detail="The specified endpoint could not be found.", - ) from exc - except ObjectHasInvalidValueException as exc: - raise HTTPException(status_code=400, detail=str(exc)) - except EndpointUnsupportedInferenceTypeException as exc: - raise HTTPException( - status_code=400, - detail=f"Unsupported inference type: {str(exc)}", - ) from exc + return EventSourceResponse(event_generator()) @llm_router_v1.post("/fine-tunes", response_model=CreateFineTuneResponse) @@ -405,12 +426,12 @@ async def delete_llm_model_endpoint( model_endpoint_service=external_interfaces.model_endpoint_service, ) return await use_case.execute(user=auth, model_endpoint_name=model_endpoint_name) - except (ObjectNotFoundException) as exc: + except ObjectNotFoundException as exc: raise HTTPException( status_code=404, detail="The requested model endpoint could not be found.", ) from exc - except (ObjectNotAuthorizedException) as exc: + except ObjectNotAuthorizedException as exc: raise HTTPException( status_code=403, detail="You don't have permission to delete the requested model endpoint.", diff --git a/model-engine/model_engine_server/common/dtos/llms.py b/model-engine/model_engine_server/common/dtos/llms.py index 27a12ddcf..bf0b75199 100644 --- a/model-engine/model_engine_server/common/dtos/llms.py +++ b/model-engine/model_engine_server/common/dtos/llms.py @@ -202,6 +202,24 @@ class CompletionStreamOutput(BaseModel): token: Optional[TokenOutput] = None +class StreamErrorContent(BaseModel): + error: str + """Error message.""" + timestamp: str + """Timestamp of the error.""" + + +class StreamError(BaseModel): + """ + Error object for a stream prompt completion task. + """ + + status_code: int + """The HTTP status code of the error.""" + content: StreamErrorContent + """The error content.""" + + class CompletionStreamV1Response(BaseModel): """ Response object for a stream prompt completion task. @@ -209,6 +227,8 @@ class CompletionStreamV1Response(BaseModel): request_id: str output: Optional[CompletionStreamOutput] = None + error: Optional[StreamError] = None + """Error of the response (if any).""" class CreateFineTuneRequest(BaseModel): diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 97d9f69b3..5b1798722 100644 --- a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -1308,7 +1308,7 @@ async def execute( ) if len(model_endpoints) == 0: - raise ObjectNotFoundException + raise ObjectNotFoundException(f"Model endpoint {model_endpoint_name} not found.") if len(model_endpoints) > 1: raise ObjectHasInvalidValueException( diff --git a/model-engine/tests/unit/api/test_llms.py b/model-engine/tests/unit/api/test_llms.py index 2e909aeb5..32178b499 100644 --- a/model-engine/tests/unit/api/test_llms.py +++ b/model-engine/tests/unit/api/test_llms.py @@ -113,6 +113,32 @@ def test_completion_sync_success( assert response_1.json().keys() == {"output", "request_id"} +def test_completion_sync_endpoint_not_found_returns_404( + llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], + completion_sync_request: Dict[str, Any], + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={}, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={ + llm_model_endpoint_sync[0] + .infra_state.deployment_name: llm_model_endpoint_sync[0] + .infra_state, + }, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + response_1 = client.post( + f"/v1/llm/completions-sync?model_endpoint_name={llm_model_endpoint_sync[0].record.name}", + auth=("no_user", ""), + json=completion_sync_request, + ) + assert response_1.status_code == 404 + + @pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness") def test_completion_stream_success( llm_model_endpoint_streaming: ModelEndpoint, @@ -136,6 +162,7 @@ def test_completion_stream_success( f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}", auth=("no_user", ""), json=completion_stream_request, + stream=True, ) assert response_1.status_code == 200 count = 0 @@ -146,3 +173,33 @@ def test_completion_stream_success( ) count += 1 assert count == 1 + + +@pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness") +def test_completion_stream_endpoint_not_found_returns_404( + llm_model_endpoint_streaming: ModelEndpoint, + completion_stream_request: Dict[str, Any], + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={}, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={ + llm_model_endpoint_streaming.infra_state.deployment_name: llm_model_endpoint_streaming.infra_state, + }, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + response_1 = client.post( + f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}", + auth=("no_user", ""), + json=completion_stream_request, + stream=True, + ) + + assert response_1.status_code == 200 + + for message in response_1: + assert "404" in message.decode("utf-8") diff --git a/model-engine/tests/unit/api/test_tasks.py b/model-engine/tests/unit/api/test_tasks.py index 5192f0250..611195bd3 100644 --- a/model-engine/tests/unit/api/test_tasks.py +++ b/model-engine/tests/unit/api/test_tasks.py @@ -364,6 +364,7 @@ def test_create_streaming_task_success( f"/v1/streaming-tasks?model_endpoint_id={model_endpoint_streaming.record.id}", auth=(test_api_key, ""), json=endpoint_predict_request_1[1], + stream=True, ) assert response.status_code == 200 count = 0