Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REF: support query for engine feature #1294

Merged
merged 41 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
8846be1
Setup for Contributing Doc
Ago327 Mar 1, 2024
0911f23
Merge branch 'xorbitsai:main' into main
Ago327 Mar 4, 2024
2c1a96d
Merge branch 'xorbitsai:main' into main
Ago327 Mar 5, 2024
dead463
Merge branch 'xorbitsai:main' into main
Ago327 Mar 6, 2024
a764092
Merge branch 'xorbitsai:main' into main
Ago327 Mar 6, 2024
d421b37
Merge branch 'xorbitsai:main' into main
Ago327 Mar 6, 2024
eec0c74
Merge branch 'xorbitsai:main' into main
Ago327 Mar 7, 2024
cc9f55c
Merge branch 'xorbitsai:main' into main
Ago327 Mar 11, 2024
147e819
doc_development
Ago327 Mar 11, 2024
0339462
Merge branch 'xorbitsai:main' into main
Ago327 Mar 12, 2024
c4337eb
Merge branch 'xorbitsai:main' into main
Ago327 Mar 13, 2024
ab43e12
Merge branch 'xorbitsai:main' into main
Ago327 Mar 18, 2024
fd15809
Merge branch 'xorbitsai:main' into main
Ago327 Mar 19, 2024
40e76b6
Merge branch 'xorbitsai:main' into main
Ago327 Mar 21, 2024
9f9d9b1
Merge branch 'xorbitsai:main' into main
Ago327 Mar 21, 2024
3a6a175
Merge branch 'xorbitsai:main' into main
Ago327 Mar 28, 2024
5cc8dc6
Merge branch 'xorbitsai:main' into main
Ago327 Apr 1, 2024
601dc9b
Merge branch 'xorbitsai:main' into main
Ago327 Apr 9, 2024
4f0f414
Merge branch 'xorbitsai:main' into main
Ago327 Apr 14, 2024
a479cd3
init structure for query
Ago327 Apr 14, 2024
97e0c94
fix bug
Ago327 Apr 14, 2024
517b15d
Merge branch 'xorbitsai:main' into query-for-engine
Ago327 Apr 15, 2024
81ca84d
remove QUANTIZATION_PARAMS and add function match_engine_params
Ago327 Apr 16, 2024
ecf8bd0
Merge branch 'xorbitsai:main' into query-for-engine
Ago327 Apr 16, 2024
e5ff0e0
add UT
Ago327 Apr 16, 2024
7a98129
fix UT
Ago327 Apr 16, 2024
7564a66
temporary UT
Ago327 Apr 16, 2024
b2a97d0
ut
Ago327 Apr 16, 2024
c139b5c
fix
Ago327 Apr 16, 2024
197faf5
format detail
Ago327 Apr 17, 2024
a4cbb8e
temp ut
Ago327 Apr 17, 2024
a085ba5
adjust ut position
Ago327 Apr 17, 2024
508040a
Merge branch 'xorbitsai:main' into query-for-engine
Ago327 Apr 17, 2024
0f87574
add UT
Ago327 Apr 18, 2024
4e447dd
Merge branch 'query-for-engine' of github.com:Ago327/inference into q…
Ago327 Apr 18, 2024
368ea4c
fix chatglm UT
Ago327 Apr 18, 2024
a2f2b78
detail fixes
Ago327 Apr 18, 2024
e118b9a
filter for supervisor
Ago327 Apr 18, 2024
4aeaa0f
Merge branch 'xorbitsai:main' into query-for-engine
Ago327 Apr 18, 2024
cdd4961
fix supervisor
Ago327 Apr 18, 2024
cddc60d
fix debug
Ago327 Apr 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,16 @@ def serve(self, logging_conf: Optional[dict] = None):
self._router.add_api_route(
"/v1/cluster/auth", self.is_cluster_authenticated, methods=["GET"]
)
self._router.add_api_route(
"/v1/engines/{model_name}",
self.query_engines_by_model_name,
methods=["GET"],
dependencies=(
[Security(self._auth_service, scopes=["models:list"])]
if self.is_authenticated()
else None
),
)
# running instances
self._router.add_api_route(
"/v1/models/instances",
Expand Down Expand Up @@ -1418,6 +1428,19 @@ async def stream_results():
self.handle_request_limit_error(e)
raise HTTPException(status_code=500, detail=str(e))

async def query_engines_by_model_name(self, model_name: str) -> JSONResponse:
try:
content = await (
await self._get_supervisor_ref()
).query_engines_by_model_name(model_name)
ChengjieLi28 marked this conversation as resolved.
Show resolved Hide resolved
return JSONResponse(content=content)
except ValueError as re:
logger.error(re, exc_info=True)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

async def register_model(self, model_type: str, request: Request) -> JSONResponse:
body = RegisterModelRequest.parse_obj(await request.json())
model = body.model
Expand Down
18 changes: 18 additions & 0 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,24 @@ def get_model_registration(self, model_type: str, model_name: str) -> Any:
else:
raise ValueError(f"Unsupported model type: {model_type}")

@log_async(logger=logger)
async def query_engines_by_model_name(self, model_name: str):
from copy import deepcopy

from ..model.llm.llm_family import LLM_ENGINES

if model_name not in LLM_ENGINES:
raise ValueError(f"Model {model_name} not found")

# filter llm_class
engine_params = deepcopy(LLM_ENGINES[model_name])
for engine in engine_params:
params = engine_params[engine]
for param in params:
del param["llm_class"]

return engine_params

@log_async(logger=logger)
async def register_model(self, model_type: str, model: str, persist: bool):
if model_type in self._custom_register_type_to_cls:
Expand Down
88 changes: 88 additions & 0 deletions xinference/model/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,14 @@
BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES,
BUILTIN_LLM_PROMPT_STYLE,
BUILTIN_MODELSCOPE_LLM_FAMILIES,
LLAMA_CLASSES,
LLM_CLASSES,
LLM_ENGINES,
PEFT_SUPPORTED_CLASSES,
PYTORCH_CLASSES,
SGLANG_CLASSES,
SUPPORTED_ENGINES,
VLLM_CLASSES,
CustomLLMFamilyV1,
GgmlLLMSpecV1,
LLMFamilyV1,
Expand All @@ -47,6 +53,50 @@
)


