Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Refactoring the LoRa adaptation method for the LLM model. #1470

Merged
merged 32 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
6ce7268
Enable multi-LoRa fine-tuning for enhanced performance
hainaweiben Apr 10, 2024
c89bc1a
Merge branch 'xorbitsai:main' into lora-multi-support
hainaweiben Apr 10, 2024
93717cd
first commit
hainaweiben Apr 11, 2024
e39803e
Merge branch 'lora-multi-support' of https://github.com/hainaweiben/i…
hainaweiben Apr 11, 2024
d1b96f0
commit
hainaweiben Apr 11, 2024
05c66a3
commit
hainaweiben Apr 11, 2024
a4b0617
add rst
hainaweiben Apr 11, 2024
4c9c134
Merge branch 'lora-multi-support' of https://github.com/hainaweiben/i…
hainaweiben Apr 11, 2024
db9f946
fix
hainaweiben Apr 11, 2024
9171384
fix
hainaweiben Apr 12, 2024
e16bd88
Modify the interfaces and documentation of the RESTful client.
hainaweiben Apr 12, 2024
3076876
fix mypy error
hainaweiben Apr 12, 2024
5b2233d
fixes
hainaweiben Apr 12, 2024
b33f811
Merge branch 'lora-multi-support' of https://github.com/hainaweiben/i…
hainaweiben Apr 16, 2024
e080622
fix
hainaweiben Apr 16, 2024
aa56884
fixes code style
hainaweiben Apr 16, 2024
b23f004
Merge branch 'lora-multi-support' of https://github.com/hainaweiben/i…
hainaweiben Apr 16, 2024
c070122
remove log
hainaweiben Apr 16, 2024
baf6a12
Merge branch 'lora-multi-support' of https://github.com/hainaweiben/i…
hainaweiben Apr 17, 2024
f2c8228
Merge branch 'lora-multi-support' of https://github.com/hainaweiben/i…
hainaweiben Apr 17, 2024
90e2deb
Merge branch 'main' of https://github.com/hainaweiben/inference into …
hainaweiben Apr 17, 2024
53cb112
fixed
hainaweiben Apr 18, 2024
b2d3ce6
Merge branch 'xorbitsai:main' into support-vllm-lora
hainaweiben Apr 18, 2024
e5c6a5b
Merge branch 'support-vllm-lora' of https://github.com/hainaweiben/in…
hainaweiben Apr 18, 2024
2aef458
fix
hainaweiben Apr 18, 2024
eb33a69
fixed
hainaweiben Apr 18, 2024
d87b18b
Re-implementing LoRa support for VLLM.
hainaweiben May 7, 2024
6fc7509
Merge branch 'support-vllm-lora' of https://github.com/hainaweiben/in…
hainaweiben May 7, 2024
7ae2e34
Provide Lora support outside of the VLLM.
hainaweiben May 7, 2024
647a189
fix bug
hainaweiben May 8, 2024
2f429bd
fix 🐛
hainaweiben May 9, 2024
e9f75b8
Merge branch 'xorbitsai:main' into support-vllm-lora
hainaweiben May 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/models/lora.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ Note
They correspond to the parameters in the ``load_lora_weights`` and ``fuse_lora`` interfaces of the ``diffusers`` library.
If launching an LLM model, these parameters are not required.

* You need to add the parameter lora_name during inference to specify the corresponding lora model. You can specify it in the Additional Inputs option.

* For LLM chat models, currently only LoRA models are supported that do not change the prompt style.

* When using GPU, both LoRA and its base model occupy the same devices.
14 changes: 10 additions & 4 deletions xinference/core/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def generate_wrapper(
history: List[List[str]],
max_tokens: int,
temperature: float,
lora_name: str,
) -> Generator:
from ..client import RESTfulClient

