Description
Your current environment
Description:
When using vLLM with prefix caching enabled (i.e. enable_prefix_caching=True
), the engine fails during inference with an assertion error in the sampler. The error occurs because the lengths of the next_token_ids
and query_indices
do not match. Interestingly, if prefix caching is disabled by setting enable_prefix_caching=False
, the error goes away.
Reproduction Steps:
-
Create a minimal reproduction file:
Save the following asreproduction.py
:import asyncio import torch from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams, TokensPrompt async def main(): engine_args = AsyncEngineArgs( model="path/to/model", # Replace with your actual model path tensor_parallel_size=1, gpu_memory_utilization=0.98, dtype=torch.bfloat16, enable_prefix_caching=True, # Bug occurs when prefix caching is enabled. max_num_seqs=1, max_model_len=16384, ) llm = AsyncLLMEngine.from_engine_args(engine_args) input_ids = [100264, 882, 100266, 4438, 1053] sampling_params = SamplingParams( n=1, max_tokens=15, # Ensure max_tokens is larger than len(input_ids) prompt_logprobs=1, ) # Call inference twice; the first call succeeds, the second fails. for i in range(2): request_id = f"test_request_{i}" print(f"Starting inference call {i}") generator = llm.generate( prompt=TokensPrompt(prompt_token_ids=input_ids), sampling_params=sampling_params, request_id=request_id, ) result = None try: async for r in generator: result = r print(f"Inference call {i} succeeded: {result}") except Exception as e: print(f"Inference call {i} failed with exception: {e}") if __name__ == "__main__": asyncio.run(main())
-
Run the reproduction code:
Execute the file using:python reproduction.py
-
Observe the Error:
The output prints error logs similar to the following:
File "/workspace/home/lab/.conda/envs/grpo/lib/python3.12/site-packages/vllm/model_executor/layers/sampler.py", line 956, in get_logprobs
assert len(next_token_ids) == len(query_indices)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
Expected Behavior:
The engine should complete inference without errors and return a valid generated result.
Actual Behavior:
With prefix caching enabled, vLLM fails during the sampling process with an AssertionError
(due to the mismatch in token lengths), followed by an AsyncEngineDeadError
.
Workaround:
Disabling prefix caching by setting enable_prefix_caching=False
in the AsyncEngineArgs
prevents the error and the engine works as expected.
Please let me know if further details are needed. Thanks!
🐛 Describe the bug
Description:
When using vLLM with prefix caching enabled (i.e. enable_prefix_caching=True
), the engine fails during inference with an assertion error in the sampler. The error occurs because the lengths of the next_token_ids
and query_indices
do not match. Interestingly, if prefix caching is disabled by setting enable_prefix_caching=False
, the error goes away.
Reproduction Steps:
-
Create a minimal reproduction file:
Save the following asreproduction.py
:import asyncio import torch from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams, TokensPrompt async def main(): engine_args = AsyncEngineArgs( model="path/to/model", # Replace with your actual model path tensor_parallel_size=1, gpu_memory_utilization=0.98, dtype=torch.bfloat16, enable_prefix_caching=True, # Bug occurs when prefix caching is enabled. max_num_seqs=1, max_model_len=16384, ) llm = AsyncLLMEngine.from_engine_args(engine_args) input_ids = [100264, 882, 100266, 4438, 1053] sampling_params = SamplingParams( n=1, max_tokens=15, # Ensure max_tokens is larger than len(input_ids) prompt_logprobs=1, ) # Call inference twice; the first call succeeds, the second fails. for i in range(2): request_id = f"test_request_{i}" print(f"Starting inference call {i}") generator = llm.generate( prompt=TokensPrompt(prompt_token_ids=input_ids), sampling_params=sampling_params, request_id=request_id, ) result = None try: async for r in generator: result = r print(f"Inference call {i} succeeded: {result}") except Exception as e: print(f"Inference call {i} failed with exception: {e}") if __name__ == "__main__": asyncio.run(main())
-
Run the reproduction code:
Execute the file using:python reproduction.py
-
Observe the Error:
The output prints error logs similar to the following:
File "/workspace/home/lab/.conda/envs/grpo/lib/python3.12/site-packages/vllm/model_executor/layers/sampler.py", line 956, in get_logprobs
assert len(next_token_ids) == len(query_indices)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
Expected Behavior:
The engine should complete inference without errors and return a valid generated result.
Actual Behavior:
With prefix caching enabled, vLLM fails during the sampling process with an AssertionError
(due to the mismatch in token lengths), followed by an AsyncEngineDeadError
.
Workaround:
Disabling prefix caching by setting enable_prefix_caching=False
in the AsyncEngineArgs
prevents the error and the engine works as expected.
Please let me know if further details are needed. Thanks!
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.