diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index a384e4aadcf4..ff15af70b88b 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -807,8 +807,10 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): engine_core_outputs[0].scheduler_stats if engine_core_outputs else None ) if expected[0] == 0: + assert scheduler_stats is not None assert scheduler_stats.spec_decoding_stats is None else: + assert scheduler_stats is not None assert scheduler_stats.spec_decoding_stats is not None stats = scheduler_stats.spec_decoding_stats assert stats.num_drafts == expected[0] diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 49b683a1a9f9..63358a0c07d8 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1229,10 +1229,10 @@ def weight_loader( param_data = param_data.narrow(output_dim, shard_offset, shard_size) if loaded_shard_id == "q": - shard_id = self.tp_rank + shard_rank = self.tp_rank else: - shard_id = self.tp_rank // self.num_kv_head_replicas - start_idx = shard_id * shard_size + shard_rank = self.tp_rank // self.num_kv_head_replicas + start_idx = shard_rank * shard_size if not is_sharded_weight: loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 068eecf5e026..93a50a377ee5 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -49,16 +49,16 @@ def __init__( self.quantized = quantized self.weight_quant = weight_quant self.input_quant = input_quant - self.model_compressor = ( - ModelCompressor.from_compression_config(model_compression_config) - if model_compression_config is not None - else None + model_compressor = ModelCompressor.from_compression_config( + model_compression_config ) self.do_sparse_decompress = ( - self.model_compressor is not None - and self.model_compressor.sparsity_config.format + model_compressor is not None + and model_compressor.sparsity_config.format == CompressionFormat.sparse_24_bitmask.value ) + if self.do_sparse_decompress: + self.model_compressor = model_compressor if ( quantized diff --git a/vllm/model_executor/layers/resampler.py b/vllm/model_executor/layers/resampler.py index 6ae2db0f428c..8dc237f8232d 100644 --- a/vllm/model_executor/layers/resampler.py +++ b/vllm/model_executor/layers/resampler.py @@ -200,12 +200,10 @@ def __init__( self.ln_q = norm_layer(embed_dim) self.ln_kv = norm_layer(embed_dim) self.do_post_projection = do_post_projection - self.ln_post = norm_layer(embed_dim) if do_post_projection else None - self.proj = ( - nn.Parameter((embed_dim**-0.5) * torch.empty(embed_dim, embed_dim)) - if do_post_projection - else None - ) + if self.do_post_projection: + self.ln_post = norm_layer(embed_dim) + data = (embed_dim**-0.5) * torch.empty(embed_dim, embed_dim) + self.proj = nn.Parameter(data=data) def _repeat(self, query, N: int): return query.unsqueeze(1).repeat(1, N, 1) diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 8c1ff0300b24..d41b8ae55ea5 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -542,8 +542,7 @@ def _verify_model_compatibility( ) quant_config = getattr(model_config.hf_config, "quantization_config", None) - if quant_config is not None: - quant_method = quant_config.get("quant_method") + if quant_config and (quant_method := quant_config.get("quant_method")): if quant_method == "bitsandbytes": self.pre_quant = True else: @@ -558,7 +557,7 @@ def _verify_model_compatibility( "Prequant BitsAndBytes models with tensor parallelism is not " "supported. Please try with pipeline parallelism." ) - if self.pre_quant: + if quant_config and self.pre_quant: self.load_8bit = quant_config.get("load_in_8bit", False) def _initialize_loader_state( diff --git a/vllm/transformers_utils/processors/ovis2_5.py b/vllm/transformers_utils/processors/ovis2_5.py index fba26d1d0304..bacc58c78b3f 100644 --- a/vllm/transformers_utils/processors/ovis2_5.py +++ b/vllm/transformers_utils/processors/ovis2_5.py @@ -397,6 +397,8 @@ def preprocess_multidata( images.append(image) elif isinstance(video, list): images = video + else: + raise ValueError("Either images or video should be provided.") min_pixels = min( max_pixels if max_pixels is not None else MAX_PIXELS, min_pixels if min_pixels is not None else MIN_PIXELS,