Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][Core] Update Outlines Integration from FSM to Guide #4109

Merged
merged 17 commits into from
Jun 5, 2024

Conversation

br3no
Copy link
Contributor

@br3no br3no commented Apr 16, 2024

This PR updates the Outlines Integration from FSM to the new Guide interface.

Since I'm not sure where to place the change, I added both labels [Frontend] and [Core].

FIX #3715

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link
Collaborator

@simon-mo simon-mo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

requirements-common.txt Outdated Show resolved Hide resolved
@simon-mo
Copy link
Collaborator

I think the test failure in the entrypoints test might be related, and there's merge conflict. 🙏

@br3no
Copy link
Contributor Author

br3no commented Apr 18, 2024

I'll have a look at the failing frontend test.

@br3no
Copy link
Contributor Author

br3no commented Apr 19, 2024

I can reproduce the error in the test on my dev environment. The generation does not stop when it should, generating IP addresses like this: 100.101.102.10319216. I'm investigating why this happens and reached out to @rlouf in the outlines discord.

I have pushed some small improvements to the test code:

@navster888
Copy link

hey @br3no @simon-mo any updates on getting this PR merged? The pinned version of outlines is preventing us from picking up a bug fix in included in 0.0.40. Is there anyway we can pull the relaxed version constraint into it's own PR to unblock?

@br3no
Copy link
Contributor Author

br3no commented Apr 30, 2024

I'm having a call with outlines contributors on Thursday. While there is no guarantee we will have a solution for the problem, I'd propose to wait until then. If there's no progress by the end of the week, I'd open a separate PR for the unpinning.

What do you think?

@br3no
Copy link
Contributor Author

br3no commented May 2, 2024

I have opened #4558 because moving to the Guide API will require outlines-dev/outlines#856 to be fixed first.

@br3no
Copy link
Contributor Author

br3no commented May 10, 2024

I have closed #4558 in favor of this PR. I expect to make progress on this next week. Waiting for outlines-dev/outlines#874.

@rlouf
Copy link

rlouf commented May 11, 2024

outlines-dev/outlines#874 merged, thank you for your patience!

@br3no
Copy link
Contributor Author

br3no commented May 12, 2024

Great, thanks for the support @rlouf!

Can you tell already when you plan to release?

@rlouf
Copy link

rlouf commented May 19, 2024

Next week I think. I need to make sure other downstream libraries have pinned the outlines version to avoid surprises.

@br3no
Copy link
Contributor Author

br3no commented May 29, 2024

@saattrupdan thanks for pointing this out.

I believe this fix will not work in vLLM. The thing is that the logits processors are cached here:

