Skip to content

Commit

Permalink
add baichuan model
Browse files Browse the repository at this point in the history
  • Loading branch information
pangyoki committed Jul 14, 2023
1 parent 4002421 commit 7952a61
Show file tree
Hide file tree
Showing 9 changed files with 424 additions and 363 deletions.
7 changes: 4 additions & 3 deletions xinference/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
if TYPE_CHECKING:
from .model import ModelSpec
from .model.llm.chatglm import ChatglmCppGenerateConfig
from .model.llm.core import LlamaCppGenerateConfig, PytorchGenerateConfig
from .model.llm.core import LlamaCppGenerateConfig
from .model.llm.pytorch.core import PytorchGenerateConfig
from .types import (
ChatCompletion,
ChatCompletionChunk,
Expand Down Expand Up @@ -263,7 +264,7 @@ def get_model(self, model_uid: str) -> "ModelHandle":
return ChatglmCppChatModelHandle(model_ref, self._isolation)
elif (
model_spec.model_name == "baichuan"
or model_spec.model_name == "facebook/opt-125m"
or model_spec.model_name == "baichuan-inc/Baichuan-7B"
):
return GenerateModelHandle(model_ref, self._isolation)
else:
Expand Down Expand Up @@ -353,7 +354,7 @@ def get_model(self, model_uid: str) -> RESTfulModelHandle:
return RESTfulChatglmCppChatModelHandle(model_uid, self.base_url)
elif (
model_spec["model_name"] == "baichuan"
or model_spec["model_name"] == "facebook/opt-125m"
or model_spec["model_name"] == "baichuan-inc/Baichuan-7B"
):
return RESTfulGenerateModelHandle(model_uid, self.base_url)
else:
Expand Down
3 changes: 0 additions & 3 deletions xinference/deploy/cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,22 +114,19 @@ def worker(log_level: str, endpoint: str, host: str):
@click.option("--size-in-billions", "-s", default=None, type=int)
@click.option("--model-format", "-f", default=None, type=str)
@click.option("--quantization", "-q", default=None, type=str)
@click.option("--device", "-d", default="cuda", type=str)
def model_launch(
endpoint: str,
model_name: str,
size_in_billions: int,
model_format: str,
quantization: str,
device: str,
):
client = RESTfulClient(base_url=endpoint)
model_uid = client.launch_model(
model_name=model_name,
model_size_in_billions=size_in_billions,
model_format=model_format,
quantization=quantization,
device=device,
)

print(f"Model uid: {model_uid}")
Expand Down
9 changes: 5 additions & 4 deletions xinference/model/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
def install():
from .. import MODEL_FAMILIES, ModelFamily
from .chatglm import ChatglmCppChatModel
from .core import LlamaCppModel, PytorchModel
from .core import LlamaCppModel
from .orca import OrcaMiniGgml
from .pytorch.baichuan import BaichuanPytorch
from .pytorch.vicuna import VicunaCensoredPytorch
from .vicuna import VicunaCensoredGgml
from .wizardlm import WizardlmGgml
Expand Down Expand Up @@ -182,12 +183,12 @@ def install():

MODEL_FAMILIES.append(
ModelFamily(
model_name="facebook/opt-125m",
model_sizes_in_billions=[1],
model_name="baichuan-inc/Baichuan-7B",
model_sizes_in_billions=[7],
model_format="pytorch",
quantizations=None,
url_generator=None,
cls=PytorchModel,
cls=BaichuanPytorch,
),
)

Expand Down
Loading

0 comments on commit 7952a61

Please sign in to comment.