From 124b43d7c9279915647323b872e031d9c734ff2e Mon Sep 17 00:00:00 2001 From: "wq.chu" Date: Mon, 7 Aug 2023 10:23:27 +0800 Subject: [PATCH] add QWen-7b --- vllm/model_executor/model_loader.py | 1 + vllm/model_executor/models/__init__.py | 16 +- vllm/model_executor/models/qwen.py | 316 ++++++++++++++++++++ vllm/transformers_utils/config.py | 1 + vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/qwen.py | 71 +++++ 6 files changed, 396 insertions(+), 11 deletions(-) create mode 100644 vllm/model_executor/models/qwen.py create mode 100644 vllm/transformers_utils/configs/qwen.py diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index a1bcd1591936..a98c5b19f56e 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -23,6 +23,7 @@ "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* "MPTForCausalLM": MPTForCausalLM, "OPTForCausalLM": OPTForCausalLM, + "QWenLMHeadModel": QWenLMHeadModel, "RWForCausalLM": FalconForCausalLM, } diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 6d61f95452c1..9e89c463593a 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -9,17 +9,11 @@ from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.mpt import MPTForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM +from vllm.model_executor.models.qwen import QWenLMHeadModel __all__ = [ - "BaiChuanForCausalLM", - "BaichuanForCausalLM", - "BloomForCausalLM", - "FalconForCausalLM", - "GPT2LMHeadModel", - "GPTBigCodeForCausalLM", - "GPTJForCausalLM", - "GPTNeoXForCausalLM", - "LlamaForCausalLM", - "MPTForCausalLM", - "OPTForCausalLM", + "BaiChuanForCausalLM", "BaichuanForCausalLM", "BloomForCausalLM", + "FalconForCausalLM", "GPT2LMHeadModel", "GPTBigCodeForCausalLM", + "GPTJForCausalLM", "GPTNeoXForCausalLM", "LlamaForCausalLM", + "MPTForCausalLM", "OPTForCausalLM", "QWenLMHeadModel" ] diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py new file mode 100644 index 000000000000..d81940ed28b0 --- /dev/null +++ b/vllm/model_executor/models/qwen.py @@ -0,0 +1,316 @@ +# coding=utf-8 +# Adapted from +# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py +# Copyright (c) Alibaba Cloud. +# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE +"""Inference-only QWen model compatible with HuggingFace weights. + +The input of the model is flattened to a 1D tensor of tokens. The model uses +InputMetadata to extract the original 2D shape of the input. +""" +from typing import Dict, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.weight_utils import ( + hf_model_weights_iterator, + load_tensor_parallel_weights, +) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.parallel_utils.tensor_parallel import ( + VocabParallelEmbedding, + ColumnParallelLinear, + RowParallelLinear, +) +from vllm.sequence import SequenceOutputs +from vllm.transformers_utils.configs.qwen import QWenConfig + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class QWenMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str = "silu", + ): + super().__init__() + self.gate_up_proj = ColumnParallelLinear( + hidden_size, + 2 * intermediate_size, + bias=False, + gather_output=False, + perform_initialization=False, + ) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.c_proj(x) + return x + + +class QWenAttention(nn.Module): + + def __init__(self, hidden_size: int, num_heads: int, + max_position_embeddings: int): + super().__init__() + self.hidden_size = hidden_size + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( + ) + self.total_num_heads = num_heads + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + self.head_dim = hidden_size // self.total_num_heads + + # pylint: disable=invalid-name + self.c_attn = ColumnParallelLinear( + hidden_size, + 3 * hidden_size, + bias=True, + gather_output=False, + perform_initialization=False, + ) + self.c_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + ) + self.scaling = self.head_dim**-0.5 + self.attn = PagedAttentionWithRoPE( + self.num_heads, + self.head_dim, + self.scaling, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + + k_cache, v_cache = kv_cache + attn_output = self.attn(positions, q, k, v, k_cache, v_cache, + input_metadata, cache_event) + + output, _ = self.c_proj(attn_output) + return output + + +class QWenBlock(nn.Module): + + def __init__(self, config: QWenConfig): + super().__init__() + self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon) + + self.attn = QWenAttention(config.n_embd, config.num_attention_heads, + config.max_position_embeddings) + + self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon) + + self.mlp = QWenMLP(config.n_embd, config.ffn_hidden_size // 2) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + cache_event: Optional[torch.cuda.Event], + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + hidden_states = self.attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + cache_event=cache_event, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class QWenModel(nn.Module): + + def __init__(self, config: QWenConfig): + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + + vocab_size = ((config.vocab_size + 63) // 64) * 64 + self.wte = VocabParallelEmbedding(vocab_size, + config.n_embd, + perform_initialization=False) + self.h = nn.ModuleList( + [QWenBlock(config) for _ in range(config.num_hidden_layers)]) + self.ln_f = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> torch.Tensor: + hidden_states = self.wte(input_ids) + for i in range(len(self.h)): + if cache_events is None: + cache_event = None + else: + cache_event = cache_events[i] + layer = self.h[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata, + cache_event, + ) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class QWenLMHeadModel(nn.Module): + + def __init__(self, config: QWenConfig): + super().__init__() + self.config = config + self.transformer = QWenModel(config) + vocab_size = ((config.vocab_size + 63) // 64) * 64 + self.lm_head = ColumnParallelLinear( + config.n_embd, + vocab_size, + bias=False, + gather_output=False, + perform_initialization=False, + ) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + cache_events: Optional[List[torch.cuda.Event]], + ) -> Dict[int, SequenceOutputs]: + hidden_states = self.transformer(input_ids, positions, kv_caches, + input_metadata, cache_events) + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + input_metadata) + return next_tokens + + _column_parallel_weights = ["wte.weight", "lm_head.weight"] + _row_parallel_weights = ["c_proj.weight"] + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + use_np_cache: bool = False, + ): + tp_world_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + state_dict = self.state_dict() + + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, use_np_cache): + if "rotary_emb.inv_freq" in name: + continue + + if "wte" in name or "lm_head" in name: + # Consider padding in the vocab size. + param = state_dict[name] + padded_vocab_size = param.shape[0] * tp_world_size + num_extra_rows = padded_vocab_size - self.config.vocab_size + extra_rows = torch.empty(num_extra_rows, + loaded_weight.shape[1]) + extra_rows = extra_rows.to(loaded_weight) + loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) + + if "c_attn" in name: + total_num_heads = self.config.num_attention_heads + hidden_size = self.config.hidden_size + head_size = hidden_size // total_num_heads + num_heads = total_num_heads // tp_world_size + head_start = tp_rank * num_heads + head_end = (tp_rank + 1) * num_heads + + if "weight" in name: + loaded_weight = loaded_weight.view(3, total_num_heads, + head_size, hidden_size) + loaded_weight = loaded_weight[:, head_start:head_end, :, :] + loaded_weight = loaded_weight.reshape(-1, hidden_size) + elif "bias" in name: + loaded_weight = loaded_weight.view(3, total_num_heads, + head_size) + loaded_weight = loaded_weight[:, head_start:head_end, :] + loaded_weight = loaded_weight.reshape(-1) + + is_gate_up_weight = False + for stride_id, weight_name in enumerate(["w2", "w1"]): + if weight_name not in name: + continue + param = state_dict[name.replace(weight_name, "gate_up_proj")] + shard_size = param.shape[0] // 2 + loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * + (tp_rank + 1)] + param_slice = param.data[shard_size * stride_id:shard_size * + (stride_id + 1)] + assert param_slice.shape == loaded_weight.shape + param_slice.copy_(loaded_weight) + is_gate_up_weight = True + break + if is_gate_up_weight: + continue + + param = state_dict[name] + load_tensor_parallel_weights( + param, + loaded_weight, + name, + self._column_parallel_weights, + self._row_parallel_weights, + tp_rank, + ) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index b7b3da63a578..7447039d80bf 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -5,6 +5,7 @@ _CONFIG_REGISTRY = { "mpt": MPTConfig, "baichuan": BaiChuanConfig, + "qwen": QWenConfig, "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) } diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index b98c797c25a4..c1b03fba0566 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,5 +1,6 @@ from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.baichuan import BaiChuanConfig +from vllm.transformers_utils.configs.qwen import QWenConfig # RWConfig is for the original tiiuae/falcon-40b(-instruct) and # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. @@ -8,5 +9,6 @@ __all__ = [ "MPTConfig", "BaiChuanConfig", + "QWenConfig", "RWConfig", ] diff --git a/vllm/transformers_utils/configs/qwen.py b/vllm/transformers_utils/configs/qwen.py new file mode 100644 index 000000000000..916bb4c77bc0 --- /dev/null +++ b/vllm/transformers_utils/configs/qwen.py @@ -0,0 +1,71 @@ +# Copyright (c) Alibaba Cloud. +# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE + +from transformers import PretrainedConfig + + +class QWenConfig(PretrainedConfig): + model_type = "qwen" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "n_embd", + "num_attention_heads": "n_head", + "max_position_embeddings": "n_positions", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=151851, + n_embd=4096, + n_layer=32, + n_head=32, + n_inner=None, + embd_pdrop=0.0, + attn_pdrop=0.0, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + scale_attn_weights=True, + use_cache=True, + eos_token_id=151643, + apply_residual_connection_post_layernorm=False, + bf16=True, + kv_channels=128, + rotary_pct=1.0, + rotary_emb_base=10000, + use_dynamic_ntk=False, + use_logn_attn=False, + use_flash_attn=True, + ffn_hidden_size=22016, + no_bias=True, + tie_word_embeddings=False, + **kwargs, + ): + self.eos_token_id = eos_token_id + super().__init__(eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs) + + self.vocab_size = vocab_size + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.apply_residual_connection_post_layernorm = ( + apply_residual_connection_post_layernorm) + self.bf16 = bf16 + self.kv_channels = kv_channels + self.rotary_pct = rotary_pct + self.rotary_emb_base = rotary_emb_base + self.use_dynamic_ntk = use_dynamic_ntk + self.use_logn_attn = use_logn_attn + self.use_flash_attn = use_flash_attn + self.ffn_hidden_size = ffn_hidden_size + self.no_bias = no_bias + self.tie_word_embeddings = tie_word_embeddings