diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index fd4b992c3821..e3063ec2b8ab 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -15,10 +15,10 @@ from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, PoolerOutput +from vllm.sequence import ExecuteModelRequest from vllm.tasks import SupportedTask from vllm.utils import make_async -from vllm.v1.outputs import SamplerOutput +from vllm.v1.outputs import PoolerOutput, SamplerOutput from vllm.worker.worker_base import WorkerBase logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 4a97438b1bb2..139011ce10be 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -16,9 +16,9 @@ from vllm.logger import init_logger from vllm.model_executor.models.adapters import _load_st_projector from vllm.pooling_params import PoolingParams -from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.tasks import PoolingTask -from vllm.utils import current_stream, resolve_obj_by_qualname +from vllm.utils import resolve_obj_by_qualname +from vllm.v1.outputs import PoolerOutput from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata logger = init_logger(__name__) @@ -190,19 +190,6 @@ def get_cross_encoder_activation_function(config: PretrainedConfig): return PoolerClassify() -def build_output( - all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput: - # Pooling models D2H & synchronize occurs here - if isinstance(all_data, list): - all_data = [d.to("cpu", non_blocking=True) for d in all_data] - else: - all_data = all_data.to("cpu", non_blocking=True) - current_stream().synchronize() - - all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data] - return PoolerOutput(outputs=all_outputs) - - class PoolingMethod(nn.Module, ABC): @staticmethod @@ -556,7 +543,7 @@ def forward( ) -> PoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) pooled_data = self.head(pooled_data, pooling_metadata) - return build_output(pooled_data) + return pooled_data class StepPooler(Pooler): @@ -607,7 +594,7 @@ def forward( ) -> PoolerOutput: pooled_data = self.extract_states(hidden_states, pooling_metadata) pooled_data = self.head(pooled_data, pooling_metadata) - return build_output(pooled_data) + return pooled_data class ClassifierPooler(Pooler): @@ -678,7 +665,7 @@ def forward( ] # scores shape: [batchsize, num_labels] - return build_output(scores) + return scores class DispatchPooler(Pooler): @@ -708,7 +695,7 @@ def forward( ) -> PoolerOutput: poolers_by_task = self.poolers_by_task - outputs = list[PoolingSequenceGroupOutput]() + outputs = list[torch.Tensor]() offset = 0 for task, group in groupby(get_tasks(pooling_metadata)): if not (pooler := poolers_by_task.get(task)): @@ -722,10 +709,10 @@ def forward( pooling_metadata[offset:offset + num_items], ) - outputs.extend(group_output.outputs) + outputs.extend(group_output) offset += num_items - return PoolerOutput(outputs) + return outputs def extra_repr(self) -> str: s = f"supported_task={self.get_supported_tasks()}" diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index a7b324f0a5b4..639d8f620f94 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -12,12 +12,12 @@ from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler, PoolerHead, PoolerNormalize, PoolingParamsUpdate, - build_output, get_prompt_lens, + get_prompt_lens, get_prompt_token_ids) from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.sequence import PoolerOutput from vllm.tasks import PoolingTask from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.v1.outputs import PoolerOutput from vllm.v1.pool.metadata import PoolingMetadata from .interfaces_base import default_pooling_type @@ -212,7 +212,7 @@ def forward( ) -> PoolerOutput: pooled_data = self.pooling(hidden_states, pooling_metadata) pooled_data = self.head(pooled_data, pooling_metadata) - return build_output(pooled_data) + return pooled_data @default_pooling_type("MEAN") diff --git a/vllm/sequence.py b/vllm/sequence.py index a6c194fbac0b..e5f23d47a660 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -11,7 +11,6 @@ from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorOutput) else: - LoRARequest = Any KVConnectorOutput = Any VLLM_TOKEN_ID_ARRAY_TYPE = "l" @@ -48,29 +47,6 @@ class RequestMetrics: model_execute_time: Optional[float] = None -class PoolingSequenceGroupOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True, # type: ignore[call-arg] -): - """The model output associated with a pooling sequence group.""" - # Annotated as Any to be compatible with msgspec - # The actual type is in SequenceGroup.pooled_data - data: Any - - def get_data_nbytes(self) -> int: - data: torch.Tensor = self.data - return data.nbytes - - def __repr__(self) -> str: - return f"PoolingSequenceGroupOutput(data={self.data}" - - def __eq__(self, other: object) -> bool: - if not isinstance(other, PoolingSequenceGroupOutput): - raise NotImplementedError() - return self.data == other.data - - # cannot use msgspec.Struct here because Dynamo does not support it @dataclass class IntermediateTensors: @@ -119,30 +95,6 @@ def __repr__(self) -> str: return f"IntermediateTensors(tensors={self.tensors})" -class PoolerOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """The output from a pooling operation in the pooling model.""" - outputs: list[PoolingSequenceGroupOutput] - - def get_data_nbytes(self) -> int: - return sum(o.get_data_nbytes() for o in self.outputs) - - def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput: - return self.outputs[idx] - - def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput): - self.outputs[idx] = value - - def __len__(self): - return len(self.outputs) - - def __eq__(self, other: object): - return isinstance(other, - self.__class__) and self.outputs == other.outputs - - class ExecuteModelRequest( msgspec.Struct, array_like=True, # type: ignore[call-arg] diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index e6cc6019b172..01f3676abd92 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, NamedTuple, Optional +from typing import TYPE_CHECKING, NamedTuple, Optional, Union import torch @@ -65,6 +65,11 @@ def empty_cpu(num_positions: int, ) +# [num_reqs, ] +# The shape of each element depends on the pooler used +PoolerOutput = Union[torch.Tensor, list[torch.Tensor]] + + @dataclass class SamplerOutput: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ee339e22cea9..1261a9c38038 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -52,13 +52,14 @@ from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, cdiv, check_use_alibi, get_dtype_size, is_pin_memory_available, length_from_prompt_token_ids_or_embeds, round_up, supports_dynamo) +from vllm.utils.jsontree import json_map_leaves from vllm.v1.attention.backends.flash_attn import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( @@ -79,7 +80,7 @@ # yapf: enable from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, LogprobsLists, LogprobsTensors, - ModelRunnerOutput, SamplerOutput) + ModelRunnerOutput, PoolerOutput, SamplerOutput) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata @@ -1823,15 +1824,22 @@ def _pool( device=hidden_states.device) seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs] - # Pooling models D2H & synchronize occurs in pooler.py:build_output - raw_pooler_output = self.model.pooler( - hidden_states=hidden_states, pooling_metadata=pooling_metadata) + model = cast(VllmModelForPooling, self.model) + raw_pooler_output: PoolerOutput = model.pooler( + hidden_states=hidden_states, + pooling_metadata=pooling_metadata, + ) + raw_pooler_output = json_map_leaves( + lambda x: x.to("cpu", non_blocking=True), + raw_pooler_output, + ) + self._sync_device() pooler_output: list[Optional[torch.Tensor]] = [] for raw_output, seq_len, prompt_len in zip( raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens): - output = raw_output.data if seq_len == prompt_len else None + output = raw_output if seq_len == prompt_len else None pooler_output.append(output) return ModelRunnerOutput( @@ -3233,7 +3241,7 @@ def _dummy_pooler_run( for task in self.get_supported_pooling_tasks(): # Run a full batch with each task to ensure none of them OOMs output = self._dummy_pooler_run_task(hidden_states, task) - output_size[task] = output.get_data_nbytes() + output_size[task] = sum(o.nbytes for o in output) del output # Allow GC max_task = max(output_size.items(), key=lambda x: x[1])[0]