Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: support baichuan-chat pytorch model #190

Merged
merged 5 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ jobs:
pip install torch
pip install accelerate
pip install sentencepiece
pip install transformers_stream_generator
pip install cpm_kernels
pip install -e ".[dev]"
working-directory: .

Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ all =
torch
accelerate
sentencepiece
transformers_stream_generator
cpm_kernels

[options.entry_points]
console_scripts =
Expand Down
10 changes: 2 additions & 8 deletions xinference/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,7 @@ def get_model(self, model_uid: str) -> "ModelHandle":

if model_spec.model_name == "chatglm" or model_spec.model_name == "chatglm2":
return ChatglmCppChatModelHandle(model_ref, self._isolation)
elif (
model_spec.model_name == "baichuan"
or model_spec.model_name == "baichuan-inc/Baichuan-7B"
):
elif model_spec.model_name == "baichuan":
return GenerateModelHandle(model_ref, self._isolation)
else:
return ChatModelHandle(model_ref, self._isolation)
Expand Down Expand Up @@ -352,10 +349,7 @@ def get_model(self, model_uid: str) -> RESTfulModelHandle:
or model_spec["model_name"] == "chatglm2"
):
return RESTfulChatglmCppChatModelHandle(model_uid, self.base_url)
elif (
model_spec["model_name"] == "baichuan"
or model_spec["model_name"] == "baichuan-inc/Baichuan-7B"
):
elif model_spec["model_name"] == "baichuan":
return RESTfulGenerateModelHandle(model_uid, self.base_url)
else:
return RESTfulChatModelHandle(model_uid, self.base_url)
29 changes: 8 additions & 21 deletions xinference/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,31 +85,18 @@ def __str__(self):

def __iter__(self):
model_specs = []
if self.model_format == "pytorch":
for model_size in self.model_sizes_in_billions:
for model_size in self.model_sizes_in_billions:
for quantization in self.quantizations:
model_specs.append(
ModelSpec(
model_name=self.model_name,
model_size_in_billions=model_size,
model_format=self.model_format,
quantization=None,
url=None,
quantization=quantization,
url=self.url_generator(model_size, quantization),
)
)
return iter(model_specs)
else:
for model_size in self.model_sizes_in_billions:
for quantization in self.quantizations:
model_specs.append(
ModelSpec(
model_name=self.model_name,
model_size_in_billions=model_size,
model_format=self.model_format,
quantization=quantization,
url=self.url_generator(model_size, quantization),
)
)
return iter(model_specs)
return iter(model_specs)

def match(
self,
Expand Down Expand Up @@ -146,9 +133,6 @@ def cache(
model_size_in_billions: Optional[int] = None,
quantization: Optional[str] = None,
) -> str:
if self.model_format == "pytorch":
return self.model_name

# by default, choose the smallest size.
model_size_in_billions = (
model_size_in_billions or self.model_sizes_in_billions[0]
Expand All @@ -158,6 +142,9 @@ def cache(

url = self.url_generator(model_size_in_billions, quantization)

if self.model_format == "pytorch":
return url

full_name = f"{str(self)}-{model_size_in_billions}b-{quantization}"
save_path, meta_path = self.generate_cache_path(
model_size_in_billions, quantization
Expand Down
34 changes: 27 additions & 7 deletions xinference/model/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def install():
from .chatglm import ChatglmCppChatModel
from .core import LlamaCppModel
from .orca import OrcaMiniGgml
from .pytorch.baichuan import BaichuanPytorch
from .pytorch.baichuan import BaichuanPytorch, BaichuanPytorchChat
from .pytorch.vicuna import VicunaCensoredPytorch
from .vicuna import VicunaCensoredGgml
from .wizardlm import WizardlmGgml
Expand Down Expand Up @@ -211,24 +211,44 @@ def install():
)
)

pytorch_baichuan_name_generator = lambda model_size, quantization: (
f"baichuan-inc/Baichuan-{model_size}B"
)
MODEL_FAMILIES.append(
ModelFamily(
model_name="baichuan-inc/Baichuan-7B",
model_name="baichuan",
model_sizes_in_billions=[7],
model_format="pytorch",
quantizations=None,
url_generator=None,
quantizations=[None],
url_generator=pytorch_baichuan_name_generator,
cls=BaichuanPytorch,
),
)

pytorch_baichuan_chat_name_generator = lambda model_size, quantization: (
f"baichuan-inc/Baichuan-{model_size}B-Chat"
)
MODEL_FAMILIES.append(
ModelFamily(
model_name="baichuan-chat",
model_sizes_in_billions=[13],
model_format="pytorch",
quantizations=["int4", "int8"],
url_generator=pytorch_baichuan_chat_name_generator,
cls=BaichuanPytorchChat,
),
)

pytorch_vicuna_v1_3_name_generator = lambda model_size, quantization: (
f"lmsys/vicuna-{model_size}b-v1.3"
)
MODEL_FAMILIES.append(
ModelFamily(
model_name="lmsys/vicuna-7b-v1.3",
model_name="vicuna-v1.3",
model_sizes_in_billions=[7, 13],
model_format="pytorch",
quantizations=None,
url_generator=None,
quantizations=[None],
url_generator=pytorch_vicuna_v1_3_name_generator,
cls=VicunaCensoredPytorch,
),
)
61 changes: 60 additions & 1 deletion xinference/model/llm/pytorch/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import TYPE_CHECKING, Optional

from ....constants import XINFERENCE_CACHE_DIR
from .core import PytorchModel, PytorchModelConfig
from .core import PytorchChatModel, PytorchModel, PytorchModelConfig

if TYPE_CHECKING:
from ... import ModelSpec
Expand Down Expand Up @@ -62,3 +62,62 @@ def _load_model(self, kwargs: dict):
**kwargs,
)
return model, tokenizer


class BaichuanPytorchChat(PytorchChatModel):
_system_prompt = (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
)
_sep = "\n###"
_user_name = "User"
_assistant_name = "Assistant"
_stop = "###"

def __init__(
self,
model_uid: str,
model_spec: "ModelSpec",
model_path: str,
pytorch_model_config: Optional[PytorchModelConfig] = None,
):
super().__init__(
model_uid,
model_spec,
model_path,
system_prompt=self._system_prompt,
sep=self._sep,
user_name=self._user_name,
assistant_name=self._assistant_name,
stop=self._stop,
pytorch_model_config=pytorch_model_config,
)

def _load_model(self, kwargs: dict):
try:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation.utils import GenerationConfig
except ImportError:
error_message = "Failed to import module 'transformers'"
installation_guide = [
"Please make sure 'transformers' is installed. ",
"You can install it by `pip install transformers`\n",
]

raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")

tokenizer = AutoTokenizer.from_pretrained(
self._model_path,
use_fast=False,
trust_remote_code=True,
revision=kwargs["revision"],
cache_dir=XINFERENCE_CACHE_DIR,
)
model = AutoModelForCausalLM.from_pretrained(
self._model_path,
trust_remote_code=True,
cache_dir=XINFERENCE_CACHE_DIR,
**kwargs,
)
model.generation_config = GenerationConfig.from_pretrained(self._model_path)
return model, tokenizer
10 changes: 10 additions & 0 deletions xinference/model/llm/pytorch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ def load(self):

self._model, self._tokenizer = self._load_model(kwargs)

quantization = self.model_spec.quantization
if quantization == "int4":
self._model = self._model.quantize(4)
elif quantization == "int8":
self._model == self._model.quantize(8)

if (
device == "cuda" and num_gpus == 1 and not cpu_offloading
) or device == "mps":
Expand Down Expand Up @@ -204,13 +210,15 @@ def __init__(
sep: str,
user_name: str,
assistant_name: str,
stop: Optional[str] = None,
pangyoki marked this conversation as resolved.
Show resolved Hide resolved
pytorch_model_config: Optional[PytorchModelConfig] = None,
):
super().__init__(model_uid, model_spec, model_path, pytorch_model_config)
self._system_prompt: str = system_prompt
self._sep: str = sep
self._user_name: str = user_name
self._assistant_name: str = assistant_name
self._stop: Optional[str] = stop

def chat(
self,
Expand All @@ -224,6 +232,8 @@ def chat(
full_prompt = self._to_prompt(prompt, system_prompt, chat_history=chat_history)

generate_config = self._sanitize_generate_config(generate_config)
if "stop" not in generate_config and self._stop is not None:
pangyoki marked this conversation as resolved.
Show resolved Hide resolved
generate_config["stop"] = self._stop

stream = generate_config.get("stream", False)
if stream:
Expand Down