diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 3a6738a27be0..cbc0a56a645e 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -17,9 +17,24 @@ These models are what we list in [supported-text-models][supported-text-models] ### Transformers -vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models and common vision language models are supported! Vision-language models currently accept only image inputs. Support for video inputs will be added in future releases. +vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <1% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers backend". -To check if the modeling backend is Transformers, you can simply do this: +Currently, the Transformers backend works for the following: + +- Modalities: embedding models, language models and vision-language models* +- Architectures: encoder-only, decoder-only +- Attention types: full attention and/or sliding attention + +_*Vision-language models currently accept only image inputs. Support for video inputs will be added in a future release._ + +If the Transformers model implementation follows all the steps in [writing a custom model](#writing-custom-models) then, when used with the Transformers backend, it will be compatible with the following features of vLLM: + +- All the features listed in the [compatibility matrix](../features/compatibility_matrix.md#feature-x-feature) +- Any combination of the following vLLM parallelisation schemes: + - Pipeline parallel + - Tensor parallel + +Checking if the modeling backend is Transformers is as simple as: ```python from vllm import LLM @@ -27,16 +42,12 @@ llm = LLM(model=...) # Name or path of your model llm.apply_model(lambda model: print(type(model))) ``` -If it is `TransformersForCausalLM` or `TransformersForMultimodalLM` then it means it's based on Transformers! +If the printed type starts with `Transformers...` then it's using the Transformers model implementation! -!!! tip - You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for [offline-inference](../serving/offline_inference.md) or `--model-impl transformers` for the [openai-compatible-server](../serving/openai_compatible_server.md). - -!!! note - vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM. +If a model has a vLLM implementation but you would prefer to use the Transformers implementation via the Transformers backend, set `model_impl="transformers"` for [offline inference](../serving/offline_inference.md) or `--model-impl transformers` for the [online serving](../serving/openai_compatible_server.md). !!! note - In case of vision language models if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Transformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance. + For vision-language models, if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Transformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance. #### Custom models @@ -66,10 +77,11 @@ This section details the necessary modifications to make to a Transformers compa To make your model compatible with the Transformers backend, it needs: 1. `kwargs` passed down through all modules from `MyModel` to `MyAttention`. + 1. If your model is encoder-only, you must also add `is_causal = False` to `MyAttention`. 2. `MyAttention` must use `ALL_ATTENTION_FUNCTIONS` to call attention. 3. `MyModel` must contain `_supports_attention_backend = True`. -
+
modeling_my_model.py ```python @@ -78,6 +90,7 @@ from transformers import PreTrainedModel from torch import nn class MyAttention(nn.Module): + is_causal = False # Only do this for encoder-only models def forward(self, hidden_states, **kwargs): ... @@ -101,13 +114,13 @@ Here is what happens in the background when this model is loaded: 1. The config is loaded. 2. `MyModel` Python class is loaded from the `auto_map` in config, and we check that the model `is_backend_compatible()`. -3. `MyModel` is loaded into `TransformersForCausalLM` or `TransformersForMultimodalLM` (see ) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used. +3. `MyModel` is loaded into one of the Transformers backend classes in which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used. That's it! For your model to be compatible with vLLM's tensor parallel and/or pipeline parallel features, you must add `base_model_tp_plan` and/or `base_model_pp_plan` to your model's config class: -
+
configuration_my_model.py ```python diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index ba9c3bebc437..1817d4aeee9f 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -9,7 +9,7 @@ from ..conftest import HfRunner, VllmRunner from ..utils import multi_gpu_test, prep_prompts -from .utils import check_logprobs_close +from .utils import check_embeddings_close, check_logprobs_close def check_implementation( @@ -165,6 +165,40 @@ def test_embed_loading(vllm_runner, model): assert model_config.using_transformers_backend() +@pytest.mark.parametrize( + "model", + [ + # Encoder model + "BAAI/bge-base-en-v1.5", + ]) +def test_embed_correctness(hf_runner, vllm_runner, example_prompts, model): + import transformers + from packaging.version import Version + installed = Version(transformers.__version__) + required = Version("4.57.0.dev0") + if installed < required: + pytest.skip("Encoder models with the Transformers backend require " + f"transformers>={required}, but got {installed}") + + with vllm_runner(model, max_model_len=512, + model_impl="transformers") as vllm_model: + model_config = vllm_model.llm.llm_engine.model_config + assert model_config.using_transformers_backend() + + vllm_outputs = vllm_model.embed(example_prompts) + + with hf_runner(model, is_sentence_transformer=True) as hf_model: + hf_outputs = hf_model.encode(example_prompts) + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) + + @pytest.mark.parametrize( "model", ["jason9693/Qwen2.5-1.5B-apeach"], diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 75bcdc4bbcf0..dfde67e1713c 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -23,14 +23,14 @@ class AttentionType: Attention type. Use string to be compatible with `torch.compile`. """ - # Decoder attention between previous layer Q/K/V DECODER = "decoder" - # Encoder attention between previous layer Q/K/V for encoder-decoder + """Decoder attention between previous layer Q/K/V.""" ENCODER = "encoder" - # Encoder attention between previous layer Q/K/V + """Encoder attention between previous layer Q/K/V for encoder-decoder.""" ENCODER_ONLY = "encoder_only" - # Attention between dec. Q and enc. K/V for encoder-decoder + """Encoder attention between previous layer Q/K/V.""" ENCODER_DECODER = "encoder_decoder" + """Attention between dec. Q and enc. K/V for encoder-decoder.""" class AttentionBackend(ABC): diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 4f51441e28ef..f40a20dee63d 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -27,7 +27,7 @@ PreTrainedModel) from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS -from vllm.attention import Attention +from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, VllmConfig) @@ -452,8 +452,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.pp_rank = self.pp_group.rank_in_group self.tp_size = get_tensor_model_parallel_world_size() - # To be updated in child classes for use in `load_weights` - self.skip_prefixes: Optional[list[str]] = None + # Weights to skip in `self.load_weights` + self.skip_prefixes: list[str] = [] + self.skip_substrs: list[str] = [] # Set correct attn and init on "meta" to delay allocating GPU tensors # TODO: @raushan, use the public `model.set_attn_implementation()` @@ -596,7 +597,10 @@ def _tensor_parallel(module: nn.Module, _tensor_parallel(self.model) - def create_attention_instances(self) -> dict[int, Attention]: + def create_attention_instances( + self, + attn_type: AttentionType = AttentionType.DECODER + ) -> dict[int, Attention]: """ Create `Attention` instances to inform KV cache allocation. """ @@ -625,7 +629,8 @@ def create_attention_instances(self) -> dict[int, Attention]: cache_config=self.cache_config, quant_config=self.quant_config, per_layer_sliding_window=per_layer_sliding_window, - prefix=f"{i}.attn") + prefix=f"{i}.attn", + attn_type=attn_type) return attention_instances def init_parameters(self, module: nn.Module): @@ -685,7 +690,11 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self, skip_prefixes=self.skip_prefixes) + loader = AutoWeightsLoader( + self, + skip_prefixes=self.skip_prefixes, + skip_substrs=self.skip_substrs, + ) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) @@ -700,6 +709,37 @@ class TransformersModel(TransformersBase): "model.score": "score", }) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + # Some encoder models have the position_ids buffer in the checkpoint + # vLLM will always pass position_ids as an argument, so we skip loading + # the buffer if it exists + self.skip_substrs.append("position_ids") + + def create_attention_instances( + self, attn_type: AttentionType = AttentionType.DECODER): + # TODO(hmellor): Better way to detect encoder models + # In encoder models, the attention layers will have `is_causal=False` + is_encoder = lambda m: not getattr(m, "is_causal", True) + # vLLM does not support encoder-decoder models, so if any encoder layer + # is found, we assume the whole model is an encoder model + if any(is_encoder(m) for m in self.model.modules()): + attn_type = AttentionType.ENCODER_ONLY + + # Check minimum transformers version for encoder models support + if attn_type == AttentionType.ENCODER_ONLY: + import transformers + from packaging.version import Version + installed = Version(transformers.__version__) + required = Version("4.57.0.dev0") + if installed < required: + raise ValueError( + "Encoder models with the Transformers backend require " + f"transformers>={required}, but got {installed}") + + return super().create_attention_instances(attn_type) + @support_torch_compile(enable_if=can_enable_torch_compile) class TransformersForCausalLM(TransformersBase): @@ -710,7 +750,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Tell `TransformersBase.load_weights` to skip # `lm_head` if the model has tied word embeddings if self.text_config.tie_word_embeddings: - self.skip_prefixes = ["lm_head."] + self.skip_prefixes.append("lm_head.") if get_pp_group().is_last_rank: self.unpadded_vocab_size = self.text_config.vocab_size