def _get_cached_logits_processor(guide: str,

Not resetting the state means that the dictionary will grow with every generation and will never be cleaned.

Or am I missing something?

@br3no
Copy link
Contributor Author

br3no commented May 29, 2024

Let me summarize the issue raised by @saattrupdan:

  • when n > 1, all batches share one logits processor. E.g. in a chat completion request
    await get_guided_decoding_logits_processor(
  • the logits processor is stateful, caching the state information for each sequence prefix
  • in [vLLM integration bug] Generated output is stopped for all samples in batch outlines-dev/outlines#757 it was observed that for big enough values of n, some sequences break
  • because we cache the logits processor, even separate requests (with the same guide and tokenizer, which won't be uncommon) will share the same logits processor
  • @saattrupdan's fix was to never reset the state
  • never resetting the state will lead to a memory leak, as the cache will increase indefinitely
  • but if we want to continue cleaning-up the state, this can and will break in different ways
  • we need to find a way to attach the cache's life-cycle to the request. It should be clean on a new request and only be cleared after the request is finished

The expensive thing though, is building the Guide for a particular tokenizer and guide specification (regex, json-schema or grammar), not the logits processor itself. So I believe that pushing the cache one level lower (caching only the Guide and not the logits processor), should solve the issue.

@simon-mo should I open a new issue for this, or should I just fix this together with the change to the Guide API, subject of this PR?

@simon-mo
Copy link
Collaborator

Same PR works since this is small enough. also cc @njhill I think you mentioned similar issue

@dongxiaolong dongxiaolong mentioned this pull request Jun 3, 2024
6 tasks
@simon-mo simon-mo mentioned this pull request Jun 3, 2024
2 tasks
@njhill
Copy link
Collaborator

njhill commented Jun 3, 2024

Sorry, missed @simon-mo tagging me above. Yes we encountered the same problem and had been thinking to introduce a LogitsProcessorFactory abstract class that can be included in the SamplingParams instead of LogitsProcessors (but allowing those too for backwards compatibility. This could have both create_processor() and return_processor(lp) methods. Latter not required but could be used for pooling.

vLLM would then ensure to call this separately for each sequence. Stateless LP factories can just return a constant LP from the method.

Note this is currently also problem if a list of prompts is passed in the API, and/or if n > 1 like you said (including beam search).

WDYT?

@maxdebayser has started prototyping this.

@br3no
Copy link
Contributor Author

br3no commented Jun 4, 2024

@njhill I like the direction of your proposal. This would allow us to invert control and get rid of the 10 lines setting up guided decoding in the create_chat_completion method. Is there a PR with @maxdebayser's sketch?

While I think this is the right thing to do on the long run, I believe we should fix this problem ASAP. Would you mind having a look at the code in this PR which handles this issue by pushing the cache one level down? I believe this would work nicely as a band-aid until the refactoring you proposed is implemented.

The core changes are:

  • we are no longer caching the LogitProcessors. They are now created on each request anew.
  • the LogitsProcessors no longer reset state – so all n sequences can safely share the same state cache in the lifetime of one request.
  • the Guide object is cached globally for every guide/tokenizer pair. This is the expensive thing we don't want to recompute on every request. (this is how we were caching the LogitProcessors)

@br3no
Copy link
Contributor Author

br3no commented Jun 4, 2024

PS: this PR is ready for review @simon-mo

I'm just waiting for outlines to be released, so that we can get rid of the regression in the tests.

@maxdebayser
Copy link

maxdebayser commented Jun 4, 2024

Hi @br3no, we also found a problem with the FSM state being shared between sequences. This curl here causes a crash:

curl http://localhost:8000/v1/completions   -H "Content-Type: application/json"   -d '{
    "model": "<MY_MODEL>",
    "prompt": ["An example of a json document: ", "Another example of a json document: "],
    "max_tokens": 100,
    "temperature": 0,
    "guided_decoding_backend": "outlines",
    "response_format": {"type":"json_object"},
    "logit_bias": {"100": -100}
  }'

We have a sketch for a PR here: IBM/vllm#38 . It uses factories like @njhill mentioned, so that each sequence can have it's own logits processor copy.

The changes in our PR solve this particular issue, but I think the CFGLogitsProcessor would still crash if the sequence is preempted with the recompute policy. But I don't know yet how to test this hypothesis.

@maxdebayser
Copy link

@br3no , I've tested your changes with the curl command above. The code doesn't crash anymore, but I think the output ends prematurely:

{
  "id": "cmpl-d1e40aa8af734fcc98238fa5e0c2ecac",
  "object": "text_completion",
  "created": 1717517730,
  "choices": [
    {
      "index": 0,
      "text": "\n\n\n{\n",
      "logprobs": null,
      "finish_reason": "stop",
      "stop_reason": null
    },
    {
      "index": 1,
      "text": "\n\n\n{",
      "logprobs": null,
      "finish_reason": "stop",
      "stop_reason": null
    }
  ],
  "usage": {
    "prompt_tokens": 18,
    "total_tokens": 29,
    "completion_tokens": 11
  }
}

So while I do think that your changes are an improvement, I think we need to add the factory PR as well to deal with stateful LogitsProcessors properly. Just for reference, the curl request returns this on our PR:

{
  "id": "cmpl-861186441f9c459c981fcab5abd33f5b",
  "object": "text_completion",
  "created": 1717515348,
  "choices": [
    {
      "index": 0,
      "text": "\n\n\n{\n\"name\" \n: \n\"John Doe\"\n,\n\"age\" \n: 30\n,\n\"city\" \n: \n\"New York\"\n}\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
      "logprobs": null,
      "finish_reason": "length",
      "stop_reason": null
    },
    {
      "index": 1,
      "text": "\n\n\n{\n\"name\" \n: \n\"John Doe\"\n,\n\"age\" \n: 30\n,\n\"city\" \n: \n\"New York\"\n}\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
      "logprobs": null,
      "finish_reason": "length",
      "stop_reason": null
    }
  ],
  "usage": {
    "prompt_tokens": 18,
    "total_tokens": 218,
    "completion_tokens": 200
  }
}

Copy link

@maxdebayser maxdebayser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link

@maxdebayser maxdebayser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, if I run

curl http://localhost:8000/v1/completions   -H "Content-Type: application/json"   -d '{
    "model": "<MY_MODEL>",
    "prompt": ["An example of a json document: ", "Another example of a json document: "],
    "max_tokens": 100,
    "temperature": 0,
    "guided_decoding_backend": "outlines",
    "response_format": {"type":"json_object"},
    "logit_bias": {"100": -100}
  }'

followed by

curl http://localhost:8000/v1/completions   -H "Content-Type: application/json"   -d '{
    "model": "<MY_MODEL>",
    "prompt": ["An example of a json document: "],                                        
    "max_tokens": 100,
    "temperature": 0,
    "guided_decoding_backend": "outlines",
    "response_format": {"type":"json_object"},
    "logit_bias": {"100": -100}
  }' | jq

I get this crash:

    File "/home/develop/.local/lib/python3.11/site-packages/vllm/model_executor/guided_decoding/outlines_logits_processors.py", line 47, in __call__
    instruction = self._guide.get_next_instruction(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/develop/.local/lib/python3.11/site-packages/outlines/fsm/guide.py", line 349, in get_next_instruction
    interactive.exhaust_lexer()
  File "/opt/vllm/lib/python3.11/site-packages/lark/parsers/lalr_interactive_parser.py", line 52, in exhaust_lexer
    return list(self.iter_parse())
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/vllm/lib/python3.11/site-packages/lark/parsers/lalr_interactive_parser.py", line 43, in iter_parse
    for token in self.lexer_thread.lex(self.parser_state):
  File "/opt/vllm/lib/python3.11/site-packages/lark/lexer.py", line 674, in lex
    raise UnexpectedToken(token, e.allowed, state=parser_state, token_history=[last_token], terminals_by_name=self.root_lexer.terminals_by_name)
lark.exceptions.UnexpectedToken: Unexpected token Token('LBRACE', '{') at line 7, column 2.
Expected one of: 
	* UNESCAPED_STRING
	* RBRACE
Previous tokens: [Token('LBRACE', '{')]

If I restart the server and just send the curl request with a single prompt several times, only the first request generates useful JSON. After that it returns:

{
  "id": "cmpl-948cb1c4033b40f1a07846bed5ac4de9",
  "object": "text_completion",
  "created": 1717518689,
  "model": "/llama_eval_storage/LLaMa/models/hf/7B-F",
  "choices": [
    {
      "index": 0,
      "text": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n",
      "logprobs": null,
      "finish_reason": "length",
      "stop_reason": null
    }
  ],
  "usage": {
    "prompt_tokens": 9,
    "total_tokens": 109,
    "completion_tokens": 100
  }
}

I'm testing with LLama2 7b and outlines==0.0.41

@br3no
Copy link
Contributor Author

br3no commented Jun 4, 2024

@maxdebayser thanks for looking into it!

we also found a problem with the FSM state being shared between sequences

I have just looked into Outlines and while RegexGuide is thread-safe, CFGGuide is not.

I'll check if I can increase the size of the band-aid a bit to deal with this case...

@br3no
Copy link
Contributor Author

br3no commented Jun 5, 2024

I have pushed a commit that comes close to what was there before and at the same time does not lead to crashes on n > 1. It's still not correct, though...

Every CFGGuide can only be used by one sequence at a time. We need the factory idea in IBM/vllm#38. Note that the performance will still be horrible, since we will create a new Guide for every sequence. This is unavoidable, unfortunately.

It probably makes sense to pre-build a (large) pool of CFGGuides for the use-case where request.response_format.type == "json_object". I believe this is a pragmatic and useful solution, since this will be the most common case.

To really solve this, Outlines would need to be changed to make the CFGGuide thread-safe. I believe this would require a large effort.

@rlouf, could you give us your expert opinion on this?

@maxdebayser
Copy link

@br3no , I can confirm that your latest commit fixes the problem where state from previous single-sequence requests is carried over to new requests.

Copy link
Collaborator

@simon-mo simon-mo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for doing this. Please let me know when this PR is ready to be merged!

@simon-mo
Copy link
Collaborator

simon-mo commented Jun 5, 2024

Actually it seems complete given @maxdebayser's comment. I will merge now.

@simon-mo simon-mo merged commit 7b0a0df into vllm-project:main Jun 5, 2024
101 of 103 checks passed
blinkbear pushed a commit to blinkbear/vllm that referenced this pull request Jun 6, 2024
…llm-project#4109)

Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Breno Faria <breno.faria@intrafind.com>
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request Jun 11, 2024
…llm-project#4109)

Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Breno Faria <breno.faria@intrafind.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature]: Update Outlines Integration from FSM to Guide
7 participants