Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: support pytorch models #157

Merged
merged 7 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ jobs:
MODULE: ${{ matrix.module }}
run: |
pip install llama-cpp-python
pip install transformers
pip install torch
pip install accelerate
pip install sentencepiece
pip install -e ".[dev]"
working-directory: .

Expand Down
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ dev =
all =
chatglm-cpp
llama-cpp-python
transformers
torch
accelerate
sentencepiece

[options.entry_points]
console_scripts =
Expand Down
47 changes: 33 additions & 14 deletions xinference/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .model import ModelSpec
from .model.llm.chatglm import ChatglmCppGenerateConfig
from .model.llm.core import LlamaCppGenerateConfig
from .model.llm.pytorch.core import PytorchGenerateConfig
from .types import (
ChatCompletion,
ChatCompletionChunk,
Expand All @@ -47,21 +48,27 @@ def __init__(self, model_ref: xo.ActorRefType["ModelActor"], isolation: Isolatio
self._isolation = isolation


class LlamaCppModelHandle(ModelHandle):
class GenerateModelHandle(ModelHandle):
def generate(
self, prompt: str, generate_config: Optional["LlamaCppGenerateConfig"] = None
self,
prompt: str,
generate_config: Optional[
Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]
] = None,
) -> Union["Completion", Iterator["CompletionChunk"]]:
coro = self._model_ref.generate(prompt, generate_config)
return self._isolation.call(coro)


class LlamaCppChatModelHandle(LlamaCppModelHandle):
class ChatModelHandle(GenerateModelHandle):
def chat(
self,
prompt: str,
system_prompt: Optional[str] = None,
chat_history: Optional[List["ChatCompletionMessage"]] = None,
generate_config: Optional["LlamaCppGenerateConfig"] = None,
generate_config: Optional[
Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]
] = None,
) -> Union["ChatCompletion", Iterator["ChatCompletionChunk"]]:
coro = self._model_ref.chat(
prompt, system_prompt, chat_history, generate_config
Expand Down Expand Up @@ -91,9 +98,13 @@ def __init__(self, model_uid: str, base_url: str):
self._base_url = base_url


class RESTfulLlamaCppModelHandle(RESTfulModelHandle):
class RESTfulGenerateModelHandle(RESTfulModelHandle):
def generate(
self, prompt: str, generate_config: Optional["LlamaCppGenerateConfig"] = None
self,
prompt: str,
generate_config: Optional[
Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]
] = None,
) -> Union["Completion", Iterator["CompletionChunk"]]:
url = f"{self._base_url}/v1/completions"
if generate_config is None:
Expand All @@ -117,13 +128,15 @@ def generate(
return response_data


class RESTfulLlamaCppChatModelHandle(RESTfulLlamaCppModelHandle):
class RESTfulChatModelHandle(RESTfulGenerateModelHandle):
def chat(
self,
prompt: str,
system_prompt: Optional[str] = None,
chat_history: Optional[List["ChatCompletionMessage"]] = None,
generate_config: Optional["LlamaCppGenerateConfig"] = None,
generate_config: Optional[
Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]
] = None,
) -> Union["ChatCompletion", Iterator["ChatCompletionChunk"]]:
url = f"{self._base_url}/v1/chat/completions"

Expand Down Expand Up @@ -249,10 +262,13 @@ def get_model(self, model_uid: str) -> "ModelHandle":

if model_spec.model_name == "chatglm" or model_spec.model_name == "chatglm2":
return ChatglmCppChatModelHandle(model_ref, self._isolation)
elif model_spec.model_name == "baichuan":
return LlamaCppModelHandle(model_ref, self._isolation)
elif (
model_spec.model_name == "baichuan"
or model_spec.model_name == "baichuan-inc/Baichuan-7B"
):
return GenerateModelHandle(model_ref, self._isolation)
else:
return LlamaCppChatModelHandle(model_ref, self._isolation)
return ChatModelHandle(model_ref, self._isolation)


