diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 9c34159f9a26..4edf193b54ac 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -326,7 +326,8 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors, global_tp_size = get_tensor_model_parallel_world_size() global_tp_rank = get_tensor_model_parallel_rank() - + check_match = (lambda weight_name, module_name: weight_name. + removesuffix(".weight") == module_name) for ( org_weight_name, mapped_weight_name, @@ -347,12 +348,12 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors, ) and mapped_weight_name.endswith(".weight"): # Without sharding if any( - mapped_weight_name.startswith(module) + check_match(mapped_weight_name, module) for module in self.unsharded_weights_modules): weight_sub_tensor = weight_tensor # Shard by column elif any( - mapped_weight_name.startswith(module) + check_match(mapped_weight_name, module) for module in self.column_sharded_weights_modules): total_size = weight_tensor.size(-1) start_index = total_size // tp_size * tp_rank @@ -362,14 +363,14 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors, # Weights have fused on disk. In this case, we assume that the # weight and module use same name. elif any( - mapped_weight_name.startswith(module) + check_match(mapped_weight_name, module) for module in self.maybe_fused_weights_modules): # special case for fused weights # get the size of each shard weight tensor total_shard_sizes = next( (sizes for module, sizes in self.maybe_fused_weights_modules.items() - if mapped_weight_name.startswith(module))) + if check_match(mapped_weight_name, module))) total_size = weight_tensor.size(0) assert total_size == sum(total_shard_sizes) # get the start/end index of each shard weight tensor