Skip to content

[Bug]: AssertionError in Sampler with Prefix Caching and Prompt Logprobs Enabled. #13105

Open
@aldopareja

Description

@aldopareja

Your current environment

Description:
When using vLLM with prefix caching enabled (i.e. enable_prefix_caching=True), the engine fails during inference with an assertion error in the sampler. The error occurs because the lengths of the next_token_ids and query_indices do not match. Interestingly, if prefix caching is disabled by setting enable_prefix_caching=False, the error goes away.

Reproduction Steps:

  1. Create a minimal reproduction file:
    Save the following as reproduction.py:

     import asyncio
     import torch
     from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams, TokensPrompt
    
     async def main():
         engine_args = AsyncEngineArgs(
             model="path/to/model",  # Replace with your actual model path
             tensor_parallel_size=1,
             gpu_memory_utilization=0.98,
             dtype=torch.bfloat16,
             enable_prefix_caching=True,  # Bug occurs when prefix caching is enabled.
             max_num_seqs=1,
             max_model_len=16384,
         )
         llm = AsyncLLMEngine.from_engine_args(engine_args)
    
         input_ids = [100264, 882, 100266, 4438, 1053]
         sampling_params = SamplingParams(
             n=1,
             max_tokens=15,  # Ensure max_tokens is larger than len(input_ids)
             prompt_logprobs=1,
         )
    
         # Call inference twice; the first call succeeds, the second fails.
         for i in range(2):
             request_id = f"test_request_{i}"
             print(f"Starting inference call {i}")
             generator = llm.generate(
                 prompt=TokensPrompt(prompt_token_ids=input_ids),
                 sampling_params=sampling_params,
                 request_id=request_id,
             )
             result = None
             try:
                 async for r in generator:
                     result = r
                 print(f"Inference call {i} succeeded: {result}")
             except Exception as e:
                 print(f"Inference call {i} failed with exception: {e}")
    
     if __name__ == "__main__":
         asyncio.run(main())
  2. Run the reproduction code:
    Execute the file using:

    python reproduction.py
  3. Observe the Error:
    The output prints error logs similar to the following:

  File "/workspace/home/lab/.conda/envs/grpo/lib/python3.12/site-packages/vllm/model_executor/layers/sampler.py", line 956, in get_logprobs
    assert len(next_token_ids) == len(query_indices)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

Expected Behavior:
The engine should complete inference without errors and return a valid generated result.

Actual Behavior:
With prefix caching enabled, vLLM fails during the sampling process with an AssertionError (due to the mismatch in token lengths), followed by an AsyncEngineDeadError.

Workaround:
Disabling prefix caching by setting enable_prefix_caching=False in the AsyncEngineArgs prevents the error and the engine works as expected.

Please let me know if further details are needed. Thanks!

🐛 Describe the bug

Description:
When using vLLM with prefix caching enabled (i.e. enable_prefix_caching=True), the engine fails during inference with an assertion error in the sampler. The error occurs because the lengths of the next_token_ids and query_indices do not match. Interestingly, if prefix caching is disabled by setting enable_prefix_caching=False, the error goes away.

Reproduction Steps:

  1. Create a minimal reproduction file:
    Save the following as reproduction.py:

     import asyncio
     import torch
     from vllm import AsyncLLMEngine, AsyncEngineArgs, SamplingParams, TokensPrompt
    
     async def main():
         engine_args = AsyncEngineArgs(
             model="path/to/model",  # Replace with your actual model path
             tensor_parallel_size=1,
             gpu_memory_utilization=0.98,
             dtype=torch.bfloat16,
             enable_prefix_caching=True,  # Bug occurs when prefix caching is enabled.
             max_num_seqs=1,
             max_model_len=16384,
         )
         llm = AsyncLLMEngine.from_engine_args(engine_args)
    
         input_ids = [100264, 882, 100266, 4438, 1053]
         sampling_params = SamplingParams(
             n=1,
             max_tokens=15,  # Ensure max_tokens is larger than len(input_ids)
             prompt_logprobs=1,
         )
    
         # Call inference twice; the first call succeeds, the second fails.
         for i in range(2):
             request_id = f"test_request_{i}"
             print(f"Starting inference call {i}")
             generator = llm.generate(
                 prompt=TokensPrompt(prompt_token_ids=input_ids),
                 sampling_params=sampling_params,
                 request_id=request_id,
             )
             result = None
             try:
                 async for r in generator:
                     result = r
                 print(f"Inference call {i} succeeded: {result}")
             except Exception as e:
                 print(f"Inference call {i} failed with exception: {e}")
    
     if __name__ == "__main__":
         asyncio.run(main())
  2. Run the reproduction code:
    Execute the file using:

    python reproduction.py
  3. Observe the Error:
    The output prints error logs similar to the following:

  File "/workspace/home/lab/.conda/envs/grpo/lib/python3.12/site-packages/vllm/model_executor/layers/sampler.py", line 956, in get_logprobs
    assert len(next_token_ids) == len(query_indices)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

Expected Behavior:
The engine should complete inference without errors and return a valid generated result.

Actual Behavior:
With prefix caching enabled, vLLM fails during the sampling process with an AssertionError (due to the mismatch in token lengths), followed by an AsyncEngineDeadError.

Workaround:
Disabling prefix caching by setting enable_prefix_caching=False in the AsyncEngineArgs prevents the error and the engine works as expected.

Please let me know if further details are needed. Thanks!

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleOver 90 days of inactivity

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions