Skip to content

Commit

Permalink
format fix
Browse files Browse the repository at this point in the history
  • Loading branch information
megha95 committed Mar 27, 2024
1 parent fc1fc91 commit 6a27282
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 110 deletions.
197 changes: 122 additions & 75 deletions vllm/model_executor/models/dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,22 @@
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
VocabParallelEmbedding,
ParallelLMHead,
DEFAULT_VOCAB_PADDING_SIZE,
)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.model_executor.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.dbrx import DbrxConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
Expand All @@ -41,11 +49,13 @@ def __init__(
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.ffn_config.moe_num_experts
self.d_model = config.d_model
self.layer = ReplicatedLinear(self.d_model,
self.num_total_experts,
bias=False,
params_dtype=params_dtype,
linear_method=None)
self.layer = ReplicatedLinear(
self.d_model,
self.num_total_experts,
bias=False,
params_dtype=params_dtype,
linear_method=None,
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
router_logits, _ = self.layer(hidden_states)
Expand All @@ -71,35 +81,50 @@ def __init__(
self.num_total_experts = config.ffn_config.moe_num_experts
self.top_k = config.ffn_config.moe_top_k
self.d_model = config.d_model
self.intermediate_size = config.ffn_config.ffn_hidden_size // self.tp_size
self.intermediate_size = (
config.ffn_config.ffn_hidden_size // self.tp_size
)

if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype

self.router = DbrxRouter(config, self.params_dtype)
self.ws = nn.Parameter(
torch.empty(self.num_total_experts,
2 * self.intermediate_size,
self.d_model,
device="cuda",
dtype=self.params_dtype))
torch.empty(
self.num_total_experts,
2 * self.intermediate_size,
self.d_model,
device="cuda",
dtype=self.params_dtype,
)
)
self.w2s = nn.Parameter(
torch.empty(self.num_total_experts,
self.d_model,
self.intermediate_size,
device="cuda",
dtype=self.params_dtype))

set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2s, {
"weight_loader": self.weight_loader,
})

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str):
torch.empty(
self.num_total_experts,
self.d_model,
self.intermediate_size,
device="cuda",
dtype=self.params_dtype,
)
)

set_weight_attrs(
self.ws,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.w2s,
{
"weight_loader": self.weight_loader,
},
)

def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, weight_name: str
):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
Expand All @@ -109,45 +134,50 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
if weight_name.endswith("w1"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model])
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
if weight_name.endswith("v1"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model])
param_data[:,
shard_size:2 * shard_size, :] = loaded_weight[:,
shard, :]
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:, shard_size : 2 * shard_size, :] = loaded_weight[
:, shard, :
]
if weight_name.endswith("w2"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model
]).transpose(1, 2)
[-1, self.intermediate_size * self.tp_size, self.d_model],
).transpose(1, 2)
param_data[:] = loaded_weight[:, :, shard]

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.d_model)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.router(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.ws,
self.w2s,
router_logits,
self.top_k,
renormalize=True,
inplace=True)
final_hidden_states = fused_moe(
hidden_states,
self.ws,
self.w2s,
router_logits,
self.top_k,
renormalize=True,
inplace=True,
)

if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
final_hidden_states
)

return final_hidden_states.view(batch_size, sequence_length,
hidden_size)
return final_hidden_states.view(
batch_size, sequence_length, hidden_size
)


class DbrxAttention(nn.Module):

def __init__(
self,
config: DbrxConfig,
Expand Down Expand Up @@ -226,7 +256,6 @@ def forward(


class DbrxFusedNormAttention(nn.Module):

def __init__(
self,
config: DbrxConfig,
Expand All @@ -247,18 +276,26 @@ def forward(
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.norm_1(hidden_states)
<<<<<<< HEAD
x = self.attn(position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
=======
x = self.attn(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
)
>>>>>>> c9b1f63 (format fix)
hidden_states = residual + x
residual = hidden_states
hidden_states = self.norm_2(hidden_states)
return hidden_states, residual


class DbrxBlock(nn.Module):

def __init__(
self,
config: DbrxConfig,
Expand Down Expand Up @@ -287,7 +324,6 @@ def forward(


class DbrxModel(nn.Module):

def __init__(
self,
config: DbrxConfig,
Expand All @@ -299,11 +335,13 @@ def __init__(
config.d_model,
)
self.blocks = nn.ModuleList(
[DbrxBlock(config, linear_method) for _ in range(config.n_layers)])
[DbrxBlock(config, linear_method) for _ in range(config.n_layers)]
)
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
for module in self.modules():
if hasattr(module, "bias") and isinstance(module.bias,
nn.Parameter):
if hasattr(module, "bias") and isinstance(
module.bias, nn.Parameter
):
# Remove the bias term in Linear and LayerNorm.
module.register_parameter("bias", None)

Expand All @@ -328,7 +366,6 @@ def forward(


class DbrxForCausalLM(nn.Module):

def __init__(
self,
config: DbrxConfig,
Expand All @@ -339,10 +376,12 @@ def __init__(
self.linear_method = linear_method
self.unpadded_vocab_size = config.vocab_size
self.transformer = DbrxModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size,
config.d_model,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.d_model,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
)
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)

def forward(
Expand All @@ -352,32 +391,39 @@ def forward(
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata)
hidden_states = self.transformer(
input_ids, positions, kv_caches, input_metadata
)
return hidden_states

def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(
self.lm_head.weight, hidden_states, sampling_metadata
)
return next_tokens

def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
expert_params_mapping = [
("ws" if weight_name in ["w1", "v1"] else "w2s",
f"experts.mlp.{weight_name}")
(
"ws" if weight_name in ["w1", "v1"] else "w2s",
f"experts.mlp.{weight_name}",
)
for weight_name in ["w1", "v1", "w2"]
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
model_name_or_path, cache_dir, load_format, revision
):
for param_name, weight_name in expert_params_mapping:
if weight_name not in name:
continue
Expand All @@ -388,6 +434,7 @@ def load_weights(self,
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)

0 comments on commit 6a27282

Please sign in to comment.