diff --git a/xinference/core/api.py b/xinference/core/api.py index 394f0e168b..2789ca3af5 100644 --- a/xinference/core/api.py +++ b/xinference/core/api.py @@ -66,6 +66,11 @@ async def get_model(self, model_uid: str) -> xo.ActorRefType["ModelActor"]: supervisor_ref = await self._get_supervisor_ref() return await supervisor_ref.get_model(model_uid) + async def is_local_deployment(self) -> bool: + # TODO: temporary. + supervisor_ref = await self._get_supervisor_ref() + return await supervisor_ref.is_local_deployment() + class SyncSupervisorAPI: def __init__(self, supervisor_address: str): @@ -124,3 +129,11 @@ async def _get_model(): return await supervisor_ref.get_model(model_uid) return self._isolation.call(_get_model()) + + def is_local_deployment(self) -> bool: + # TODO: temporary. + async def _is_local_deployment(): + supervisor_ref = await self._get_supervisor_ref() + return await supervisor_ref.is_local_deployment() + + return self._isolation.call(_is_local_deployment()) diff --git a/xinference/core/gradio.py b/xinference/core/gradio.py index c77c956d73..8550e8b206 100644 --- a/xinference/core/gradio.py +++ b/xinference/core/gradio.py @@ -301,7 +301,11 @@ def select_model( progress=gr.Progress(), ): match_result = match_llm( - _model_name, _model_format, int(_model_size_in_billions), _quantization + _model_name, + _model_format, + int(_model_size_in_billions), + _quantization, + self._api.is_local_deployment(), ) if not match_result: raise ValueError( diff --git a/xinference/core/service.py b/xinference/core/service.py index 97578e8685..753130fa46 100644 --- a/xinference/core/service.py +++ b/xinference/core/service.py @@ -182,6 +182,13 @@ async def list_models(self) -> Dict[str, Dict[str, Any]]: ret.update(await worker.list_models()) return ret + def is_local_deployment(self) -> bool: + # TODO: temporary. + return ( + len(self._worker_address_to_worker) == 1 + and list(self._worker_address_to_worker)[0] == self.address + ) + @log async def add_worker(self, worker_address: str): assert worker_address not in self._worker_address_to_worker @@ -290,8 +297,13 @@ async def launch_builtin_model( from ..model.llm import match_llm, match_llm_cls + assert self._supervisor_ref is not None match_result = match_llm( - model_name, model_format, model_size_in_billions, quantization + model_name, + model_format, + model_size_in_billions, + quantization, + await self._supervisor_ref.is_local_deployment(), ) if not match_result: raise ValueError( diff --git a/xinference/model/llm/__init__.py b/xinference/model/llm/__init__.py index 7b82ede1e8..22f93c9f32 100644 --- a/xinference/model/llm/__init__.py +++ b/xinference/model/llm/__init__.py @@ -13,7 +13,9 @@ # limitations under the License. import json +import logging import os +import platform from typing import List, Optional, Tuple, Type from .core import LLM @@ -29,12 +31,29 @@ LLM_FAMILIES: List["LLMFamilyV1"] = [] +logger = logging.getLogger(__name__) + + +def _is_linux(): + return platform.system() == "Linux" + + +def _has_cuda_device(): + cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES") + if cuda_visible_devices: + return True + else: + from xorbits._mars.resource import cuda_count + + return cuda_count() > 0 + def match_llm( model_name: str, model_format: Optional[str] = None, model_size_in_billions: Optional[int] = None, quantization: Optional[str] = None, + is_local_deployment: bool = False, ) -> Optional[Tuple[LLMFamilyV1, LLMSpecV1, str]]: """ Find an LLM family, spec, and quantization that satisfy given criteria. @@ -52,8 +71,25 @@ def match_llm( and quantization not in spec.quantizations ): continue - # by default, choose the most coarse-grained quantization. - return family, spec, quantization or spec.quantizations[0] + if quantization: + return family, spec, quantization + else: + # by default, choose the most coarse-grained quantization. + # TODO: too hacky. + quantizations = spec.quantizations + quantizations.sort() + for q in quantizations: + if ( + is_local_deployment + and not (_is_linux() and _has_cuda_device()) + and q == "4-bit" + ): + logger.warning( + "Skipping %s for non-linux or non-cuda local deployment .", + q, + ) + continue + return family, spec, q return None