Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.tasks import SupportedTask
from vllm.utils import make_async
from vllm.v1.outputs import SamplerOutput
from vllm.v1.outputs import PoolerOutput, SamplerOutput
from vllm.worker.worker_base import WorkerBase

logger = init_logger(__name__)
Expand Down
29 changes: 8 additions & 21 deletions vllm/model_executor/layers/pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from vllm.logger import init_logger
from vllm.model_executor.models.adapters import _load_st_projector
from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
from vllm.tasks import PoolingTask
from vllm.utils import current_stream, resolve_obj_by_qualname
from vllm.utils import resolve_obj_by_qualname
from vllm.v1.outputs import PoolerOutput
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata

logger = init_logger(__name__)
Expand Down Expand Up @@ -190,19 +190,6 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
return PoolerClassify()


def build_output(
all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput:
# Pooling models D2H & synchronize occurs here
if isinstance(all_data, list):
all_data = [d.to("cpu", non_blocking=True) for d in all_data]
else:
all_data = all_data.to("cpu", non_blocking=True)
current_stream().synchronize()

all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
return PoolerOutput(outputs=all_outputs)


class PoolingMethod(nn.Module, ABC):

@staticmethod
Expand Down Expand Up @@ -556,7 +543,7 @@ def forward(
) -> PoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return build_output(pooled_data)
return pooled_data


class StepPooler(Pooler):
Expand Down Expand Up @@ -607,7 +594,7 @@ def forward(
) -> PoolerOutput:
pooled_data = self.extract_states(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return build_output(pooled_data)
return pooled_data


class ClassifierPooler(Pooler):
Expand Down Expand Up @@ -678,7 +665,7 @@ def forward(
]

# scores shape: [batchsize, num_labels]
return build_output(scores)
return scores


class DispatchPooler(Pooler):
Expand Down Expand Up @@ -708,7 +695,7 @@ def forward(
) -> PoolerOutput:
poolers_by_task = self.poolers_by_task

outputs = list[PoolingSequenceGroupOutput]()
outputs = list[torch.Tensor]()
offset = 0
for task, group in groupby(get_tasks(pooling_metadata)):
if not (pooler := poolers_by_task.get(task)):
Expand All @@ -722,10 +709,10 @@ def forward(
pooling_metadata[offset:offset + num_items],
)

outputs.extend(group_output.outputs)
outputs.extend(group_output)
offset += num_items

return PoolerOutput(outputs)
return outputs

def extra_repr(self) -> str:
s = f"supported_task={self.get_supported_tasks()}"
Expand Down
6 changes: 3 additions & 3 deletions vllm/model_executor/models/gritlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
PoolerHead, PoolerNormalize,
PoolingParamsUpdate,
build_output, get_prompt_lens,
get_prompt_lens,
get_prompt_token_ids)
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.sequence import PoolerOutput
from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.v1.outputs import PoolerOutput
from vllm.v1.pool.metadata import PoolingMetadata

from .interfaces_base import default_pooling_type
Expand Down Expand Up @@ -212,7 +212,7 @@ def forward(
) -> PoolerOutput:
pooled_data = self.pooling(hidden_states, pooling_metadata)
pooled_data = self.head(pooled_data, pooling_metadata)
return build_output(pooled_data)
return pooled_data


@default_pooling_type("MEAN")
Expand Down
48 changes: 0 additions & 48 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorOutput)
else:
LoRARequest = Any
KVConnectorOutput = Any

VLLM_TOKEN_ID_ARRAY_TYPE = "l"
Expand Down Expand Up @@ -48,29 +47,6 @@ class RequestMetrics:
model_execute_time: Optional[float] = None


class PoolingSequenceGroupOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True, # type: ignore[call-arg]
):
"""The model output associated with a pooling sequence group."""
# Annotated as Any to be compatible with msgspec
# The actual type is in SequenceGroup.pooled_data
data: Any

def get_data_nbytes(self) -> int:
data: torch.Tensor = self.data
return data.nbytes

def __repr__(self) -> str:
return f"PoolingSequenceGroupOutput(data={self.data}"

def __eq__(self, other: object) -> bool:
if not isinstance(other, PoolingSequenceGroupOutput):
raise NotImplementedError()
return self.data == other.data


# cannot use msgspec.Struct here because Dynamo does not support it
@dataclass
class IntermediateTensors:
Expand Down Expand Up @@ -119,30 +95,6 @@ def __repr__(self) -> str:
return f"IntermediateTensors(tensors={self.tensors})"


class PoolerOutput(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""The output from a pooling operation in the pooling model."""
outputs: list[PoolingSequenceGroupOutput]

def get_data_nbytes(self) -> int:
return sum(o.get_data_nbytes() for o in self.outputs)

def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
return self.outputs[idx]

def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
self.outputs[idx] = value

def __len__(self):
return len(self.outputs)

def __eq__(self, other: object):
return isinstance(other,
self.__class__) and self.outputs == other.outputs


class ExecuteModelRequest(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
Expand Down
7 changes: 6 additions & 1 deletion vllm/v1/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, NamedTuple, Optional
from typing import TYPE_CHECKING, NamedTuple, Optional, Union

import torch

Expand Down Expand Up @@ -65,6 +65,11 @@ def empty_cpu(num_positions: int,
)


# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
PoolerOutput = Union[torch.Tensor, list[torch.Tensor]]


@dataclass
class SamplerOutput:

Expand Down
22 changes: 15 additions & 7 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,14 @@
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, cdiv, check_use_alibi, get_dtype_size,
is_pin_memory_available,
length_from_prompt_token_ids_or_embeds, round_up,
supports_dynamo)
from vllm.utils.jsontree import json_map_leaves
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
Expand All @@ -79,7 +80,7 @@
# yapf: enable
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
DraftTokenIds, LogprobsLists, LogprobsTensors,
ModelRunnerOutput, SamplerOutput)
ModelRunnerOutput, PoolerOutput, SamplerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
from vllm.v1.sample.metadata import SamplingMetadata
Expand Down Expand Up @@ -1823,15 +1824,22 @@ def _pool(
device=hidden_states.device)
seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs]

# Pooling models D2H & synchronize occurs in pooler.py:build_output
raw_pooler_output = self.model.pooler(
hidden_states=hidden_states, pooling_metadata=pooling_metadata)
model = cast(VllmModelForPooling, self.model)
raw_pooler_output: PoolerOutput = model.pooler(
hidden_states=hidden_states,
pooling_metadata=pooling_metadata,
)
raw_pooler_output = json_map_leaves(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes more sense for model runner to handle device transfer

lambda x: x.to("cpu", non_blocking=True),
raw_pooler_output,
)
self._sync_device()

pooler_output: list[Optional[torch.Tensor]] = []
for raw_output, seq_len, prompt_len in zip(
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):

output = raw_output.data if seq_len == prompt_len else None
output = raw_output if seq_len == prompt_len else None
pooler_output.append(output)

return ModelRunnerOutput(
Expand Down Expand Up @@ -3233,7 +3241,7 @@ def _dummy_pooler_run(
for task in self.get_supported_pooling_tasks():
# Run a full batch with each task to ensure none of them OOMs
output = self._dummy_pooler_run_task(hidden_states, task)
output_size[task] = output.get_data_nbytes()
output_size[task] = sum(o.nbytes for o in output)
del output # Allow GC

max_task = max(output_size.items(), key=lambda x: x[1])[0]
Expand Down