diff --git a/tests/worker/test_worker.py b/tests/worker/test_model_runner.py similarity index 68% rename from tests/worker/test_worker.py rename to tests/worker/test_model_runner.py index b2c61e24efdd..949a7e2292a4 100644 --- a/tests/worker/test_worker.py +++ b/tests/worker/test_model_runner.py @@ -2,18 +2,19 @@ import torch from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.worker.worker import Worker +from vllm.worker.model_runner import ModelRunner -def test_worker_prepare_inputs_for_prompt(): - worker = Worker(None, None, None) - worker.block_size = 16 +def test_prepare_prompt(): + model_runner = ModelRunner(None, None, None) + model_runner.set_block_size(16) + batch_size = random.randint(1, 256) prompt_lens = [] seq_group_metadata_list = [] for i in range(batch_size): # make sure all tokens fit into one block - prompt_len = i % (worker.block_size - 1) + 1 + prompt_len = i % (model_runner.block_size - 1) + 1 prompt_lens.append(prompt_len) seq_data = list(range(prompt_len)) seq_group_metadata_list.append( @@ -24,6 +25,7 @@ def test_worker_prepare_inputs_for_prompt(): sampling_params=SamplingParams(temperature=0), block_tables={0: [1]}, )) + expected_selected_token_indices = [] selected_token_start_idx = 0 max_seq_len = max(prompt_lens) @@ -31,12 +33,15 @@ def test_worker_prepare_inputs_for_prompt(): expected_selected_token_indices.append(selected_token_start_idx + prompt_len - 1) selected_token_start_idx += max_seq_len - input_tokens, input_positions, input_metadata = worker._prepare_inputs( + input_tokens, input_positions, _ = model_runner._prepare_prompt( seq_group_metadata_list) - assert input_tokens.shape == input_positions.shape == (batch_size, - max_seq_len) + sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, + prompt_lens) + assert input_tokens.shape == (batch_size, max_seq_len) + assert input_positions.shape == (batch_size, max_seq_len) torch.testing.assert_close(input_tokens, input_positions) - actual = input_metadata.selected_token_indices + + actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, dtype=actual.dtype)