Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,37 @@ These models are what we list in [supported-text-models][supported-text-models]

### Transformers

vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models and common vision language models are supported! Vision-language models currently accept only image inputs. Support for video inputs will be added in future releases.
vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <1% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers backend".

To check if the modeling backend is Transformers, you can simply do this:
Currently, the Transformers backend works for the following:

- Modalities: embedding models, language models and vision-language models*
- Architectures: encoder-only, decoder-only
- Attention types: full attention and/or sliding attention

_*Vision-language models currently accept only image inputs. Support for video inputs will be added in a future release._

If the Transformers model implementation follows all the steps in [writing a custom model](#writing-custom-models) then, when used with the Transformers backend, it will be compatible with the following features of vLLM:

- All the features listed in the [compatibility matrix](../features/compatibility_matrix.md#feature-x-feature)
- Any combination of the following vLLM parallelisation schemes:
- Pipeline parallel
- Tensor parallel

Checking if the modeling backend is Transformers is as simple as:

```python
from vllm import LLM
llm = LLM(model=...) # Name or path of your model
llm.apply_model(lambda model: print(type(model)))
```

If it is `TransformersForCausalLM` or `TransformersForMultimodalLM` then it means it's based on Transformers!
If the printed type starts with `Transformers...` then it's using the Transformers model implementation!

!!! tip
You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for [offline-inference](../serving/offline_inference.md) or `--model-impl transformers` for the [openai-compatible-server](../serving/openai_compatible_server.md).

!!! note
vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM.
If a model has a vLLM implementation but you would prefer to use the Transformers implementation via the Transformers backend, set `model_impl="transformers"` for [offline inference](../serving/offline_inference.md) or `--model-impl transformers` for the [online serving](../serving/openai_compatible_server.md).

!!! note
In case of vision language models if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Transformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance.
For vision-language models, if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Transformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance.

#### Custom models

Expand Down Expand Up @@ -66,10 +77,11 @@ This section details the necessary modifications to make to a Transformers compa
To make your model compatible with the Transformers backend, it needs:

1. `kwargs` passed down through all modules from `MyModel` to `MyAttention`.
1. If your model is encoder-only, you must also add `is_causal = False` to `MyAttention`.
2. `MyAttention` must use `ALL_ATTENTION_FUNCTIONS` to call attention.
3. `MyModel` must contain `_supports_attention_backend = True`.

<details>
<details class="code">
<summary>modeling_my_model.py</summary>

```python
Expand All @@ -78,6 +90,7 @@ from transformers import PreTrainedModel
from torch import nn

class MyAttention(nn.Module):
is_causal = False # Only do this for encoder-only models

def forward(self, hidden_states, **kwargs):
...
Expand All @@ -101,13 +114,13 @@ Here is what happens in the background when this model is loaded:

1. The config is loaded.
2. `MyModel` Python class is loaded from the `auto_map` in config, and we check that the model `is_backend_compatible()`.
3. `MyModel` is loaded into `TransformersForCausalLM` or `TransformersForMultimodalLM` (see <gh-file:vllm/model_executor/models/transformers.py>) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used.
3. `MyModel` is loaded into one of the Transformers backend classes in <gh-file:vllm/model_executor/models/transformers.py> which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used.

That's it!

For your model to be compatible with vLLM's tensor parallel and/or pipeline parallel features, you must add `base_model_tp_plan` and/or `base_model_pp_plan` to your model's config class:

<details>
<details class="code">
<summary>configuration_my_model.py</summary>

```python
Expand Down
36 changes: 35 additions & 1 deletion tests/models/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ..conftest import HfRunner, VllmRunner
from ..utils import multi_gpu_test, prep_prompts
from .utils import check_logprobs_close
from .utils import check_embeddings_close, check_logprobs_close


def check_implementation(
Expand Down Expand Up @@ -165,6 +165,40 @@ def test_embed_loading(vllm_runner, model):
assert model_config.using_transformers_backend()


@pytest.mark.parametrize(
"model",
[
# Encoder model
"BAAI/bge-base-en-v1.5",
])
def test_embed_correctness(hf_runner, vllm_runner, example_prompts, model):
import transformers
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("4.57.0.dev0")
if installed < required:
pytest.skip("Encoder models with the Transformers backend require "
f"transformers>={required}, but got {installed}")

with vllm_runner(model, max_model_len=512,
model_impl="transformers") as vllm_model:
model_config = vllm_model.llm.llm_engine.model_config
assert model_config.using_transformers_backend()

vllm_outputs = vllm_model.embed(example_prompts)

with hf_runner(model, is_sentence_transformer=True) as hf_model:
hf_outputs = hf_model.encode(example_prompts)

check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
tol=1e-2,
)


@pytest.mark.parametrize(
"model",
["jason9693/Qwen2.5-1.5B-apeach"],
Expand Down
8 changes: 4 additions & 4 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ class AttentionType:
Attention type.
Use string to be compatible with `torch.compile`.
"""
# Decoder attention between previous layer Q/K/V
DECODER = "decoder"
# Encoder attention between previous layer Q/K/V for encoder-decoder
"""Decoder attention between previous layer Q/K/V."""
ENCODER = "encoder"
# Encoder attention between previous layer Q/K/V
"""Encoder attention between previous layer Q/K/V for encoder-decoder."""
ENCODER_ONLY = "encoder_only"
# Attention between dec. Q and enc. K/V for encoder-decoder
"""Encoder attention between previous layer Q/K/V."""
ENCODER_DECODER = "encoder_decoder"
"""Attention between dec. Q and enc. K/V for encoder-decoder."""


class AttentionBackend(ABC):
Expand Down
54 changes: 47 additions & 7 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
PreTrainedModel)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

from vllm.attention import Attention
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, VllmConfig)
Expand Down Expand Up @@ -452,8 +452,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.pp_rank = self.pp_group.rank_in_group
self.tp_size = get_tensor_model_parallel_world_size()

# To be updated in child classes for use in `load_weights`
self.skip_prefixes: Optional[list[str]] = None
# Weights to skip in `self.load_weights`
self.skip_prefixes: list[str] = []
self.skip_substrs: list[str] = []

# Set correct attn and init on "meta" to delay allocating GPU tensors
# TODO: @raushan, use the public `model.set_attn_implementation()`
Expand Down Expand Up @@ -596,7 +597,10 @@ def _tensor_parallel(module: nn.Module,

_tensor_parallel(self.model)

def create_attention_instances(self) -> dict[int, Attention]:
def create_attention_instances(
self,
attn_type: AttentionType = AttentionType.DECODER
) -> dict[int, Attention]:
"""
Create `Attention` instances to inform KV cache allocation.
"""
Expand Down Expand Up @@ -625,7 +629,8 @@ def create_attention_instances(self) -> dict[int, Attention]:
cache_config=self.cache_config,
quant_config=self.quant_config,
per_layer_sliding_window=per_layer_sliding_window,
prefix=f"{i}.attn")
prefix=f"{i}.attn",
attn_type=attn_type)
return attention_instances

def init_parameters(self, module: nn.Module):
Expand Down Expand Up @@ -685,7 +690,11 @@ def forward(

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=self.skip_prefixes)
loader = AutoWeightsLoader(
self,
skip_prefixes=self.skip_prefixes,
skip_substrs=self.skip_substrs,
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)


Expand All @@ -700,6 +709,37 @@ class TransformersModel(TransformersBase):
"model.score": "score",
})

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)

# Some encoder models have the position_ids buffer in the checkpoint
# vLLM will always pass position_ids as an argument, so we skip loading
# the buffer if it exists
self.skip_substrs.append("position_ids")

def create_attention_instances(
self, attn_type: AttentionType = AttentionType.DECODER):
# TODO(hmellor): Better way to detect encoder models
# In encoder models, the attention layers will have `is_causal=False`
is_encoder = lambda m: not getattr(m, "is_causal", True)
# vLLM does not support encoder-decoder models, so if any encoder layer
# is found, we assume the whole model is an encoder model
if any(is_encoder(m) for m in self.model.modules()):
attn_type = AttentionType.ENCODER_ONLY

# Check minimum transformers version for encoder models support
if attn_type == AttentionType.ENCODER_ONLY:
import transformers
from packaging.version import Version
installed = Version(transformers.__version__)
required = Version("4.57.0.dev0")
if installed < required:
raise ValueError(
"Encoder models with the Transformers backend require "
f"transformers>={required}, but got {installed}")

return super().create_attention_instances(attn_type)


@support_torch_compile(enable_if=can_enable_torch_compile)
class TransformersForCausalLM(TransformersBase):
Expand All @@ -710,7 +750,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Tell `TransformersBase.load_weights` to skip
# `lm_head` if the model has tied word embeddings
if self.text_config.tie_word_embeddings:
self.skip_prefixes = ["lm_head."]
self.skip_prefixes.append("lm_head.")

if get_pp_group().is_last_rank:
self.unpadded_vocab_size = self.text_config.vocab_size
Expand Down