Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions tests/worker/test_worker.py → tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@
import torch

from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.worker import Worker
from vllm.worker.model_runner import ModelRunner


def test_worker_prepare_inputs_for_prompt():
worker = Worker(None, None, None)
worker.block_size = 16
def test_prepare_prompt():
model_runner = ModelRunner(None, None, None)
model_runner.set_block_size(16)

batch_size = random.randint(1, 256)
prompt_lens = []
seq_group_metadata_list = []
for i in range(batch_size):
# make sure all tokens fit into one block
prompt_len = i % (worker.block_size - 1) + 1
prompt_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len)
seq_data = list(range(prompt_len))
seq_group_metadata_list.append(
Expand All @@ -24,19 +25,23 @@ def test_worker_prepare_inputs_for_prompt():
sampling_params=SamplingParams(temperature=0),
block_tables={0: [1]},
))

expected_selected_token_indices = []
selected_token_start_idx = 0
max_seq_len = max(prompt_lens)
for prompt_len in prompt_lens:
expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_seq_len
input_tokens, input_positions, input_metadata = worker._prepare_inputs(
input_tokens, input_positions, _ = model_runner._prepare_prompt(
seq_group_metadata_list)
assert input_tokens.shape == input_positions.shape == (batch_size,
max_seq_len)
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
assert input_tokens.shape == (batch_size, max_seq_len)
assert input_positions.shape == (batch_size, max_seq_len)
torch.testing.assert_close(input_tokens, input_positions)
actual = input_metadata.selected_token_indices

actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices,
device=actual.device,
dtype=actual.dtype)
Expand Down