Skip to content

Commit 8115fa9

Browse files
committed
lmdeploy 支持 dtype 配置
1 parent 264b612 commit 8115fa9

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

gpt_server/model_backend/lmdeploy_backend.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,21 @@ def __init__(self, model_path) -> None:
2929
backend = backend_map[os.getenv("backend")]
3030
enable_prefix_caching = bool(os.getenv("enable_prefix_caching", False))
3131
max_model_len = os.getenv("max_model_len", None)
32+
dtype = os.getenv("dtype", "auto")
3233
logger.info(f"后端 {backend}")
3334
if backend == "pytorch":
34-
backend_config = PytorchEngineConfig(tp=int(os.getenv("num_gpus", "1")))
35+
backend_config = PytorchEngineConfig(
36+
tp=int(os.getenv("num_gpus", "1")),
37+
dtype=dtype,
38+
session_len=int(max_model_len) if max_model_len else None,
39+
enable_prefix_caching=enable_prefix_caching,
40+
)
3541
if backend == "turbomind":
3642
backend_config = TurbomindEngineConfig(
3743
tp=int(os.getenv("num_gpus", "1")),
3844
enable_prefix_caching=enable_prefix_caching,
3945
session_len=int(max_model_len) if max_model_len else None,
46+
dtype=dtype,
4047
)
4148
pipeline_type, pipeline_class = get_task(model_path)
4249
logger.info(f"模型架构:{pipeline_type}")

gpt_server/model_backend/vllm_backend.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ class VllmBackend(ModelBackend):
2727
def __init__(self, model_path) -> None:
2828
lora = os.getenv("lora", None)
2929
enable_prefix_caching = bool(os.getenv("enable_prefix_caching", False))
30-
3130
max_model_len = os.getenv("max_model_len", None)
32-
3331
tensor_parallel_size = int(os.getenv("num_gpus", "1"))
3432
dtype = os.getenv("dtype", "auto")
3533
max_loras = 1

0 commit comments

Comments
 (0)