class RESTfulClient:
Expand Down Expand Up @@ -336,7 +352,10 @@ def get_model(self, model_uid: str) -> RESTfulModelHandle:
or model_spec["model_name"] == "chatglm2"
):
return RESTfulChatglmCppChatModelHandle(model_uid, self.base_url)
elif model_spec["model_name"] == "baichuan":
return RESTfulLlamaCppModelHandle(model_uid, self.base_url)
elif (
model_spec["model_name"] == "baichuan"
or model_spec["model_name"] == "baichuan-inc/Baichuan-7B"
):
return RESTfulGenerateModelHandle(model_uid, self.base_url)
else:
return RESTfulLlamaCppChatModelHandle(model_uid, self.base_url)
return RESTfulChatModelHandle(model_uid, self.base_url)
6 changes: 3 additions & 3 deletions xinference/core/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
+ "Top-k sampling is a text generation method that selects the next token only from the top k most likely tokens predicted by the model. It helps reduce the risk of generating low-probability or nonsensical tokens, but it may also limit the diversity of the output. A higher value for top_k (e.g., 100) will consider more tokens and lead to more diverse text, while a lower value (e.g., 10) will focus on the most probable tokens and generate more conservative text.",
)

repeat_penalty_field = Field(
repetition_penalty_field = Field(
default=1.1,
ge=0.0,
description="A penalty applied to each token that is already generated. This helps prevent the model from repeating itself.\n\n"
Expand Down Expand Up @@ -139,7 +139,7 @@ class CreateCompletionRequest(BaseModel):

# llama.cpp specific parameters
top_k: int = top_k_field
repeat_penalty: float = repeat_penalty_field
repetition_penalty: float = repetition_penalty_field
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)

class Config:
Expand Down Expand Up @@ -193,7 +193,7 @@ class CreateChatCompletionRequest(BaseModel):

# llama.cpp specific parameters
top_k: int = top_k_field
repeat_penalty: float = repeat_penalty_field
repetition_penalty: float = repetition_penalty_field
pangyoki marked this conversation as resolved.
Show resolved Hide resolved
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)

class Config:
Expand Down
26 changes: 21 additions & 5 deletions xinference/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,31 @@ def __str__(self):

def __iter__(self):
model_specs = []
for model_size in self.model_sizes_in_billions:
for quantization in self.quantizations:
if self.model_format == "pytorch":
for model_size in self.model_sizes_in_billions:
model_specs.append(
ModelSpec(
model_name=self.model_name,
model_size_in_billions=model_size,
model_format=self.model_format,
quantization=quantization,
url=self.url_generator(model_size, quantization),
quantization=None,
url=None,
)
)
return iter(model_specs)
return iter(model_specs)
else:
for model_size in self.model_sizes_in_billions:
for quantization in self.quantizations:
model_specs.append(
ModelSpec(
model_name=self.model_name,
model_size_in_billions=model_size,
model_format=self.model_format,
quantization=quantization,
url=self.url_generator(model_size, quantization),
)
)
return iter(model_specs)

def match(
self,
Expand Down Expand Up @@ -133,6 +146,9 @@ def cache(
model_size_in_billions: Optional[int] = None,
quantization: Optional[str] = None,
) -> str:
if self.model_format == "pytorch":
return self.model_name

# by default, choose the smallest size.
model_size_in_billions = (
model_size_in_billions or self.model_sizes_in_billions[0]
Expand Down
24 changes: 24 additions & 0 deletions xinference/model/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def install():
from .chatglm import ChatglmCppChatModel
from .core import LlamaCppModel
from .orca import OrcaMiniGgml
from .pytorch.baichuan import BaichuanPytorch
from .pytorch.vicuna import VicunaCensoredPytorch
from .vicuna import VicunaCensoredGgml
from .wizardlm import WizardlmGgml

Expand Down Expand Up @@ -208,3 +210,25 @@ def install():
cls=ChatglmCppChatModel,
)
)

MODEL_FAMILIES.append(
pangyoki marked this conversation as resolved.
Show resolved Hide resolved
ModelFamily(
model_name="baichuan-inc/Baichuan-7B",
model_sizes_in_billions=[7],
model_format="pytorch",
quantizations=None,
url_generator=None,
cls=BaichuanPytorch,
),
)

MODEL_FAMILIES.append(
ModelFamily(
model_name="lmsys/vicuna-7b-v1.3",
model_sizes_in_billions=[7, 13],
model_format="pytorch",
quantizations=None,
url_generator=None,
cls=VicunaCensoredPytorch,
),
)
Loading