Skip to content

Commit

Permalink
FEAT: Refactor device related code and add initial Intel GPU support
Browse files Browse the repository at this point in the history
  • Loading branch information
notsyncing authored and notsyncing committed Feb 2, 2024
1 parent 749ef3f commit f7b70f9
Show file tree
Hide file tree
Showing 16 changed files with 206 additions and 83 deletions.
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ all =
auto-gptq ; sys_platform!='darwin'
optimum
flash-attn
intel =
torch==2.1.0a0
intel_extension_for_pytorch==2.1.10+xpu
ggml =
llama-cpp-python>=0.2.25
ctransformers
Expand Down
3 changes: 2 additions & 1 deletion xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
logger = logging.getLogger(__name__)

from .utils import json_dumps, log_async
from ..device_utils import empty_cache

try:
from torch.cuda import OutOfMemoryError
Expand Down Expand Up @@ -141,7 +142,7 @@ async def __pre_destroy__(self):

del self._model
gc.collect()
torch.cuda.empty_cache()
empty_cache()

def __init__(
self,
Expand Down
6 changes: 3 additions & 3 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,11 @@ async def get_builtin_families() -> Dict[str, List[str]]:
}

async def get_devices_count(self) -> int:
from ..utils import cuda_count
from ..device_utils import gpu_count

if self.is_local_deployment():
return cuda_count()
# distributed deployment, choose a worker and return its cuda_count.
return gpu_count()
# distributed deployment, choose a worker and return its device_count.
# Assume that each worker has the same count of cards.
worker_ref = await self._choose_worker()
return await worker_ref.get_devices_count()
Expand Down
26 changes: 13 additions & 13 deletions xinference/core/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ..core import ModelActor
from ..core.status_guard import LaunchStatus
from ..model.core import ModelDescription, create_model_instance
from ..utils import cuda_count
from ..device_utils import gpu_count
from .event import Event, EventCollectorActor, EventType
from .metrics import launch_metrics_export_server, record_metrics
from .resource import gather_node_info
Expand All @@ -54,13 +54,13 @@ def __init__(
self,
supervisor_address: str,
main_pool: MainActorPoolType,
cuda_devices: List[int],
gpu_devices: List[int],
metrics_exporter_host: Optional[str] = None,
metrics_exporter_port: Optional[int] = None,
):
super().__init__()
# static attrs.
self._total_cuda_devices = cuda_devices
self._total_gpu_devices = gpu_devices
self._supervisor_address = supervisor_address
self._supervisor_ref = None
self._main_pool = main_pool
Expand Down Expand Up @@ -244,9 +244,9 @@ async def __pre_destroy__(self):

@staticmethod
def get_devices_count():
from ..utils import cuda_count
from ..device_utils import gpu_count

return cuda_count()
return gpu_count()

@log_sync(logger=logger)
def get_model_count(self) -> int:
Expand All @@ -263,7 +263,7 @@ async def allocate_devices_for_embedding(self, model_uid: str) -> int:
we assume that embedding model only takes 1 GPU slot.
"""
candidates = []
for _dev in self._total_cuda_devices:
for _dev in self._total_gpu_devices:
if _dev not in self._gpu_to_model_uid:
candidates.append(_dev)
else:
Expand Down Expand Up @@ -291,11 +291,11 @@ async def allocate_devices_for_embedding(self, model_uid: str) -> int:
return device

def allocate_devices(self, model_uid: str, n_gpu: int) -> List[int]:
if n_gpu > len(self._total_cuda_devices) - len(self._gpu_to_model_uid):
if n_gpu > len(self._total_gpu_devices) - len(self._gpu_to_model_uid):
raise RuntimeError("No available slot found for the model")

devices: List[int] = [
dev for dev in self._total_cuda_devices if dev not in self._gpu_to_model_uid
dev for dev in self._total_gpu_devices if dev not in self._gpu_to_model_uid
][:n_gpu]
for dev in devices:
self._gpu_to_model_uid[int(dev)] = model_uid
Expand Down Expand Up @@ -324,7 +324,7 @@ async def _create_subpool(
) -> Tuple[str, List[str]]:
env = {}
devices = []
if isinstance(n_gpu, int) or (n_gpu == "auto" and cuda_count() > 0):
if isinstance(n_gpu, int) or (n_gpu == "auto" and gpu_count() > 0):
# Currently, n_gpu=auto means using 1 GPU
gpu_cnt = n_gpu if isinstance(n_gpu, int) else 1
devices = (
Expand Down Expand Up @@ -396,10 +396,10 @@ async def launch_speculative_model(
n_gpu: Optional[Union[int, str]] = "auto",
):
if n_gpu is not None:
if isinstance(n_gpu, int) and (n_gpu <= 0 or n_gpu > cuda_count()):
if isinstance(n_gpu, int) and (n_gpu <= 0 or n_gpu > gpu_count()):
raise ValueError(
f"The parameter `n_gpu` must be greater than 0 and "
f"not greater than the number of GPUs: {cuda_count()} on the machine."
f"not greater than the number of GPUs: {gpu_count()} on the machine."
)
if isinstance(n_gpu, str) and n_gpu != "auto":
raise ValueError("Currently `n_gpu` only supports `auto`.")
Expand Down Expand Up @@ -504,10 +504,10 @@ async def launch_builtin_model(
launch_args.pop("kwargs")
launch_args.update(kwargs)
if n_gpu is not None:
if isinstance(n_gpu, int) and (n_gpu <= 0 or n_gpu > cuda_count()):
if isinstance(n_gpu, int) and (n_gpu <= 0 or n_gpu > gpu_count()):
raise ValueError(
f"The parameter `n_gpu` must be greater than 0 and "
f"not greater than the number of GPUs: {cuda_count()} on the machine."
f"not greater than the number of GPUs: {gpu_count()} on the machine."
)
if isinstance(n_gpu, str) and n_gpu != "auto":
raise ValueError("Currently `n_gpu` only supports `auto`.")
Expand Down
14 changes: 7 additions & 7 deletions xinference/deploy/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from xoscar import MainActorPoolType

from ..core.worker import WorkerActor
from ..utils import cuda_count
from ..device_utils import gpu_count

logger = logging.getLogger(__name__)

Expand All @@ -33,20 +33,20 @@ async def start_worker_components(
metrics_exporter_host: Optional[str],
metrics_exporter_port: Optional[int],
):
cuda_device_indices = []
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
if cuda_visible_devices:
cuda_device_indices.extend([int(i) for i in cuda_visible_devices.split(",")])
gpu_device_indices = []
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if cuda_visible_devices is not None and cuda_visible_devices != "-1":
gpu_device_indices.extend([int(i) for i in cuda_visible_devices.split(",")])
else:
cuda_device_indices = list(range(cuda_count()))
gpu_device_indices = list(range(gpu_count()))

await xo.create_actor(
WorkerActor,
address=address,
uid=WorkerActor.uid(),
supervisor_address=supervisor_address,
main_pool=main_pool,
cuda_devices=cuda_device_indices,
gpu_devices=gpu_device_indices,
metrics_exporter_host=metrics_exporter_host,
metrics_exporter_port=metrics_exporter_port,
)
Expand Down
105 changes: 105 additions & 0 deletions xinference/device_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing_extensions import Literal, Union

import torch

try:
import intel_extension_for_pytorch
except:
pass


DeviceType = Literal["cuda", "mps", "xpu", "cpu"]


def is_xpu_available() -> bool:
return hasattr(torch, "xpu") and torch.xpu.is_available()


def get_available_device() -> DeviceType:
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
elif is_xpu_available():
return "xpu"
return "cpu"


def is_device_available(device: str) -> bool:
if device == "cuda":
return torch.cuda.is_available()
elif device == "mps":
return torch.backends.mps.is_available()
elif device == "xpu":
return is_xpu_available()
elif device == "cpu":
return True

return False


def move_model_to_available_device(model):
device = get_available_device()

if device == "cpu":
return model

return model.to(device)


def get_device_preferred_dtype(device: str) -> Union[torch.dtype, None]:
if device == "cpu":
return torch.float32
elif device == "cuda" or device == "mps":
return torch.float16
elif device == "xpu":
return torch.bfloat16

return None


def is_hf_accelerate_supported(device: str) -> bool:
return device == "cuda" or device == "xpu"


def empty_cache():
if torch.cuda.is_available():
torch.cuda.empty_cache()
if torch.backends.mps.is_available():
torch.mps.empty_cache()
if is_xpu_available():
torch.xpu.empty_cache()


def gpu_count():
if torch.cuda.is_available():
cuda_visible_devices_env = os.getenv("CUDA_VISIBLE_DEVICES", None)

if cuda_visible_devices_env is None:
return torch.cuda.device_count()

cuda_visible_devices = (
cuda_visible_devices_env.split(",") if cuda_visible_devices_env else []
)

return min(torch.cuda.device_count(), len(cuda_visible_devices))
elif torch.backends.mps.is_available():
return 1
elif is_xpu_available():
return torch.xpu.device_count()
else:
return 0
11 changes: 9 additions & 2 deletions xinference/model/audio/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import logging
from typing import TYPE_CHECKING, Dict, Optional

from xinference.device_utils import get_available_device, is_device_available, get_device_preferred_dtype

if TYPE_CHECKING:
from .core import AudioModelFamilyV1

Expand All @@ -40,8 +42,13 @@ def load(self):
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
if device is None:
device = get_available_device()
else:
if not is_device_available(device):
raise ValueError(f"Device {device} is not available!")

torch_dtype = get_device_preferred_dtype(device)

model = AutoModelForSpeechSeq2Seq.from_pretrained(
self._model_path,
Expand Down
6 changes: 2 additions & 4 deletions xinference/model/image/stable_diffusion/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from io import BytesIO
from typing import List, Optional, Union

from ....device_utils import move_model_to_available_device
from ....constants import XINFERENCE_IMAGE_DIR
from ....types import Image, ImageList

Expand Down Expand Up @@ -57,10 +58,7 @@ def load(self):
# torch_dtype=torch.float16,
# use_safetensors=True,
)
if torch.cuda.is_available():
self._model = self._model.to("cuda")
elif torch.backends.mps.is_available():
self._model = self._model.to("mps")
self._model = move_model_to_available_device(self._model)
# Recommended if your computer has < 64 GB of RAM
self._model.enable_attention_slicing()

Expand Down
4 changes: 3 additions & 1 deletion xinference/model/llm/pytorch/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from ....device_utils import empty_cache


@dataclasses.dataclass
class CompressionConfig:
Expand Down Expand Up @@ -153,7 +155,7 @@ def load_compress_model(
tmp_state_dict[name] = None
tensor = None
gc.collect()
torch.cuda.empty_cache()
empty_cache()

for name in model.state_dict():
if name not in linear_weights:
Expand Down
Loading

0 comments on commit f7b70f9

Please sign in to comment.