def generate_engine_config_by_model_family(model_family):
model_name = model_family.model_name
specs = model_family.model_specs
engines = {} # structure for engine query
for spec in specs:
model_format = spec.model_format
model_size_in_billions = spec.model_size_in_billions
quantizations = spec.quantizations
for quantization in quantizations:
# traverse all supported engines to match the name, format, size in billions and quatization of model
for engine in SUPPORTED_ENGINES:
CLASSES = SUPPORTED_ENGINES[engine]
for cls in CLASSES:
if cls.match(model_family, spec, quantization):
engine_params = engines.get(engine, [])
already_exists = False
# if the name, format and size in billions of model already exists in the structure, add the new quantization
for param in engine_params:
if (
model_name == param["model_name"]
and model_format == param["model_format"]
and model_size_in_billions
== param["model_size_in_billions"]
and quantization not in param["quantizations"]
):
param["quantizations"].append(quantization)
already_exists = True
break
# successfully match the params for the first time, add to the structure
if not already_exists:
engine_params.append(
{
"model_name": model_name,
"model_format": model_format,
"model_size_in_billions": model_size_in_billions,
"quantizations": [quantization],
"llm_class": cls,
}
)
engines[engine] = engine_params
break
LLM_ENGINES[model_name] = engines


def _install():
from .ggml.chatglm import ChatglmCppChatModel
from .ggml.llamacpp import LlamaCppChatModel, LlamaCppModel
Expand Down Expand Up @@ -76,8 +126,17 @@ def _install():
ChatglmCppChatModel,
]
)
LLAMA_CLASSES.extend(
[
ChatglmCppChatModel,
Ago327 marked this conversation as resolved.
Show resolved Hide resolved
LlamaCppChatModel,
LlamaCppModel,
]
)
LLM_CLASSES.extend([SGLANGModel, SGLANGChatModel])
SGLANG_CLASSES.extend([SGLANGModel, SGLANGChatModel])
LLM_CLASSES.extend([VLLMModel, VLLMChatModel])
VLLM_CLASSES.extend([VLLMModel, VLLMChatModel])
LLM_CLASSES.extend(
[
BaichuanPytorchChatModel,
Expand All @@ -96,6 +155,24 @@ def _install():
PytorchModel,
]
)
PYTORCH_CLASSES.extend(
[
BaichuanPytorchChatModel,
VicunaPytorchChatModel,
FalconPytorchChatModel,
ChatglmPytorchChatModel,
LlamaPytorchModel,
LlamaPytorchChatModel,
PytorchChatModel,
FalconPytorchModel,
Internlm2PytorchChatModel,
QwenVLChatModel,
OmniLMMModel,
YiVLChatModel,
DeepSeekVLChatModel,
PytorchModel,
]
)
PEFT_SUPPORTED_CLASSES.extend(
[
BaichuanPytorchChatModel,
Expand All @@ -113,6 +190,12 @@ def _install():
]
)

