Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions vllm/model_executor/model_loader/bitsandbytes_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,8 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,

global_tp_size = get_tensor_model_parallel_world_size()
global_tp_rank = get_tensor_model_parallel_rank()

check_match = (lambda weight_name, module_name: weight_name.
removesuffix(".weight") == module_name)
for (
org_weight_name,
mapped_weight_name,
Expand All @@ -347,12 +348,12 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
) and mapped_weight_name.endswith(".weight"):
# Without sharding
if any(
mapped_weight_name.startswith(module)
check_match(mapped_weight_name, module)
for module in self.unsharded_weights_modules):
weight_sub_tensor = weight_tensor
# Shard by column
elif any(
mapped_weight_name.startswith(module)
check_match(mapped_weight_name, module)
for module in self.column_sharded_weights_modules):
total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
Expand All @@ -362,14 +363,14 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors,
# Weights have fused on disk. In this case, we assume that the
# weight and module use same name.
elif any(
mapped_weight_name.startswith(module)
check_match(mapped_weight_name, module)
for module in self.maybe_fused_weights_modules):
# special case for fused weights
# get the size of each shard weight tensor
total_shard_sizes = next(
(sizes for module, sizes in
self.maybe_fused_weights_modules.items()
if mapped_weight_name.startswith(module)))
if check_match(mapped_weight_name, module)))
total_size = weight_tensor.size(0)
assert total_size == sum(total_shard_sizes)
# get the start/end index of each shard weight tensor
Expand Down