Expand All @@ -127,6 +128,7 @@ def generate_wrapper(
"max_tokens": int(max_tokens),
"temperature": temperature,
"stream": True,
"lora_name": lora_name,
},
):
assert isinstance(chunk, dict)
Expand All @@ -152,6 +154,7 @@ def generate_wrapper(
gr.Slider(
minimum=0, maximum=2, value=1, step=0.01, label="Temperature"
),
gr.Text(label="LoRA Name"),
],
title=f"🚀 Xinference Chat Bot : {self.model_name} 🚀",
css="""
Expand Down Expand Up @@ -331,7 +334,7 @@ def clear(text, hist):
history: hist,
}

def complete(text, hist, max_tokens, temperature) -> Generator:
def complete(text, hist, max_tokens, temperature, lora_name) -> Generator:
from ..client import RESTfulClient

client = RESTfulClient(self.endpoint)
Expand All @@ -349,6 +352,7 @@ def complete(text, hist, max_tokens, temperature) -> Generator:
"max_tokens": max_tokens,
"temperature": temperature,
"stream": True,
"lora_name": lora_name,
},
):
assert isinstance(chunk, dict)
Expand All @@ -368,7 +372,7 @@ def complete(text, hist, max_tokens, temperature) -> Generator:
history: hist,
}

def retry(text, hist, max_tokens, temperature) -> Generator:
def retry(text, hist, max_tokens, temperature, lora_name) -> Generator:
from ..client import RESTfulClient

client = RESTfulClient(self.endpoint)
Expand All @@ -387,6 +391,7 @@ def retry(text, hist, max_tokens, temperature) -> Generator:
"max_tokens": max_tokens,
"temperature": temperature,
"stream": True,
"lora_name": lora_name,
},
):
assert isinstance(chunk, dict)
Expand Down Expand Up @@ -470,10 +475,11 @@ def retry(text, hist, max_tokens, temperature) -> Generator:
temperature = gr.Slider(
minimum=0, maximum=2, value=1, step=0.01, label="Temperature"
)
lora_name = gr.Text(label="LoRA Name")

btn_generate.click(
fn=complete,
inputs=[textbox, history, length, temperature],
inputs=[textbox, history, length, temperature, lora_name],
outputs=[textbox, history],
)

Expand All @@ -485,7 +491,7 @@ def retry(text, hist, max_tokens, temperature) -> Generator:

btn_retry.click(
fn=retry,
inputs=[textbox, history, length, temperature],
inputs=[textbox, history, length, temperature, lora_name],
outputs=[textbox, history],
)

Expand Down
29 changes: 23 additions & 6 deletions xinference/model/llm/pytorch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,17 @@ def _apply_lora(self):
f"Failed to import 'PeftModel' from 'peft'. Please make sure 'peft' is installed.\n\n"
)

for peft_model in self._peft_model:
# Apply LoRA
self._model = PeftModel.from_pretrained(
self._model,
peft_model.local_path,
)
for i, peft_model in enumerate(self._peft_model):
if i == 0:
self._model = PeftModel.from_pretrained(
self._model,
peft_model.local_path,
adapter_name=peft_model.lora_name,
)
else:
self._model.load_adapter(
peft_model.local_path, adapter_name=peft_model.lora_name
)
logger.info(
f"PEFT adaptor '{peft_model.lora_name}' successfully loaded for model '{self.model_uid}'."
)
Expand Down Expand Up @@ -302,6 +307,18 @@ def generator_wrapper(
assert self._model is not None
assert self._tokenizer is not None

lora_model = generate_config.pop("lora_name")

if lora_model is not None and self._peft_model is not None:
for lora in self._peft_model:
if lora_model == lora.lora_name:
self._model.set_adapter(lora_model)
logger.info(f"Set lora model to {lora_model}")
break
else:
self._model.disable_adapter()
logger.info(f"No lora model {lora_model} found, skip setting")

stream = generate_config.get("stream", False)
if not stream:
if "falcon" in model_family_name:
Expand Down
54 changes: 52 additions & 2 deletions xinference/model/llm/vllm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
CompletionChoice,
CompletionChunk,
CompletionUsage,
LoRA,
ToolCallFunction,
ToolCalls,
)
Expand Down Expand Up @@ -64,6 +65,7 @@ class VLLMModelConfig(TypedDict, total=False):


class VLLMGenerateConfig(TypedDict, total=False):
lora_name: Optional[str]
n: int
best_of: Optional[int]
presence_penalty: float
Expand Down Expand Up @@ -143,16 +145,30 @@ def __init__(
quantization: str,
model_path: str,
model_config: Optional[VLLMModelConfig],
peft_model: Optional[List[LoRA]] = None,
):
try:
from vllm.lora.request import LoRARequest
except ImportError:
error_message = "Failed to import module 'vllm'"
installation_guide = [
"Please make sure 'vllm' is installed. ",
"You can install it by `pip install vllm`\n",
]

raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")
super().__init__(model_uid, model_family, model_spec, quantization, model_path)
self._model_config = model_config
self._engine = None
self.lora_modules = peft_model
self.lora_requests: List[LoRARequest] = []

def load(self):
try:
import vllm
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.lora.request import LoRARequest
except ImportError:
error_message = "Failed to import module 'vllm'"
installation_guide = [
Expand All @@ -171,11 +187,33 @@ def load(self):
multiprocessing.set_start_method("fork", force=True)

self._model_config = self._sanitize_model_config(self._model_config)

if self.lora_modules is None:
self.lora_requests = []
else:
self.lora_requests = [
LoRARequest(
lora_name=lora.lora_name,
lora_int_id=i,
lora_local_path=lora.local_path,
)
for i, lora in enumerate(self.lora_modules, start=1)
]

enable_lora = len(self.lora_requests) > 0
max_loras = len(self.lora_requests)

logger.info(
f"Loading {self.model_uid} with following model config: {self._model_config}"
f"Enable lora: {enable_lora}. Lora count: {max_loras}."
)

engine_args = AsyncEngineArgs(model=self.model_path, **self._model_config)
engine_args = AsyncEngineArgs(
model=self.model_path,
enable_lora=enable_lora,
max_loras=max_loras,
**self._model_config,
)
self._engine = AsyncLLMEngine.from_engine_args(engine_args)

def _sanitize_model_config(
Expand Down Expand Up @@ -206,6 +244,7 @@ def _sanitize_generate_config(
generate_config = {}

sanitized = VLLMGenerateConfig()
sanitized.setdefault("lora_name", generate_config.get("lora_name", None))
sanitized.setdefault("n", generate_config.get("n", 1))
sanitized.setdefault("best_of", generate_config.get("best_of", None))
sanitized.setdefault(
Expand Down Expand Up @@ -338,12 +377,23 @@ async def async_generate(
"Enter generate, prompt: %s, generate config: %s", prompt, generate_config
)

lora_model = sanitized_generate_config.pop("lora_name")

lora_request = None
if lora_model is not None:
for lora in self.lora_requests:
if lora_model == lora.lora_name:
lora_request = lora
break

stream = sanitized_generate_config.pop("stream")
sampling_params = SamplingParams(**sanitized_generate_config)
request_id = str(uuid.uuid1())

assert self._engine is not None
results_generator = self._engine.generate(prompt, sampling_params, request_id)
results_generator = self._engine.generate(
prompt, sampling_params, request_id, lora_request=lora_request
)

async def stream_results() -> AsyncGenerator[CompletionChunk, None]:
previous_texts = [""] * sanitized_generate_config["n"]
Expand Down
3 changes: 3 additions & 0 deletions xinference/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ class ChatglmCppGenerateConfig(TypedDict, total=False):
top_p: float
temperature: float
stream: bool
lora_name: Optional[str]


class QWenCppModelConfig(TypedDict, total=False):
Expand Down Expand Up @@ -279,6 +280,7 @@ class PytorchGenerateConfig(TypedDict, total=False):
stream_interval: int
model: Optional[str]
tools: Optional[List[Dict]]
lora_name: Optional[str]


class PytorchModelConfig(TypedDict, total=False):
Expand Down Expand Up @@ -354,6 +356,7 @@ class CreateCompletionTorch(BaseModel):
temperature: float = temperature_field
top_p: float = top_p_field
top_k: int = top_k_field
lora_name: Optional[str]


CreateCompletionLlamaCpp: BaseModel
Expand Down
Loading