# support 4 engines for now
SUPPORTED_ENGINES["vLLM"] = VLLM_CLASSES
SUPPORTED_ENGINES["SGLang"] = SGLANG_CLASSES
SUPPORTED_ENGINES["PyTorch"] = PYTORCH_CLASSES
SUPPORTED_ENGINES["llama-cpp-python"] = LLAMA_CLASSES

json_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "llm_family.json"
)
Expand Down Expand Up @@ -163,6 +246,11 @@ def _install():
if llm_spec.model_name not in LLM_MODEL_DESCRIPTIONS:
LLM_MODEL_DESCRIPTIONS.update(generate_llm_description(llm_spec))

# traverse all families and add engine parameters corresponding to the model name
for families in [BUILTIN_LLM_FAMILIES, BUILTIN_MODELSCOPE_LLM_FAMILIES]:
for family in families:
generate_engine_config_by_model_family(family)

from ...constants import XINFERENCE_MODEL_DIR

user_defined_llm_dir = os.path.join(XINFERENCE_MODEL_DIR, "llm")
Expand Down
40 changes: 40 additions & 0 deletions xinference/model/llm/llm_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,16 +227,25 @@ def parse_raw(
CustomLLMFamilyV1.update_forward_refs()


LLAMA_CLASSES: List[Type[LLM]] = []
LLM_CLASSES: List[Type[LLM]] = []
PEFT_SUPPORTED_CLASSES: List[Type[LLM]] = []

BUILTIN_LLM_FAMILIES: List["LLMFamilyV1"] = []
BUILTIN_MODELSCOPE_LLM_FAMILIES: List["LLMFamilyV1"] = []

SGLANG_CLASSES: List[Type[LLM]] = []
PYTORCH_CLASSES: List[Type[LLM]] = []

UD_LLM_FAMILIES: List["LLMFamilyV1"] = []

UD_LLM_FAMILIES_LOCK = Lock()

VLLM_CLASSES: List[Type[LLM]] = []

LLM_ENGINES: Dict[str, Dict[str, List[Dict[str, Any]]]] = {}
SUPPORTED_ENGINES: Dict[str, List[Type[LLM]]] = {}

LLM_LAUNCH_VERSIONS: Dict[str, List[str]] = {}


Expand Down Expand Up @@ -904,6 +913,7 @@ def _apply_format_to_model_id(spec: LLMSpecV1, q: str) -> LLMSpecV1:

def register_llm(llm_family: LLMFamilyV1, persist: bool):
from ..utils import is_valid_model_name
from . import generate_engine_config_by_model_family

if not is_valid_model_name(llm_family.model_name):
raise ValueError(f"Invalid model name {llm_family.model_name}.")
Expand All @@ -916,6 +926,7 @@ def register_llm(llm_family: LLMFamilyV1, persist: bool):
)

UD_LLM_FAMILIES.append(llm_family)
generate_engine_config_by_model_family(llm_family)

if persist:
# We only validate model URL when persist is True.
Expand All @@ -941,6 +952,7 @@ def unregister_llm(model_name: str, raise_error: bool = True):
break
if llm_family:
UD_LLM_FAMILIES.remove(llm_family)
del LLM_ENGINES[model_name]

persist_path = os.path.join(
XINFERENCE_MODEL_DIR, "llm", f"{llm_family.model_name}.json"
Expand Down Expand Up @@ -990,3 +1002,31 @@ def match_llm_cls(
if cls.match(family, llm_spec, quantization):
return cls
return None


def check_engine_by_spec_parameters(
model_engine: str,
Ago327 marked this conversation as resolved.
Show resolved Hide resolved
model_name: str,
model_format: str,
model_size_in_billions: Union[str, int],
quantization: str,
) -> Optional[Type[LLM]]:
if model_name not in LLM_ENGINES:
logger.debug(f"Cannot find model {model_name}.")
return None
if model_engine not in LLM_ENGINES[model_name]:
logger.debug(f"Model {model_name} cannot be run on engine {model_engine}.")
return None
match_params = LLM_ENGINES[model_name][model_engine]
for param in match_params:
if (
model_name == param["model_name"]
and model_format == param["model_format"]
and model_size_in_billions == param["model_size_in_billions"]
and quantization in param["quantizations"]
):
return param["llm_class"]
logger.debug(
f"Model {model_name} with format {model_format}, size {model_size_in_billions} and quantization {quantization} cannot be run on engine {model_engine}."
)
return None
Loading
Loading