Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: basic exception handling for RESTful api #111

Merged
merged 1 commit into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion xinference/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def list_models(self) -> Dict[str, Dict[str, Any]]:
url = f"{self.base_url}/v1/models"

response = requests.get(url)
if response.status_code != 200:
raise RuntimeError(
f"Failed to launch model, detail: {response.json()['detail']}"
)

response_data = response.json()
return response_data

Expand All @@ -110,6 +115,11 @@ def launch_model(
"kwargs": kwargs,
}
response = requests.post(url, json=payload)
if response.status_code != 200:
raise RuntimeError(
f"Failed to launch model, detail: {response.json()['detail']}"
)

response_data = response.json()
model_uid = response_data["model_uid"]
return model_uid
Expand All @@ -119,11 +129,15 @@ def terminate_model(self, model_uid: str):

response = requests.delete(url)
if response.status_code != 200:
raise Exception(f"Error terminating the model.")
raise RuntimeError(
f"Failed to terminate model, detail: {response.json()['detail']}"
)

def _get_supervisor_internal_address(self):
url = f"{self.base_url}/v1/address"
response = requests.get(url)
if response.status_code != 200:
raise RuntimeError(f"Failed to get supervisor internal address")
response_data = response.json()
return response_data

Expand Down
67 changes: 50 additions & 17 deletions xinference/core/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import sys
import threading
from typing import Any, Dict, List, Literal, Optional, Union

import gradio as gr
import xoscar as xo
from fastapi import APIRouter, FastAPI, Request
from fastapi import APIRouter, FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
Expand All @@ -29,6 +30,8 @@
from ..model.llm.types import ChatCompletion, Completion
from .service import SupervisorActor

logger = logging.getLogger(__name__)

max_tokens_field = Field(
default=16, ge=1, le=2048, description="The maximum number of tokens to generate."
)
Expand Down Expand Up @@ -267,7 +270,12 @@ def serve(self):
return f"http://{self._host}:{self._port}"

async def list_models(self) -> Dict[str, Dict[str, Any]]:
models = await self._supervisor_ref.list_models()
try:
models = await self._supervisor_ref.list_models()
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

models_dict = {}
for model_uid, model_spec in models:
models_dict[model_uid] = {
Expand All @@ -287,18 +295,26 @@ async def launch_model(self, request: Request) -> JSONResponse:
quantization = payload.get("quantization")
kwargs = payload.get("kwargs", {}) or {}

await self._supervisor_ref.launch_builtin_model(
model_uid=model_uid,
model_name=model_name,
model_size_in_billions=model_size_in_billions,
model_format=model_format,
quantization=quantization,
**kwargs,
)
try:
await self._supervisor_ref.launch_builtin_model(
model_uid=model_uid,
model_name=model_name,
model_size_in_billions=model_size_in_billions,
model_format=model_format,
quantization=quantization,
**kwargs,
)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
return JSONResponse(content={"model_uid": model_uid})

async def terminate_model(self, model_uid: str):
await self._supervisor_ref.terminate_model(model_uid)
try:
await self._supervisor_ref.terminate_model(model_uid)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

async def get_address(self):
return self.address
Expand All @@ -318,11 +334,21 @@ async def create_completion(self, request: Request, body: CreateCompletionReques
if body.logit_bias is not None:
raise NotImplementedError
model_uid = body.model
model = await self._supervisor_ref.get_model(model_uid)

try:
model = await self._supervisor_ref.get_model(model_uid)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

if body.stream:
raise NotImplementedError
else:
return await model.generate(body.prompt, kwargs)
try:
return await model.generate(body.prompt, kwargs)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

async def create_embedding(self, request: CreateEmbeddingRequest):
raise NotImplementedError
Expand Down Expand Up @@ -351,18 +377,25 @@ async def create_chat_completion(
if user_messages:
prompt = user_messages[-1]
else:
raise Exception("no prompt given")
raise HTTPException(status_code=400, detail="No prompt given")
system_prompt = next(
(msg["content"] for msg in body.messages if msg["role"] == "system"), None
)

chat_history = body.messages

model_uid = body.model
model = await self._supervisor_ref.get_model(model_uid)
try:
model = await self._supervisor_ref.get_model(model_uid)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

if body.stream:
raise NotImplementedError

else:
return await model.chat(prompt, system_prompt, chat_history, kwargs)
try:
return await model.chat(prompt, system_prompt, chat_history, kwargs)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))