Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions vllm/model_executor/input_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from xformers.ops import AttentionBias

from vllm.sampling_params import SamplingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData


Expand All @@ -29,6 +29,8 @@ def __init__(
context_lens: torch.Tensor,
max_context_len: int,
block_tables: torch.Tensor,
selected_token_indices: torch.Tensor,
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
sliding_window: Optional[int] = None,
) -> None:
self.seq_groups = seq_groups
Expand All @@ -38,6 +40,8 @@ def __init__(
self.context_lens = context_lens
self.max_context_len = max_context_len
self.block_tables = block_tables
self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices

self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
self.to_cache = None
Expand Down Expand Up @@ -72,13 +76,16 @@ def __init__(

def __repr__(self) -> str:
# Print only useful metadata.
return (f'InputMetadata('
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'num_prompts={self.num_prompts}, '
f'prompt_lens={self.prompt_lens}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'context_lens={self.context_lens}, '
f'max_context_len={self.max_context_len}, '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'block_tables={self.block_tables}, '
f'slot_mapping={self.slot_mapping})')
return (
f'InputMetadata('
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'num_prompts={self.num_prompts}, '
f'prompt_lens={self.prompt_lens}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'context_lens={self.context_lens}, '
f'max_context_len={self.max_context_len}), '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'block_tables={self.block_tables}, '
f'selected_token_indices={self.selected_token_indices}, '
f'categorized_sample_indices={self.categorized_sample_indices}, '
f'slot_mapping={self.slot_mapping})')
37 changes: 3 additions & 34 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,29 +108,8 @@ def _prune_hidden_states(
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
selected_token_indices: List[int] = []
start_idx = 0
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts:
assert len(seq_ids) == 1, "Prompt input should have only one seq."
prompt_len = input_metadata.prompt_lens[i]
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(start_idx, start_idx + prompt_len - 1))
selected_token_indices.append(start_idx + prompt_len - 1)
start_idx += input_metadata.max_prompt_len
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
range(start_idx, start_idx + num_seqs))
start_idx += num_seqs

selected_token_indices = torch.tensor(selected_token_indices,
dtype=torch.long,
device=hidden_states.device)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
return hidden_states.index_select(0, selected_token_indices)
return hidden_states.index_select(0, input_metadata.selected_token_indices)


def _get_penalties(
Expand Down Expand Up @@ -408,21 +387,11 @@ def _sample(
input_metadata: InputMetadata,
) -> List[Tuple[List[int], List[int]]]:
categorized_seq_group_ids = {t: [] for t in SamplingType}
categorized_sample_indices = {t: [] for t in SamplingType}
start_idx = 0
categorized_sample_indices = input_metadata.categorized_sample_indices
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
_, sampling_params = seq_group
sampling_type = sampling_params.sampling_type
if (i < input_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
# NOTE: prompt token positions do not need sample, skip
prompt_len = input_metadata.prompt_lens[i]
start_idx += prompt_len - 1
categorized_seq_group_ids[sampling_type].append(i)
num_seqs = len(seq_ids)
categorized_sample_indices[sampling_type].extend(
range(start_idx, start_idx + num_seqs))
start_idx += num_seqs

sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
for sampling_type in SamplingType:
Expand Down
47 changes: 45 additions & 2 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.model_executor import get_model, InputMetadata, set_random_seed
from vllm.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel)
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine
from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes
Expand Down Expand Up @@ -161,6 +161,10 @@ def _prepare_inputs(
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
selected_token_indices: List[int] = []
selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices_start_idx = 0

# Add prompt tokens.
prompt_lens: List[int] = []
Expand All @@ -180,6 +184,14 @@ def _prepare_inputs(
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)

if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += prompt_len - 1

categorized_sample_indices[sampling_params.sampling_type].append(
categorized_sample_indices_start_idx)
categorized_sample_indices_start_idx += 1

input_tokens.append(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
Expand All @@ -205,14 +217,37 @@ def _prepare_inputs(
max_num_blocks_per_seq = 0
context_lens: List[int] = []
generation_block_tables: List[List[int]] = []
max_seq_len = max(prompt_lens) if prompt_lens else 1
for seq_group_metadata in seq_group_metadata_list:
if seq_group_metadata.is_prompt:
# We need to do this in this loop as we need to know max_seq_len
assert len(
seq_ids) == 1, "Prompt input should have only one seq."
sampling_params = seq_group_metadata.sampling_params
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + prompt_len - 1))
selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_seq_len
continue

seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))

num_seqs = len(seq_ids)
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + num_seqs))
selected_token_start_idx += num_seqs

categorized_sample_indices[sampling_params.sampling_type].extend(
range(categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx + num_seqs))
categorized_sample_indices_start_idx += num_seqs

for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
Expand Down Expand Up @@ -242,7 +277,6 @@ def _prepare_inputs(
block_table = block_table[-sliding_window_blocks:]
generation_block_tables.append(block_table)

max_seq_len = max(prompt_lens) if prompt_lens else 1
padded_input_tokens = [
_pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens
]
Expand Down Expand Up @@ -272,6 +306,13 @@ def _prepare_inputs(
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device="cuda")
selected_token_indices = torch.tensor(selected_token_indices,
dtype=torch.long,
device="cuda")
categorized_sample_indices = {
t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
for t, seq_ids in categorized_sample_indices.items()
}
block_tables_tensor = torch.tensor(padded_block_tables,
dtype=torch.int,
device="cuda")
Expand All @@ -288,6 +329,8 @@ def _prepare_inputs(
context_lens=context_lens_tensor,
max_context_len=max_context_len,
block_tables=block_tables_tensor,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
sliding_window=self.sliding_window,
)
return tokens_tensor, positions_tensor, input_metadata
Expand Down