Skip to content

Commit

Permalink
FEAT: support fp4 and int8 quantization for pytorch model (#238)
Browse files Browse the repository at this point in the history
  • Loading branch information
pangyoki committed Jul 26, 2023
1 parent b95df01 commit e4115d1
Show file tree
Hide file tree
Showing 8 changed files with 318 additions and 38 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ jobs:
pip install accelerate
pip install sentencepiece
pip install transformers_stream_generator
pip install bitsandbytes
pip install -e ".[dev]"
working-directory: .

Expand Down
6 changes: 3 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ dev =
all =
chatglm-cpp
llama-cpp-python
transformers
transformers>=4.31.0
torch
accelerate
accelerate>=0.20.3
sentencepiece
transformers_stream_generator
cpm_kernels; platform_system != "Darwin"
bitsandbytes
doc =
ipython>=6.5.0
sphinx>=3.0.0,<5.0.0
Expand Down
4 changes: 1 addition & 3 deletions xinference/core/gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,7 @@ def select_model(
cache_path = model_family.generate_cache_path(
int(_model_size_in_billions), _quantization
)
if not (os.path.exists(cache_path)):
if os.path.exists(cache_path):
os.remove(cache_path)
if _model_format != "pytorch" and not (os.path.exists(cache_path)):
url = model_family.url_generator(
int(_model_size_in_billions), _quantization
)
Expand Down
9 changes: 0 additions & 9 deletions xinference/core/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import asyncio
import platform
import time
from dataclasses import dataclass
from logging import getLogger
Expand Down Expand Up @@ -247,13 +246,6 @@ def _choose_subpool(self) -> str:

raise RuntimeError("No available slot found")

def _check_model_is_valid(self, model_name):
# baichuan-base and baichuan-chat depend on `cpm_kernels` module,
# but `cpm_kernels` cannot run on Darwin system.
if platform.system() == "Darwin":
if model_name in ["baichuan-base", "baichuan-chat"]:
raise ValueError(f"{model_name} model can't run on Darwin system.")

@log
async def launch_builtin_model(
self,
Expand All @@ -265,7 +257,6 @@ async def launch_builtin_model(
**kwargs,
) -> xo.ActorRefType["ModelActor"]:
assert model_uid not in self._model_uid_to_model
self._check_model_is_valid(model_name)

from ..model import MODEL_FAMILIES

Expand Down
6 changes: 3 additions & 3 deletions xinference/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ def cache(
url = self.url_generator(model_size_in_billions, quantization)
rp_url = self.rp_url_generator(model_size_in_billions, quantization)

if self.model_format == "pytorch":
return url

try:
rp_fetch = requests.get(rp_url)
except RequestException:
Expand All @@ -167,9 +170,6 @@ def cache(
str(splitted_res_content[index + 1], encoding="utf-8")
)

if self.model_format == "pytorch":
return url

full_name = f"{str(self)}-{model_size_in_billions}b-{quantization}"
save_path = self.generate_cache_path(model_size_in_billions, quantization)

Expand Down
10 changes: 5 additions & 5 deletions xinference/model/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def install():
model_name="baichuan",
model_sizes_in_billions=[7],
model_format="pytorch",
quantizations=["none"],
quantizations=["8-bit", "4-bit", "none"],
url_generator=pytorch_baichuan_name_generator,
rp_url_generator=lambda model_size, quantization: "",
cls=BaichuanPytorch,
Expand All @@ -391,7 +391,7 @@ def install():
model_name="baichuan-base",
model_sizes_in_billions=[13],
model_format="pytorch",
quantizations=["int4", "int8", "none"],
quantizations=["8-bit", "4-bit", "none"],
url_generator=pytorch_baichuan_base_name_generator,
rp_url_generator=lambda model_size, quantization: "",
cls=BaichuanPytorch,
Expand All @@ -406,7 +406,7 @@ def install():
model_name="baichuan-chat",
model_sizes_in_billions=[13],
model_format="pytorch",
quantizations=["int4", "int8", "none"],
quantizations=["8-bit", "4-bit", "none"],
url_generator=pytorch_baichuan_chat_name_generator,
rp_url_generator=lambda model_size, quantization: "",
cls=BaichuanPytorchChat,
Expand All @@ -421,9 +421,9 @@ def install():
MODEL_FAMILIES.append(
ModelFamily(
model_name="vicuna-v1.3",
model_sizes_in_billions=[7, 13],
model_sizes_in_billions=[7, 13, 33],
model_format="pytorch",
quantizations=["none"],
quantizations=["8-bit", "4-bit", "none"],
url_generator=pytorch_vicuna_v1_3_name_generator,
rp_url_generator=lambda model_size, quantization: "",
cls=VicunaCensoredPytorch,
Expand Down
261 changes: 261 additions & 0 deletions xinference/model/llm/pytorch/compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import gc
import glob
import os

import torch
import torch.nn as nn
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from huggingface_hub import snapshot_download
from torch import Tensor
from torch.nn import functional as F
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from ....constants import XINFERENCE_CACHE_DIR


@dataclasses.dataclass
class CompressionConfig:
"""Group-wise quantization."""

num_bits: int
group_size: int
group_dim: int
symmetric: bool
enabled: bool = True


default_compression_config = CompressionConfig(
num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True
)


class CLinear(nn.Module):
"""Compressed Linear Layer."""

def __init__(self, weight=None, bias=None, device=None):
super().__init__()
if weight is None:
self.weight = None
elif isinstance(weight, Tensor):
self.weight = compress(weight.data.to(device), default_compression_config)
else:
self.weight = weight
self.bias = bias

def forward(self, input: Tensor) -> Tensor:
weight = decompress(self.weight, default_compression_config)
if self.bias is None:
return F.linear(input.to(weight.dtype), weight)
return F.linear(input.to(weight.dtype), weight, self.bias.to(weight.dtype))


def get_compressed_list(module, prefix=""):
compressed_list = []
for attr_str in dir(module):
target_attr = getattr(module, attr_str)
if type(target_attr) == torch.nn.Linear:
full_name = (
f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight"
)
compressed_list.append(full_name)
for name, child in module.named_children():
child_prefix = f"{prefix}.{name}" if prefix else name
for each in get_compressed_list(child, child_prefix):
compressed_list.append(each)
return compressed_list


def apply_compressed_weight(module, compressed_state_dict, target_device, prefix=""):
for attr_str in dir(module):
target_attr = getattr(module, attr_str)
if type(target_attr) == torch.nn.Linear:
full_name = (
f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight"
)
setattr(
module,
attr_str,
CLinear(
compressed_state_dict[full_name], target_attr.bias, target_device
),
)
for name, child in module.named_children():
child_prefix = f"{prefix}.{name}" if prefix else name
apply_compressed_weight(
child, compressed_state_dict, target_device, child_prefix
)


def load_compress_model(
model_path: str,
device: str,
torch_dtype: torch.dtype,
use_fast: bool,
revision: str = "main",
):
# partially load model
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=use_fast,
trust_remote_code=True,
revision=revision,
cache_dir=XINFERENCE_CACHE_DIR,
)

with init_empty_weights():
config = AutoConfig.from_pretrained(
model_path,
low_cpu_mem_usage=True,
torch_dtype=torch_dtype,
trust_remote_code=True,
revision=revision,
cache_dir=XINFERENCE_CACHE_DIR,
)
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
linear_weights = get_compressed_list(model)

if os.path.exists(model_path):
# `model_path` is a local folder
base_pattern = os.path.join(model_path, "pytorch_model*.bin")
else:
# `model_path` is a cached Hugging Face repo
model_path = snapshot_download(
model_path, revision=revision, cache_dir=XINFERENCE_CACHE_DIR
)
base_pattern = os.path.join(model_path, "pytorch_model*.bin")

files = glob.glob(base_pattern)

compressed_state_dict = {}

for filename in tqdm(files):
tmp_state_dict = torch.load(filename, map_location=torch.device(device))
for name in tmp_state_dict:
if name in linear_weights:
tensor = tmp_state_dict[name].to(device).data.to(torch_dtype)
compressed_state_dict[name] = compress(
tensor, default_compression_config
)
else:
compressed_state_dict[name] = tmp_state_dict[name].to(device)
tmp_state_dict[name] = None
tensor = None
gc.collect()
torch.cuda.empty_cache()

for name in model.state_dict():
if name not in linear_weights:
set_module_tensor_to_device(
model, name, device, value=compressed_state_dict[name]
)
apply_compressed_weight(model, compressed_state_dict, device)

model.to(device)

return model, tokenizer


def compress(tensor, config):
"""Simulate group-wise quantization."""
if not config.enabled:
return tensor

group_size, num_bits, group_dim, symmetric = (
config.group_size,
config.num_bits,
config.group_dim,
config.symmetric,
)
assert num_bits <= 8

original_shape = tensor.shape
num_groups = (original_shape[group_dim] + group_size - 1) // group_size
new_shape = (
original_shape[:group_dim]
+ (num_groups, group_size)
+ original_shape[group_dim + 1 :]
)

# Pad
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
if pad_len != 0:
pad_shape = (
original_shape[:group_dim] + (pad_len,) + original_shape[group_dim + 1 :]
)
tensor = torch.cat(
[tensor, torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)],
dim=group_dim,
)
data = tensor.view(new_shape)

# Quantize
if symmetric:
B = 2 ** (num_bits - 1) - 1
scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0]
data = data * scale
data = data.clamp_(-B, B).round_().to(torch.int8)
return data, scale, original_shape
else:
B = 2**num_bits - 1
mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0]
mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0]

scale = B / (mx - mn)
data = data - mn
data.mul_(scale)

data = data.clamp_(0, B).round_().to(torch.uint8)
return data, mn, scale, original_shape


def decompress(packed_data, config):
"""Simulate group-wise dequantization."""
if not config.enabled:
return packed_data

group_size, _, group_dim, symmetric = (
config.group_size,
config.num_bits,
config.group_dim,
config.symmetric,
)

# Dequantize
if symmetric:
data, scale, original_shape = packed_data
data = data / scale
else:
data, mn, scale, original_shape = packed_data
data = data / scale
data.add_(mn)

# Unpad
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
if pad_len:
padded_original_shape = (
original_shape[:group_dim]
+ (original_shape[group_dim] + pad_len,)
+ original_shape[group_dim + 1 :]
)
data = data.reshape(padded_original_shape)
indices = [slice(0, x) for x in original_shape]
return data[indices].contiguous()
else:
return data.view(original_shape)
Loading

0 comments on commit e4115d1

Please sign in to comment.