-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
add QWen-7b support #685
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
add QWen-7b support #685
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,316 @@ | ||
# coding=utf-8 | ||
# Adapted from | ||
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py | ||
# Copyright (c) Alibaba Cloud. | ||
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE | ||
"""Inference-only QWen model compatible with HuggingFace weights. | ||
|
||
The input of the model is flattened to a 1D tensor of tokens. The model uses | ||
InputMetadata to extract the original 2D shape of the input. | ||
""" | ||
from typing import Dict, List, Optional, Tuple | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from vllm.model_executor.input_metadata import InputMetadata | ||
from vllm.model_executor.layers.activation import SiluAndMul | ||
from vllm.model_executor.layers.layernorm import RMSNorm | ||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE | ||
from vllm.model_executor.layers.sampler import Sampler | ||
from vllm.model_executor.weight_utils import ( | ||
hf_model_weights_iterator, | ||
load_tensor_parallel_weights, | ||
) | ||
from vllm.model_executor.parallel_utils.parallel_state import ( | ||
get_tensor_model_parallel_rank, | ||
get_tensor_model_parallel_world_size, | ||
) | ||
from vllm.model_executor.parallel_utils.tensor_parallel import ( | ||
VocabParallelEmbedding, | ||
ColumnParallelLinear, | ||
RowParallelLinear, | ||
) | ||
from vllm.sequence import SequenceOutputs | ||
from vllm.transformers_utils.configs.qwen import QWenConfig | ||
|
||
KVCache = Tuple[torch.Tensor, torch.Tensor] | ||
|
||
|
||
class QWenMLP(nn.Module): | ||
|
||
def __init__( | ||
self, | ||
hidden_size: int, | ||
intermediate_size: int, | ||
hidden_act: str = "silu", | ||
): | ||
super().__init__() | ||
self.gate_up_proj = ColumnParallelLinear( | ||
hidden_size, | ||
2 * intermediate_size, | ||
bias=False, | ||
gather_output=False, | ||
perform_initialization=False, | ||
) | ||
self.c_proj = RowParallelLinear( | ||
intermediate_size, | ||
hidden_size, | ||
bias=False, | ||
input_is_parallel=True, | ||
perform_initialization=False, | ||
) | ||
if hidden_act != "silu": | ||
raise ValueError(f"Unsupported activation: {hidden_act}. " | ||
"Only silu is supported for now.") | ||
self.act_fn = SiluAndMul() | ||
|
||
def forward(self, x): | ||
gate_up, _ = self.gate_up_proj(x) | ||
x = self.act_fn(gate_up) | ||
x, _ = self.c_proj(x) | ||
return x | ||
|
||
|
||
class QWenAttention(nn.Module): | ||
|
||
def __init__(self, hidden_size: int, num_heads: int, | ||
max_position_embeddings: int): | ||
super().__init__() | ||
self.hidden_size = hidden_size | ||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( | ||
) | ||
self.total_num_heads = num_heads | ||
assert self.total_num_heads % tensor_model_parallel_world_size == 0 | ||
self.num_heads = (self.total_num_heads // | ||
tensor_model_parallel_world_size) | ||
self.head_dim = hidden_size // self.total_num_heads | ||
|
||
# pylint: disable=invalid-name | ||
self.c_attn = ColumnParallelLinear( | ||
hidden_size, | ||
3 * hidden_size, | ||
bias=True, | ||
gather_output=False, | ||
perform_initialization=False, | ||
) | ||
self.c_proj = RowParallelLinear( | ||
self.total_num_heads * self.head_dim, | ||
hidden_size, | ||
bias=False, | ||
input_is_parallel=True, | ||
perform_initialization=False, | ||
) | ||
self.scaling = self.head_dim**-0.5 | ||
self.attn = PagedAttentionWithRoPE( | ||
self.num_heads, | ||
self.head_dim, | ||
self.scaling, | ||
rotary_dim=self.head_dim, | ||
max_position=max_position_embeddings, | ||
) | ||
|
||
def forward( | ||
self, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
kv_cache: KVCache, | ||
input_metadata: InputMetadata, | ||
cache_event: Optional[torch.cuda.Event], | ||
) -> torch.Tensor: | ||
qkv, _ = self.c_attn(hidden_states) | ||
q, k, v = qkv.chunk(chunks=3, dim=-1) | ||
|
||
k_cache, v_cache = kv_cache | ||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache, | ||
input_metadata, cache_event) | ||
|
||
output, _ = self.c_proj(attn_output) | ||
return output | ||
|
||
|
||
class QWenBlock(nn.Module): | ||
|
||
def __init__(self, config: QWenConfig): | ||
super().__init__() | ||
self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon) | ||
|
||
self.attn = QWenAttention(config.n_embd, config.num_attention_heads, | ||
config.max_position_embeddings) | ||
|
||
self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon) | ||
|
||
self.mlp = QWenMLP(config.n_embd, config.ffn_hidden_size // 2) | ||
|
||
def forward( | ||
self, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
kv_cache: KVCache, | ||
input_metadata: InputMetadata, | ||
cache_event: Optional[torch.cuda.Event], | ||
) -> torch.Tensor: | ||
# Self Attention | ||
residual = hidden_states | ||
hidden_states = self.ln_1(hidden_states) | ||
hidden_states = self.attn( | ||
positions=positions, | ||
hidden_states=hidden_states, | ||
kv_cache=kv_cache, | ||
input_metadata=input_metadata, | ||
cache_event=cache_event, | ||
) | ||
hidden_states = residual + hidden_states | ||
|
||
# Fully Connected | ||
residual = hidden_states | ||
hidden_states = self.ln_2(hidden_states) | ||
hidden_states = self.mlp(hidden_states) | ||
hidden_states = residual + hidden_states | ||
return hidden_states | ||
|
||
|
||
class QWenModel(nn.Module): | ||
|
||
def __init__(self, config: QWenConfig): | ||
super().__init__() | ||
self.config = config | ||
self.vocab_size = config.vocab_size | ||
|
||
vocab_size = ((config.vocab_size + 63) // 64) * 64 | ||
self.wte = VocabParallelEmbedding(vocab_size, | ||
config.n_embd, | ||
perform_initialization=False) | ||
self.h = nn.ModuleList( | ||
[QWenBlock(config) for _ in range(config.num_hidden_layers)]) | ||
self.ln_f = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
kv_caches: List[KVCache], | ||
input_metadata: InputMetadata, | ||
cache_events: Optional[List[torch.cuda.Event]], | ||
) -> torch.Tensor: | ||
hidden_states = self.wte(input_ids) | ||
for i in range(len(self.h)): | ||
if cache_events is None: | ||
cache_event = None | ||
else: | ||
cache_event = cache_events[i] | ||
layer = self.h[i] | ||
hidden_states = layer( | ||
positions, | ||
hidden_states, | ||
kv_caches[i], | ||
input_metadata, | ||
cache_event, | ||
) | ||
hidden_states = self.ln_f(hidden_states) | ||
return hidden_states | ||
|
||
|
||
class QWenLMHeadModel(nn.Module): | ||
|
||
def __init__(self, config: QWenConfig): | ||
super().__init__() | ||
self.config = config | ||
self.transformer = QWenModel(config) | ||
vocab_size = ((config.vocab_size + 63) // 64) * 64 | ||
self.lm_head = ColumnParallelLinear( | ||
config.n_embd, | ||
vocab_size, | ||
bias=False, | ||
gather_output=False, | ||
perform_initialization=False, | ||
) | ||
self.sampler = Sampler(config.vocab_size) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
kv_caches: List[KVCache], | ||
input_metadata: InputMetadata, | ||
cache_events: Optional[List[torch.cuda.Event]], | ||
) -> Dict[int, SequenceOutputs]: | ||
hidden_states = self.transformer(input_ids, positions, kv_caches, | ||
input_metadata, cache_events) | ||
next_tokens = self.sampler(self.lm_head.weight, hidden_states, | ||
input_metadata) | ||
return next_tokens | ||
|
||
_column_parallel_weights = ["wte.weight", "lm_head.weight"] | ||
_row_parallel_weights = ["c_proj.weight"] | ||
|
||
def load_weights( | ||
self, | ||
model_name_or_path: str, | ||
cache_dir: Optional[str] = None, | ||
use_np_cache: bool = False, | ||
): | ||
tp_world_size = get_tensor_model_parallel_world_size() | ||
tp_rank = get_tensor_model_parallel_rank() | ||
state_dict = self.state_dict() | ||
|
||
for name, loaded_weight in hf_model_weights_iterator( | ||
model_name_or_path, cache_dir, use_np_cache): | ||
if "rotary_emb.inv_freq" in name: | ||
continue | ||
|
||
if "wte" in name or "lm_head" in name: | ||
# Consider padding in the vocab size. | ||
param = state_dict[name] | ||
padded_vocab_size = param.shape[0] * tp_world_size | ||
num_extra_rows = padded_vocab_size - self.config.vocab_size | ||
extra_rows = torch.empty(num_extra_rows, | ||
loaded_weight.shape[1]) | ||
extra_rows = extra_rows.to(loaded_weight) | ||
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) | ||
|
||
if "c_attn" in name: | ||
total_num_heads = self.config.num_attention_heads | ||
hidden_size = self.config.hidden_size | ||
head_size = hidden_size // total_num_heads | ||
num_heads = total_num_heads // tp_world_size | ||
head_start = tp_rank * num_heads | ||
head_end = (tp_rank + 1) * num_heads | ||
|
||
if "weight" in name: | ||
loaded_weight = loaded_weight.view(3, total_num_heads, | ||
head_size, hidden_size) | ||
loaded_weight = loaded_weight[:, head_start:head_end, :, :] | ||
loaded_weight = loaded_weight.reshape(-1, hidden_size) | ||
elif "bias" in name: | ||
loaded_weight = loaded_weight.view(3, total_num_heads, | ||
head_size) | ||
loaded_weight = loaded_weight[:, head_start:head_end, :] | ||
loaded_weight = loaded_weight.reshape(-1) | ||
|
||
is_gate_up_weight = False | ||
for stride_id, weight_name in enumerate(["w2", "w1"]): | ||
if weight_name not in name: | ||
continue | ||
param = state_dict[name.replace(weight_name, "gate_up_proj")] | ||
shard_size = param.shape[0] // 2 | ||
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * | ||
(tp_rank + 1)] | ||
param_slice = param.data[shard_size * stride_id:shard_size * | ||
(stride_id + 1)] | ||
assert param_slice.shape == loaded_weight.shape | ||
param_slice.copy_(loaded_weight) | ||
is_gate_up_weight = True | ||
break | ||
if is_gate_up_weight: | ||
continue | ||
|
||
param = state_dict[name] | ||
load_tensor_parallel_weights( | ||
param, | ||
loaded_weight, | ||
name, | ||
self._column_parallel_weights, | ||
self._row_parallel_weights, | ||
tp_rank, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that here we introduced a new license different from this project. Will this license bring some potential legal risks for commercial use?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Qwen's license allows commercial use. The following is the original text from the Qwen license.
I noticed that the baichuan model does not include the license of the original project. Do we need to add Qwen's license in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can ask owners for some advice🤔.
Hi @WoosukKwon and @zhuohan123, should we follow the model's license when adding new models support in vLLM?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we need to follow the model code's license instead of the model's licenses. For Baichuan (and LLaMA), the code is Apache 2 but the model has a special license. In this case we can include the code with no problem. However, for Qwen, both its code and model are with their restricted license. I believe we should include the link to Qwen's license on the top of Qwen's source file for safety.