Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CFG-guided generation to the vLLM integration #541

Closed
wants to merge 4 commits into from

Conversation

mory91
Copy link
Contributor

@mory91 mory91 commented Jan 15, 2024

Second attempt to add CFG support to vllm.

Context:

Currently vllm supports Regex and JsonSchema for vllm serving. This PR tries to add CFG support to vllm serving.

vllm has vllm.LLM class to handle offline inference and vllm._AsyncLLMEngine to handle async multi node serving. vllm._AsyncLLMEngine is an instance of vllm.LLMEngine and vllm.LLM has an underlying engine property which is an instance of vllm.LLMEngine. This is important because outlines need to get the tokenizer from the vllm and tailor it to its needs and vllm.LLM and vllm.LLMEngine has two distinct way of getting the tokenizer.

Outlines code has some special needs for the tokenizer API that has some slight differences with the vllm's tokenizer. Because of that it has a adap_tokenizer function, to tailor the tokenizer to its needs.

Current bug (Thanks to #536 #535):

  • Current outlines code (in main) has LogitProcessors to bias the logis of vllm and these classes expect vllm._AsyncLLMEngine (as demonstrated in outlines/serve/serve.py ), although it mentioned vllm.LLM as argument type in documentation. So the documentation is currently wrong. Because of this expectation, code in examples/vllm_integration.py fails.

Proposed Solution:

Since vllm mentioned that the vllm.LLMEngine is the main class for vllm Engine, Outlines to expect vllm.LLMEngine in it's LogitProcessors and get the tokenizer from that.
Other solution can be supporting both and branching by a hasattr .
Currently vllm doesnt support vllm.LLM class as input in to the logit processor and thats the error cause in the example file.

@SupreethRao99
Copy link

would this be compatible with distributed inference using Ray? I'm trying to run mistral 7B across 4 GPU's using vLLM's TensorParallel=4, I see the following issue:

TypeError: RegexLogitsProcessor.__call__() missing 1 required positional argument: 'scores'
(RayWorkerVllm pid=24725) Could not apply nest_asyncio: Can't patch loop of type <class 'uvloop.Loop'> [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)

@mory91
Copy link
Contributor Author

mory91 commented Jan 16, 2024

would this be compatible with distributed inference using Ray? I'm trying to run mistral 7B across 4 GPU's using vLLM's TensorParallel=4, I see the following issue:

TypeError: RegexLogitsProcessor.__call__() missing 1 required positional argument: 'scores'
(RayWorkerVllm pid=24725) Could not apply nest_asyncio: Can't patch loop of type <class 'uvloop.Loop'> [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)

This is addressed by #539

@rlouf rlouf changed the title Vllm cfg Add CFG-guided generation to the vLLM integration Jan 16, 2024
@rlouf rlouf added enhancement vLLM Things involving vLLM support labels Jan 16, 2024
@lapp0
Copy link
Contributor

lapp0 commented Jan 16, 2024

would this be compatible with distributed inference using Ray? I'm trying to run mistral 7B across 4 GPU's using vLLM's TensorParallel=4, I see the following issue:

TypeError: RegexLogitsProcessor.__call__() missing 1 required positional argument: 'scores'
(RayWorkerVllm pid=24725) Could not apply nest_asyncio: Can't patch loop of type <class 'uvloop.Loop'> [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)

This is addressed by #539

This PR addresses generation of multiple sequences concurrently, which can take place on a single GPU without tensor parallelism.

However, I'll get around to it since I need guided generation on multiple GPUs as well :) Tensor parallel / Ray is discussed #524

outlines/serve/serve.py Outdated Show resolved Hide resolved
docs/reference/vllm.md Outdated Show resolved Hide resolved
@rlouf
Copy link
Member

rlouf commented Jan 18, 2024

Thank you for opening a PR! A couple question:

  1. Did you have to add adapt_tokenizer because vLLM expects a list when decoding? Does batching work with this change?
  2. Could you please rebase on main instead of merging? This helps keep a linear history.

@mory91
Copy link
Contributor Author

mory91 commented Jan 18, 2024

Hi,
Because the implemented model.decode of outlines, outlines uses batch_decode of the tokenizer class and it returns a list, and in https://github.com/outlines-dev/outlines/blob/65ed0f7e9f900701ba62479982ee29bb47cc2738/outlines/fsm/fsm.py#L345 the FSM generation uses the index zero of the returned list. but vllm's model.decode uses decode of the tokenizer class. without decode returning a list this causes an error.

the serve.py file I think doesn't support batching. The batching checked in example/vllm_integration.py file and it works.

Copy link
Contributor

@lapp0 lapp0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran this command twice and it crashed the server:

curl http://127.0.0.1:8000/generate \
    -d '{
        "prompt": "What is Pi? Give me the first 15 digits: ",
        "grammar": "start: DECIMAL \r\nDIGIT: \"0\"..\"9\"\r\nINT: DIGIT+\r\nDECIMAL: INT \".\" INT? | \".\" INT"
        }'

vllm.engine.async_llm_engine.AsyncEngineDeadError: Task finished unexpectedly. This should never happen! Please open an issue on Github. See stack trace above for the actual cause.

Full traceback

INFO:     127.0.0.1:41880 - "POST /generate HTTP/1.1" 500 Internal Server Error                               
ERROR:    Exception in ASGI application                                                                                                                                                                                      
Traceback (most recent call last):                                                                                                        
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 28, in _raise_exception_on_finish                                                                                                     
    task.result()                                                                                                                         
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 363, in run_engine_loop
    has_requests_in_progress = await self.engine_step()                                     
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 342, in engine_step                                                                              
    request_outputs = await self.engine.step_async()                                                                                                                                                                                                                                 
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 190, in step_async
    all_outputs = await self._run_workers_async(                                                                                                                                                                                                                                     
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 231, in _run_workers_async                                                                       
    all_outputs = await asyncio.gather(*coros)                                                                                            
  File "/usr/lib/python3.10/concurrent/futures/thread.py", line 58, in run                                                                                                              
    result = self.fn(*self.args, **self.kwargs)                                                               
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context                                                                                                                   
    return func(*args, **kwargs)                                                                                                          
  File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py", line 189, in execute_model                                                                                                                           
    output = self.model_runner.execute_model(seq_group_metadata_list,                                                                                                                                                                                                                
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context                                                                                                                   
    return func(*args, **kwargs)                                                                                                          
  File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 461, in execute_model                                                                                                                                                                             
    output = self.model.sample(                                                                                                                                                                                              
  File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/mistral.py", line 291, in sample                                                                                                                                                                          
    next_tokens = self.sampler(self.lm_head.weight, hidden_states,                                                                                                                                                           
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl                                                                                                                                                                        
    return self._call_impl(*args, **kwargs)                                                                                                                                                                                  
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl                                                                                                                        
    return forward_call(*args, **kwargs)                                                                                                                                                                                     
  File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/sampler.py", line 59, in forward                                                                                                                                                                          
    logits = _apply_logits_processors(logits, sampling_metadata)                            
  File "/root/outlines/outlines/serve/vllm.py", line 33, in _patched_apply_logits_processors                                                                                                                                                                                         
    logits_row = logits_processor(seq_id, token_ids, logits_row)                                                                          
  File "/root/outlines/outlines/serve/vllm.py", line 93, in __call__                                                                                                                                                                                                                 
    self.fsm_state[seq_id] = self.fsm.next_state(                                                                                                                                                                                                                                    
  File "/root/outlines/outlines/fsm/fsm.py", line 316, in next_state                                                                                                                    
    self.generation += self.tokenizer.decode([token_id])[0]                                 
TypeError: can only concatenate str (not "list") to str                              
                                                                                                                                                                                        
The above exception was the direct cause of the following exception:                                                                                                                    
                                                                                                                                                                                        
Traceback (most recent call last):                                                          
  File "/usr/local/lib/python3.10/dist-packages/uvicorn/protocols/http/httptools_impl.py", line 419, in run_asgi                                                                        
    result = await app(  # type: ignore[func-returns-value]                                                                                                                                                                                                                                                                                                                      
  File "/usr/local/lib/python3.10/dist-packages/uvicorn/middleware/proxy_headers.py", line 84, in __call__                                                                                                                                                                           
    return await self.app(scope, receive, send)                                                                                                                                         
  File "/usr/local/lib/python3.10/dist-packages/fastapi/applications.py", line 1054, in __call__                                                                                                                                                                                     
    await super().__call__(scope, receive, send)                                                                                                                                        
  File "/usr/local/lib/python3.10/dist-packages/starlette/applications.py", line 123, in __call__                                                                                                                                                                                    
    await self.middleware_stack(scope, receive, send)             
  File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/errors.py", line 186, in __call__                  
    raise exc                                                                               
  File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/errors.py", line 164, in __call__                  
    await self.app(scope, receive, _send)                                                   
  File "/usr/local/lib/python3.10/dist-packages/starlette/middleware/exceptions.py", line 62, in __call__                                                                           
    await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)                                                                                                                                                                                                                                                                                                     
  File "/usr/local/lib/python3.10/dist-packages/starlette/_exception_handler.py", line 64, in wrapped_app                                                                                                                                                                            
    raise exc                                                                                                                             
  File "/usr/local/lib/python3.10/dist-packages/starlette/_exception_handler.py", line 53, in wrapped_app                                                                                                                                                                            
    await app(scope, receive, sender)                                                                                                     
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 762, in __call__                                              
    await self.middleware_stack(scope, receive, send)                                                                                     
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 782, in app                                                   
    await route.handle(scope, receive, send)                                                                                              
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 297, in handle                                                
    await self.app(scope, receive, send)                                                                                                  
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 77, in app                                                    
    await wrap_app_handling_exceptions(app, request)(scope, receive, send)                                                                
  File "/usr/local/lib/python3.10/dist-packages/starlette/_exception_handler.py", line 64, in wrapped_app                                                                                                                                                                            
    raise exc                                                                                                                             
  File "/usr/local/lib/python3.10/dist-packages/starlette/_exception_handler.py", line 53, in wrapped_app                                                                                                                                                                            
    await app(scope, receive, sender)                                                                                                     
  File "/usr/local/lib/python3.10/dist-packages/starlette/routing.py", line 72, in app                                                    
    response = await func(request)                                                                                                        
  File "/usr/local/lib/python3.10/dist-packages/fastapi/routing.py", line 299, in app                                                     
    raise e                                                                                                                               
  File "/usr/local/lib/python3.10/dist-packages/fastapi/routing.py", line 294, in app                                                     
    raw_response = await run_endpoint_function(                                                                                           
  File "/usr/local/lib/python3.10/dist-packages/fastapi/routing.py", line 191, in run_endpoint_function                                                                                                                                                                              
    return await dependant.call(**values)                                                                                                 
  File "/root/outlines/outlines/serve/serve.py", line 100, in generate                                                                    
    async for request_output in results_generator:                                                                                        
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 449, in generate                                                                                                                                                                              
    raise e                                                                                                                               
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 443, in generate                                                                                                                                                                              
    async for request_output in stream:                                                                                                   
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 70, in __anext__                                                                                                                                                                              
    raise result                                                                                                                          
  File "uvloop/cbhandles.pyx", line 63, in uvloop.loop.Handle._run                                                                        
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 37, in _raise_exception_on_finish                                                                                                                                                             
    raise exc                                                                                                                             
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 32, in _raise_exception_on_finish                                                                                                                                                             
    raise AsyncEngineDeadError(                                                                                                           
vllm.engine.async_llm_engine.AsyncEngineDeadError: Task finished unexpectedly. This should never happen! Please open an issue on Github. See stack trace above for the actual cause.                                                                                                 

Environment:

python3 -c "from outlines import _version; print(_version.version)"
0.0.25.dev14+ge16d986.d20240125

python3 -c "import sys; print('Python', sys.version)"
Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]

pip3 freeze follows

``` accelerate==0.26.1 aiohttp==3.9.1 aioprometheus==23.12.0 aiosignal==1.3.1 annotated-types==0.6.0 anyio==4.2.0 async-timeout==4.0.3 attrs==23.2.0 blinker==1.4 certifi==2023.11.17 charset-normalizer==3.3.2 click==8.1.7 cloudpickle==3.0.0 cryptography==3.4.8 datasets==2.16.1 dbus-python==1.2.18 dill==0.3.7 diskcache==5.6.3 distro==1.7.0 exceptiongroup==1.2.0 fastapi==0.109.0 filelock==3.13.1 frozenlist==1.4.1 fsspec==2023.10.0 h11==0.14.0 httplib2==0.20.2 httptools==0.6.1 huggingface-hub==0.20.3 idna==3.6 importlib-metadata==4.6.4 interegular==0.3.3 jeepney==0.7.1 Jinja2==3.1.3 joblib==1.3.2 jsonschema==4.21.1 jsonschema-specifications==2023.12.1 keyring==23.5.0 lark==1.1.9 launchpadlib==1.10.16 lazr.restfulclient==0.14.4 lazr.uri==1.0.6 llvmlite==0.41.1 MarkupSafe==2.1.4 more-itertools==8.10.0 mpmath==1.3.0 msgpack==1.0.7 multidict==6.0.4 multiprocess==0.70.15 nest-asyncio==1.6.0 networkx==3.2.1 ninja==1.11.1.1 numba==0.58.1 numpy==1.26.3 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.18.1 nvidia-nvjitlink-cu12==12.3.101 nvidia-nvtx-cu12==12.1.105 oauthlib==3.2.0 orjson==3.9.12 packaging==23.2 pandas==2.2.0 protobuf==4.25.2 psutil==5.9.8 pyarrow==15.0.0 pyarrow-hotfix==0.6 pydantic==1.10.13 pydantic_core==2.14.6 PyGObject==3.42.1 PyJWT==2.3.0 pyparsing==2.4.7 python-apt==2.4.0+ubuntu2 python-dateutil==2.8.2 python-dotenv==1.0.1 pytz==2023.3.post1 PyYAML==6.0.1 quantile-python==1.1 ray==2.9.1 referencing==0.32.1 regex==2023.12.25 requests==2.31.0 rpds-py==0.17.1 safetensors==0.4.2 scipy==1.12.0 SecretStorage==3.3.1 sentencepiece==0.1.99 six==1.16.0 sniffio==1.3.0 starlette==0.35.1 sympy==1.12 tokenizers==0.15.1 torch==2.1.2 tqdm==4.66.1 transformers==4.37.1 triton==2.1.0 typing_extensions==4.9.0 tzdata==2023.4 UNKNOWN @ file:///root/outlines urllib3==2.1.0 uvicorn==0.27.0 uvloop==0.19.0 vllm==0.2.7 wadllib==1.3.6 watchfiles==0.21.0 websockets==12.0 xformers==0.0.23.post1 xxhash==3.4.1 yarl==1.9.4 zipp==1.0.0 ```

docs/reference/vllm.md Outdated Show resolved Hide resolved
outlines/serve/vllm.py Show resolved Hide resolved
@mory91
Copy link
Contributor Author

mory91 commented Jan 25, 2024

Thanks for pointing out the error. The reason behind that error was that the tokenizer changes the first time you make the request and the second the changes added on to the already changed tokenizer and this make it incompatible with the outlines model behavior and API. I added a flag on to the vllm decoder to check whether the vllm decoder is adapted for outlines model or not and if that's the case, do not continue the adaption of the tokenizer. This change might not seem a good choice but I think in this way we can keep the adaption code in the vllm.py file and don't leak it to the serve file.
Also prevents to run adaption code on each request.

@lapp0
Copy link
Contributor

lapp0 commented Jan 25, 2024

Thanks for pointing out the error. The reason behind that error was that the tokenizer changes the first time you make the request and the second the changes added on to the already changed tokenizer and this make it incompatible with the outlines model behavior and API. I added a flag on to the vllm decoder to check whether the vllm decoder is adapted for outlines model or not and if that's the case, do not continue the adaption of the tokenizer. This change might not seem a good choice but I think in this way we can keep the adaption code in the vllm.py file and don't leak it to the serve file. Also prevents to run adaption code on each request.

Thanks for the explanation.

Could you write a test reproducing this problem?

The test should be guarded by @pytest.mark.skipif, skipping if not torch.cuda.is_available()

@mory91
Copy link
Contributor Author

mory91 commented Jan 26, 2024

done

@lapp0
Copy link
Contributor

lapp0 commented Jan 26, 2024

Will review some time today or tomorrow. Preliminary - it seems you designed your test case in a way which doesn't actually require GPU, so you can remove the skipif.

@mory91
Copy link
Contributor Author

mory91 commented Jan 26, 2024

I agree it should be executed in all cases.

Copy link
Contributor

@lapp0 lapp0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ran a variety of smoke tests, they were successful!

One problem I saw is that the model always ends on ., but I don't see anything in your change-set that would cause this. Seems like a separate issue. Are you able to reproduce?

Otherwise, the change-set looks good to me!

@mory91
Copy link
Contributor Author

mory91 commented Jan 28, 2024

I cannot run the mistral on my machine and with other models I see the json ends without dot (.) sometimes. But a problem that I see with the grammar that you mentioned in your comment in the linked issue is that the connection is not closed after the json is finished, that is I can see the result and json in correct form but the curl connection is not closed. I don't think this problem is related to the changes either.

@lapp0
Copy link
Contributor

lapp0 commented Jan 28, 2024

I cannot run the mistral on my machine and with other models I see the json ends without dot (.) sometimes. But a problem that I see with the grammar that you mentioned in your comment in the linked issue is that the connection is not closed after the json is finished, that is I can see the result and json in correct form but the curl connection is not closed. I don't think this problem is related to the changes either.

If other models don't end before the dot, it must be the models fault that it ends when there's a period and not Outlines, nor Outlines integration with vLLM.

@mory91 mory91 force-pushed the vllm-cfg branch 5 times, most recently from 504c015 to 9bad41b Compare February 5, 2024 17:59
@mory91
Copy link
Contributor Author

mory91 commented Feb 9, 2024 via email



@pytest.mark.parametrize("logit_processor, fsm_str", LOGIT_PROCESSORS)
def test_logit_processor(logit_processor, fsm_str: str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this test doing?


string = tokenizer.convert_tokens_to_string([token])
"""
adapted_tokenizer = _adapt_tokenizer(llm.tokenizer)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be llm.tokenizer.tokenizer


return tokenizer
"""
adapted_tokenizer = _adapt_tokenizer(llm.tokenizer)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above, llm.tokenizer.tokenizer

@br3no
Copy link
Contributor

br3no commented Feb 19, 2024

This PR is currently blocking the integration of outlines into vLLM because the decision was made to support CFG-guided generation from the beginning.

@lapp0, @rlouf is there anything I can do to support this effort?

@simon-mo
Copy link

simon-mo commented Mar 6, 2024

I added CFG in vLLM in this PR vllm-project/vllm#3211

@rlouf
Copy link
Member

rlouf commented Mar 6, 2024

We can close this PR as soon as the PR on vLLM is merged. We will also need to remove the vLLM-related code in the repo.

@rlouf
Copy link
Member

rlouf commented Mar 7, 2024

Thank you for contributing! vLLM is currently implementing this on their end, so I will close this PR for now.

@rlouf rlouf closed this Mar 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement grammar structured generation Linked to structured generation vLLM Things involving vLLM support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants