From e4115d1fb2e1299ae383593644b70b0ae30c47a7 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Wed, 26 Jul 2023 11:05:53 +0800 Subject: [PATCH] FEAT: support fp4 and int8 quantization for pytorch model (#238) --- .github/workflows/python.yaml | 1 + setup.cfg | 6 +- xinference/core/gradio.py | 4 +- xinference/core/service.py | 9 - xinference/model/__init__.py | 6 +- xinference/model/llm/__init__.py | 10 +- xinference/model/llm/pytorch/compression.py | 261 ++++++++++++++++++++ xinference/model/llm/pytorch/core.py | 59 +++-- 8 files changed, 318 insertions(+), 38 deletions(-) create mode 100644 xinference/model/llm/pytorch/compression.py diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index d565679cce..e4e9bdbb05 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -91,6 +91,7 @@ jobs: pip install accelerate pip install sentencepiece pip install transformers_stream_generator + pip install bitsandbytes pip install -e ".[dev]" working-directory: . diff --git a/setup.cfg b/setup.cfg index 61e0edbe8a..a92f0d1654 100644 --- a/setup.cfg +++ b/setup.cfg @@ -60,12 +60,12 @@ dev = all = chatglm-cpp llama-cpp-python - transformers + transformers>=4.31.0 torch - accelerate + accelerate>=0.20.3 sentencepiece transformers_stream_generator - cpm_kernels; platform_system != "Darwin" + bitsandbytes doc = ipython>=6.5.0 sphinx>=3.0.0,<5.0.0 diff --git a/xinference/core/gradio.py b/xinference/core/gradio.py index 083f663a0f..da450ac606 100644 --- a/xinference/core/gradio.py +++ b/xinference/core/gradio.py @@ -272,9 +272,7 @@ def select_model( cache_path = model_family.generate_cache_path( int(_model_size_in_billions), _quantization ) - if not (os.path.exists(cache_path)): - if os.path.exists(cache_path): - os.remove(cache_path) + if _model_format != "pytorch" and not (os.path.exists(cache_path)): url = model_family.url_generator( int(_model_size_in_billions), _quantization ) diff --git a/xinference/core/service.py b/xinference/core/service.py index be7d05f4ae..bf8452f592 100644 --- a/xinference/core/service.py +++ b/xinference/core/service.py @@ -13,7 +13,6 @@ # limitations under the License. import asyncio -import platform import time from dataclasses import dataclass from logging import getLogger @@ -247,13 +246,6 @@ def _choose_subpool(self) -> str: raise RuntimeError("No available slot found") - def _check_model_is_valid(self, model_name): - # baichuan-base and baichuan-chat depend on `cpm_kernels` module, - # but `cpm_kernels` cannot run on Darwin system. - if platform.system() == "Darwin": - if model_name in ["baichuan-base", "baichuan-chat"]: - raise ValueError(f"{model_name} model can't run on Darwin system.") - @log async def launch_builtin_model( self, @@ -265,7 +257,6 @@ async def launch_builtin_model( **kwargs, ) -> xo.ActorRefType["ModelActor"]: assert model_uid not in self._model_uid_to_model - self._check_model_is_valid(model_name) from ..model import MODEL_FAMILIES diff --git a/xinference/model/__init__.py b/xinference/model/__init__.py index e6302f7740..0019887438 100644 --- a/xinference/model/__init__.py +++ b/xinference/model/__init__.py @@ -150,6 +150,9 @@ def cache( url = self.url_generator(model_size_in_billions, quantization) rp_url = self.rp_url_generator(model_size_in_billions, quantization) + if self.model_format == "pytorch": + return url + try: rp_fetch = requests.get(rp_url) except RequestException: @@ -167,9 +170,6 @@ def cache( str(splitted_res_content[index + 1], encoding="utf-8") ) - if self.model_format == "pytorch": - return url - full_name = f"{str(self)}-{model_size_in_billions}b-{quantization}" save_path = self.generate_cache_path(model_size_in_billions, quantization) diff --git a/xinference/model/llm/__init__.py b/xinference/model/llm/__init__.py index c96ec6930f..0c5e8c1885 100644 --- a/xinference/model/llm/__init__.py +++ b/xinference/model/llm/__init__.py @@ -376,7 +376,7 @@ def install(): model_name="baichuan", model_sizes_in_billions=[7], model_format="pytorch", - quantizations=["none"], + quantizations=["8-bit", "4-bit", "none"], url_generator=pytorch_baichuan_name_generator, rp_url_generator=lambda model_size, quantization: "", cls=BaichuanPytorch, @@ -391,7 +391,7 @@ def install(): model_name="baichuan-base", model_sizes_in_billions=[13], model_format="pytorch", - quantizations=["int4", "int8", "none"], + quantizations=["8-bit", "4-bit", "none"], url_generator=pytorch_baichuan_base_name_generator, rp_url_generator=lambda model_size, quantization: "", cls=BaichuanPytorch, @@ -406,7 +406,7 @@ def install(): model_name="baichuan-chat", model_sizes_in_billions=[13], model_format="pytorch", - quantizations=["int4", "int8", "none"], + quantizations=["8-bit", "4-bit", "none"], url_generator=pytorch_baichuan_chat_name_generator, rp_url_generator=lambda model_size, quantization: "", cls=BaichuanPytorchChat, @@ -421,9 +421,9 @@ def install(): MODEL_FAMILIES.append( ModelFamily( model_name="vicuna-v1.3", - model_sizes_in_billions=[7, 13], + model_sizes_in_billions=[7, 13, 33], model_format="pytorch", - quantizations=["none"], + quantizations=["8-bit", "4-bit", "none"], url_generator=pytorch_vicuna_v1_3_name_generator, rp_url_generator=lambda model_size, quantization: "", cls=VicunaCensoredPytorch, diff --git a/xinference/model/llm/pytorch/compression.py b/xinference/model/llm/pytorch/compression.py new file mode 100644 index 0000000000..3244a9ac35 --- /dev/null +++ b/xinference/model/llm/pytorch/compression.py @@ -0,0 +1,261 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import gc +import glob +import os + +import torch +import torch.nn as nn +from accelerate import init_empty_weights +from accelerate.utils import set_module_tensor_to_device +from huggingface_hub import snapshot_download +from torch import Tensor +from torch.nn import functional as F +from tqdm import tqdm +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + +from ....constants import XINFERENCE_CACHE_DIR + + +@dataclasses.dataclass +class CompressionConfig: + """Group-wise quantization.""" + + num_bits: int + group_size: int + group_dim: int + symmetric: bool + enabled: bool = True + + +default_compression_config = CompressionConfig( + num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True +) + + +class CLinear(nn.Module): + """Compressed Linear Layer.""" + + def __init__(self, weight=None, bias=None, device=None): + super().__init__() + if weight is None: + self.weight = None + elif isinstance(weight, Tensor): + self.weight = compress(weight.data.to(device), default_compression_config) + else: + self.weight = weight + self.bias = bias + + def forward(self, input: Tensor) -> Tensor: + weight = decompress(self.weight, default_compression_config) + if self.bias is None: + return F.linear(input.to(weight.dtype), weight) + return F.linear(input.to(weight.dtype), weight, self.bias.to(weight.dtype)) + + +def get_compressed_list(module, prefix=""): + compressed_list = [] + for attr_str in dir(module): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + full_name = ( + f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight" + ) + compressed_list.append(full_name) + for name, child in module.named_children(): + child_prefix = f"{prefix}.{name}" if prefix else name + for each in get_compressed_list(child, child_prefix): + compressed_list.append(each) + return compressed_list + + +def apply_compressed_weight(module, compressed_state_dict, target_device, prefix=""): + for attr_str in dir(module): + target_attr = getattr(module, attr_str) + if type(target_attr) == torch.nn.Linear: + full_name = ( + f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight" + ) + setattr( + module, + attr_str, + CLinear( + compressed_state_dict[full_name], target_attr.bias, target_device + ), + ) + for name, child in module.named_children(): + child_prefix = f"{prefix}.{name}" if prefix else name + apply_compressed_weight( + child, compressed_state_dict, target_device, child_prefix + ) + + +def load_compress_model( + model_path: str, + device: str, + torch_dtype: torch.dtype, + use_fast: bool, + revision: str = "main", +): + # partially load model + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=use_fast, + trust_remote_code=True, + revision=revision, + cache_dir=XINFERENCE_CACHE_DIR, + ) + + with init_empty_weights(): + config = AutoConfig.from_pretrained( + model_path, + low_cpu_mem_usage=True, + torch_dtype=torch_dtype, + trust_remote_code=True, + revision=revision, + cache_dir=XINFERENCE_CACHE_DIR, + ) + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + linear_weights = get_compressed_list(model) + + if os.path.exists(model_path): + # `model_path` is a local folder + base_pattern = os.path.join(model_path, "pytorch_model*.bin") + else: + # `model_path` is a cached Hugging Face repo + model_path = snapshot_download( + model_path, revision=revision, cache_dir=XINFERENCE_CACHE_DIR + ) + base_pattern = os.path.join(model_path, "pytorch_model*.bin") + + files = glob.glob(base_pattern) + + compressed_state_dict = {} + + for filename in tqdm(files): + tmp_state_dict = torch.load(filename, map_location=torch.device(device)) + for name in tmp_state_dict: + if name in linear_weights: + tensor = tmp_state_dict[name].to(device).data.to(torch_dtype) + compressed_state_dict[name] = compress( + tensor, default_compression_config + ) + else: + compressed_state_dict[name] = tmp_state_dict[name].to(device) + tmp_state_dict[name] = None + tensor = None + gc.collect() + torch.cuda.empty_cache() + + for name in model.state_dict(): + if name not in linear_weights: + set_module_tensor_to_device( + model, name, device, value=compressed_state_dict[name] + ) + apply_compressed_weight(model, compressed_state_dict, device) + + model.to(device) + + return model, tokenizer + + +def compress(tensor, config): + """Simulate group-wise quantization.""" + if not config.enabled: + return tensor + + group_size, num_bits, group_dim, symmetric = ( + config.group_size, + config.num_bits, + config.group_dim, + config.symmetric, + ) + assert num_bits <= 8 + + original_shape = tensor.shape + num_groups = (original_shape[group_dim] + group_size - 1) // group_size + new_shape = ( + original_shape[:group_dim] + + (num_groups, group_size) + + original_shape[group_dim + 1 :] + ) + + # Pad + pad_len = (group_size - original_shape[group_dim] % group_size) % group_size + if pad_len != 0: + pad_shape = ( + original_shape[:group_dim] + (pad_len,) + original_shape[group_dim + 1 :] + ) + tensor = torch.cat( + [tensor, torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)], + dim=group_dim, + ) + data = tensor.view(new_shape) + + # Quantize + if symmetric: + B = 2 ** (num_bits - 1) - 1 + scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0] + data = data * scale + data = data.clamp_(-B, B).round_().to(torch.int8) + return data, scale, original_shape + else: + B = 2**num_bits - 1 + mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0] + mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0] + + scale = B / (mx - mn) + data = data - mn + data.mul_(scale) + + data = data.clamp_(0, B).round_().to(torch.uint8) + return data, mn, scale, original_shape + + +def decompress(packed_data, config): + """Simulate group-wise dequantization.""" + if not config.enabled: + return packed_data + + group_size, _, group_dim, symmetric = ( + config.group_size, + config.num_bits, + config.group_dim, + config.symmetric, + ) + + # Dequantize + if symmetric: + data, scale, original_shape = packed_data + data = data / scale + else: + data, mn, scale, original_shape = packed_data + data = data / scale + data.add_(mn) + + # Unpad + pad_len = (group_size - original_shape[group_dim] % group_size) % group_size + if pad_len: + padded_original_shape = ( + original_shape[:group_dim] + + (original_shape[group_dim] + pad_len,) + + original_shape[group_dim + 1 :] + ) + data = data.reshape(padded_original_shape) + indices = [slice(0, x) for x in original_shape] + return data[indices].contiguous() + else: + return data.view(original_shape) diff --git a/xinference/model/llm/pytorch/core.py b/xinference/model/llm/pytorch/core.py index a8684fc011..b369bc267a 100644 --- a/xinference/model/llm/pytorch/core.py +++ b/xinference/model/llm/pytorch/core.py @@ -28,6 +28,7 @@ ) from ..core import Model from ..utils import ChatModelDataProcessorMixin +from .compression import load_compress_model from .utils import generate_stream if TYPE_CHECKING: @@ -56,8 +57,6 @@ class PytorchModelConfig(TypedDict, total=False): gpus: Optional[str] num_gpus: int max_gpu_memory: str - load_8bit: bool - cpu_offloading: bool gptq_ckpt: Optional[str] gptq_wbits: int gptq_groupsize: int @@ -87,14 +86,23 @@ def _sanitize_model_config( pytorch_model_config.setdefault("revision", "main") pytorch_model_config.setdefault("gpus", None) pytorch_model_config.setdefault("num_gpus", 1) - pytorch_model_config.setdefault("load_8bit", False) - pytorch_model_config.setdefault("cpu_offloading", False) pytorch_model_config.setdefault("gptq_ckpt", None) pytorch_model_config.setdefault("gptq_wbits", 16) pytorch_model_config.setdefault("gptq_groupsize", -1) pytorch_model_config.setdefault("gptq_act_order", False) if self._is_darwin_and_apple_silicon(): pytorch_model_config.setdefault("device", "mps") + if ( + self.model_spec.model_name in ["baichuan-chat", "baichuan-base"] + and self.model_spec.quantization != "none" + ): + # dtype of parameters in `baichuan-chat` and `baichuan-base` model + # is `torch.bfloat16` which is not supported on MPS. + logger.warning( + f"Model {self.model_spec.model_name} can't use quantization method on MPS device. " + "Continuing with CPU device" + ) + pytorch_model_config["device"] = "cpu" else: pytorch_model_config.setdefault("device", "cuda") return pytorch_model_config @@ -139,8 +147,8 @@ def _load_model(self, kwargs: dict): return model, tokenizer def load(self): + quantization = self.model_spec.quantization num_gpus = self._pytorch_model_config.get("num_gpus", 1) - cpu_offloading = self._pytorch_model_config.get("cpu_offloading", False) if self._is_darwin_and_apple_silicon(): device = self._pytorch_model_config.get("device", "mps") else: @@ -150,27 +158,48 @@ def load(self): kwargs = {"torch_dtype": torch.float32} elif device == "cuda": kwargs = {"torch_dtype": torch.float16} - if cpu_offloading: - kwargs["device_map"] = "auto" elif device == "mps": kwargs = {"torch_dtype": torch.float16} else: raise ValueError(f"Device {device} is not supported in temporary") kwargs["revision"] = self._pytorch_model_config.get("revision", "main") - self._model, self._tokenizer = self._load_model(kwargs) + if quantization != "none": + if device == "cuda" and self._is_linux(): + kwargs["device_map"] = "auto" + if quantization == "4-bit": + kwargs["load_in_4bit"] = True + elif quantization == "8-bit": + kwargs["load_in_8bit"] = True + else: + raise ValueError( + f"Quantization {quantization} is not supported in temporary" + ) + else: + if num_gpus != 1: + raise ValueError(f"Quantization is not supported for multi-gpu") + elif quantization != "8-bit": + raise ValueError( + f"Only 8-bit quantization is supported if it is not linux system or cuda device" + ) + else: + self._model, self._tokenizer = load_compress_model( + model_path=self._model_path, + device=device, + torch_dtype=kwargs["torch_dtype"], + use_fast=self._use_fast_tokenizer, + revision=kwargs["revision"], + ) + logger.debug(f"Model Memory: {self._model.get_memory_footprint()}") + return - quantization = self.model_spec.quantization - if quantization == "int4": - self._model = self._model.quantize(4) - elif quantization == "int8": - self._model = self._model.quantize(8) + self._model, self._tokenizer = self._load_model(kwargs) if ( - device == "cuda" and num_gpus == 1 and not cpu_offloading + device == "cuda" and num_gpus == 1 and quantization == "none" ) or device == "mps": self._model.to(device) - print(self._model) + logger.debug(f"Model Memory: {self._model.get_memory_footprint()}") def generate( self, prompt: str, generate_config: Optional[PytorchGenerateConfig] = None