Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/model_executor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
"MPTForCausalLM": MPTForCausalLM,
"OPTForCausalLM": OPTForCausalLM,
"QWenLMHeadModel": QWenLMHeadModel,
"RWForCausalLM": FalconForCausalLM,
}

Expand Down
16 changes: 5 additions & 11 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,11 @@
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.mpt import MPTForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.models.qwen import QWenLMHeadModel

__all__ = [
"BaiChuanForCausalLM",
"BaichuanForCausalLM",
"BloomForCausalLM",
"FalconForCausalLM",
"GPT2LMHeadModel",
"GPTBigCodeForCausalLM",
"GPTJForCausalLM",
"GPTNeoXForCausalLM",
"LlamaForCausalLM",
"MPTForCausalLM",
"OPTForCausalLM",
"BaiChuanForCausalLM", "BaichuanForCausalLM", "BloomForCausalLM",
"FalconForCausalLM", "GPT2LMHeadModel", "GPTBigCodeForCausalLM",
"GPTJForCausalLM", "GPTNeoXForCausalLM", "LlamaForCausalLM",
"MPTForCausalLM", "OPTForCausalLM", "QWenLMHeadModel"
]
316 changes: 316 additions & 0 deletions vllm/model_executor/models/qwen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
# coding=utf-8
# Adapted from
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that here we introduced a new license different from this project. Will this license bring some potential legal risks for commercial use?

Copy link
Contributor Author

@Sanster Sanster Aug 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen's license allows commercial use. The following is the original text from the Qwen license.

If you are commercially using the Materials, and your product or service has more than 100 million monthly active users, You shall request a license from Us.

I noticed that the baichuan model does not include the license of the original project. Do we need to add Qwen's license in this PR?

baichuan's original license: For commercial use, you need to send an email to apply for a license.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can ask owners for some advice🤔.
Hi @WoosukKwon and @zhuohan123, should we follow the model's license when adding new models support in vLLM?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we need to follow the model code's license instead of the model's licenses. For Baichuan (and LLaMA), the code is Apache 2 but the model has a special license. In this case we can include the code with no problem. However, for Qwen, both its code and model are with their restricted license. I believe we should include the link to Qwen's license on the top of Qwen's source file for safety.

"""Inference-only QWen model compatible with HuggingFace weights.

The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Dict, List, Optional, Tuple

import torch
from torch import nn

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator,
load_tensor_parallel_weights,
)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear,
)
from vllm.sequence import SequenceOutputs
from vllm.transformers_utils.configs.qwen import QWenConfig

KVCache = Tuple[torch.Tensor, torch.Tensor]


class QWenMLP(nn.Module):

def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str = "silu",
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(
hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()

def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.c_proj(x)
return x


class QWenAttention(nn.Module):

def __init__(self, hidden_size: int, num_heads: int,
max_position_embeddings: int):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
)
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads

# pylint: disable=invalid-name
self.c_attn = ColumnParallelLinear(
hidden_size,
3 * hidden_size,
bias=True,
gather_output=False,
perform_initialization=False,
)
self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
self.scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)

k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)

output, _ = self.c_proj(attn_output)
return output


class QWenBlock(nn.Module):

def __init__(self, config: QWenConfig):
super().__init__()
self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)

self.attn = QWenAttention(config.n_embd, config.num_attention_heads,
config.max_position_embeddings)

self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)

self.mlp = QWenMLP(config.n_embd, config.ffn_hidden_size // 2)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
hidden_states = self.attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states


class QWenModel(nn.Module):

def __init__(self, config: QWenConfig):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size

vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(vocab_size,
config.n_embd,
perform_initialization=False)
self.h = nn.ModuleList(
[QWenBlock(config) for _ in range(config.num_hidden_layers)])
self.ln_f = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.h)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.h[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states


class QWenLMHeadModel(nn.Module):

def __init__(self, config: QWenConfig):
super().__init__()
self.config = config
self.transformer = QWenModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(
config.n_embd,
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.sampler = Sampler(config.vocab_size)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return next_tokens

_column_parallel_weights = ["wte.weight", "lm_head.weight"]
_row_parallel_weights = ["c_proj.weight"]

def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False,
):
tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()

for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if "rotary_emb.inv_freq" in name:
continue

if "wte" in name or "lm_head" in name:
# Consider padding in the vocab size.
param = state_dict[name]
padded_vocab_size = param.shape[0] * tp_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows,
loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)

if "c_attn" in name:
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tp_world_size
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads

if "weight" in name:
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif "bias" in name:
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)

is_gate_up_weight = False
for stride_id, weight_name in enumerate(["w2", "w1"]):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
(tp_rank + 1)]
param_slice = param.data[shard_size * stride_id:shard_size *
(stride_id + 1)]
assert param_slice.shape == loaded_weight.shape
param_slice.copy_(loaded_weight)
is_gate_up_weight = True
break
if is_gate_up_weight:
continue

param = state_dict[name]
load_tensor_parallel_weights(
param,
loaded_weight,
name,
self._column_parallel_weights,
self._row_parallel_weights,
tp_rank,
)
1 change: 1 addition & 0 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
_CONFIG_REGISTRY = {
"mpt": MPTConfig,
"baichuan": BaiChuanConfig,
"qwen": QWenConfig,
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
}
Expand Down
2 changes: 2 additions & 0 deletions vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
from vllm.transformers_utils.configs.qwen import QWenConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
Expand All @@ -8,5 +9,6 @@
__all__ = [
"MPTConfig",
"BaiChuanConfig",
"QWenConfig",
"RWConfig",
]
Loading