diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index 265f3d357ad4..b3b5852e4876 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -3,7 +3,7 @@ import torch from xformers.ops import AttentionBias -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData @@ -29,6 +29,8 @@ def __init__( context_lens: torch.Tensor, max_context_len: int, block_tables: torch.Tensor, + selected_token_indices: torch.Tensor, + categorized_sample_indices: Dict[SamplingType, torch.Tensor], sliding_window: Optional[int] = None, ) -> None: self.seq_groups = seq_groups @@ -38,6 +40,8 @@ def __init__( self.context_lens = context_lens self.max_context_len = max_context_len self.block_tables = block_tables + self.selected_token_indices = selected_token_indices + self.categorized_sample_indices = categorized_sample_indices self.max_prompt_len = max(prompt_lens) if prompt_lens else 0 self.to_cache = None @@ -72,13 +76,16 @@ def __init__( def __repr__(self) -> str: # Print only useful metadata. - return (f'InputMetadata(' - f'num_prompt_tokens={self.num_prompt_tokens}, ' - f'num_prompts={self.num_prompts}, ' - f'prompt_lens={self.prompt_lens}, ' - f'num_generation_tokens={self.num_generation_tokens}, ' - f'context_lens={self.context_lens}, ' - f'max_context_len={self.max_context_len}, ' - f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, ' - f'block_tables={self.block_tables}, ' - f'slot_mapping={self.slot_mapping})') + return ( + f'InputMetadata(' + f'num_prompt_tokens={self.num_prompt_tokens}, ' + f'num_prompts={self.num_prompts}, ' + f'prompt_lens={self.prompt_lens}, ' + f'num_generation_tokens={self.num_generation_tokens}, ' + f'context_lens={self.context_lens}, ' + f'max_context_len={self.max_context_len}), ' + f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, ' + f'block_tables={self.block_tables}, ' + f'selected_token_indices={self.selected_token_indices}, ' + f'categorized_sample_indices={self.categorized_sample_indices}, ' + f'slot_mapping={self.slot_mapping})') diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index a12c82a21f46..54fca2bfa68a 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -108,29 +108,8 @@ def _prune_hidden_states( hidden_states: torch.Tensor, input_metadata: InputMetadata, ) -> torch.Tensor: - selected_token_indices: List[int] = [] - start_idx = 0 - for i, seq_group in enumerate(input_metadata.seq_groups): - seq_ids, sampling_params = seq_group - if i < input_metadata.num_prompts: - assert len(seq_ids) == 1, "Prompt input should have only one seq." - prompt_len = input_metadata.prompt_lens[i] - if sampling_params.prompt_logprobs is not None: - selected_token_indices.extend( - range(start_idx, start_idx + prompt_len - 1)) - selected_token_indices.append(start_idx + prompt_len - 1) - start_idx += input_metadata.max_prompt_len - else: - num_seqs = len(seq_ids) - selected_token_indices.extend( - range(start_idx, start_idx + num_seqs)) - start_idx += num_seqs - - selected_token_indices = torch.tensor(selected_token_indices, - dtype=torch.long, - device=hidden_states.device) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - return hidden_states.index_select(0, selected_token_indices) + return hidden_states.index_select(0, input_metadata.selected_token_indices) def _get_penalties( @@ -408,21 +387,11 @@ def _sample( input_metadata: InputMetadata, ) -> List[Tuple[List[int], List[int]]]: categorized_seq_group_ids = {t: [] for t in SamplingType} - categorized_sample_indices = {t: [] for t in SamplingType} - start_idx = 0 + categorized_sample_indices = input_metadata.categorized_sample_indices for i, seq_group in enumerate(input_metadata.seq_groups): - seq_ids, sampling_params = seq_group + _, sampling_params = seq_group sampling_type = sampling_params.sampling_type - if (i < input_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - # NOTE: prompt token positions do not need sample, skip - prompt_len = input_metadata.prompt_lens[i] - start_idx += prompt_len - 1 categorized_seq_group_ids[sampling_type].append(i) - num_seqs = len(seq_ids) - categorized_sample_indices[sampling_type].extend( - range(start_idx, start_idx + num_seqs)) - start_idx += num_seqs sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} for sampling_type in SamplingType: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 1b1f116a2ad9..fd6faecccbfb 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -10,7 +10,7 @@ from vllm.model_executor import get_model, InputMetadata, set_random_seed from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel) -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes @@ -161,6 +161,10 @@ def _prepare_inputs( input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] + selected_token_indices: List[int] = [] + selected_token_start_idx = 0 + categorized_sample_indices = {t: [] for t in SamplingType} + categorized_sample_indices_start_idx = 0 # Add prompt tokens. prompt_lens: List[int] = [] @@ -180,6 +184,14 @@ def _prepare_inputs( prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) + if sampling_params.prompt_logprobs is not None: + # NOTE: prompt token positions do not need sample, skip + categorized_sample_indices_start_idx += prompt_len - 1 + + categorized_sample_indices[sampling_params.sampling_type].append( + categorized_sample_indices_start_idx) + categorized_sample_indices_start_idx += 1 + input_tokens.append(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. @@ -205,14 +217,37 @@ def _prepare_inputs( max_num_blocks_per_seq = 0 context_lens: List[int] = [] generation_block_tables: List[List[int]] = [] + max_seq_len = max(prompt_lens) if prompt_lens else 1 for seq_group_metadata in seq_group_metadata_list: if seq_group_metadata.is_prompt: + # We need to do this in this loop as we need to know max_seq_len + assert len( + seq_ids) == 1, "Prompt input should have only one seq." + sampling_params = seq_group_metadata.sampling_params + if sampling_params.prompt_logprobs is not None: + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + prompt_len - 1)) + selected_token_indices.append(selected_token_start_idx + + prompt_len - 1) + selected_token_start_idx += max_seq_len continue seq_ids = list(seq_group_metadata.seq_data.keys()) sampling_params = seq_group_metadata.sampling_params seq_groups.append((seq_ids, sampling_params)) + num_seqs = len(seq_ids) + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + num_seqs)) + selected_token_start_idx += num_seqs + + categorized_sample_indices[sampling_params.sampling_type].extend( + range(categorized_sample_indices_start_idx, + categorized_sample_indices_start_idx + num_seqs)) + categorized_sample_indices_start_idx += num_seqs + for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] generation_token = seq_data.get_last_token_id() @@ -242,7 +277,6 @@ def _prepare_inputs( block_table = block_table[-sliding_window_blocks:] generation_block_tables.append(block_table) - max_seq_len = max(prompt_lens) if prompt_lens else 1 padded_input_tokens = [ _pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens ] @@ -272,6 +306,13 @@ def _prepare_inputs( context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device="cuda") + selected_token_indices = torch.tensor(selected_token_indices, + dtype=torch.long, + device="cuda") + categorized_sample_indices = { + t: torch.tensor(seq_ids, dtype=torch.int, device="cuda") + for t, seq_ids in categorized_sample_indices.items() + } block_tables_tensor = torch.tensor(padded_block_tables, dtype=torch.int, device="cuda") @@ -288,6 +329,8 @@ def _prepare_inputs( context_lens=context_lens_tensor, max_context_len=max_context_len, block_tables=block_tables_tensor, + selected_token_indices=selected_token_indices, + categorized_sample_indices=categorized_sample_indices, sliding_window=self.sliding_window, ) return tokens_tensor, positions_tensor, input_metadata