Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use of logits_processors has become very slow in v0.3.2 #3087

Closed
saattrupdan opened this issue Feb 28, 2024 · 5 comments · Fixed by #3099
Closed

Use of logits_processors has become very slow in v0.3.2 #3087

saattrupdan opened this issue Feb 28, 2024 · 5 comments · Fixed by #3099
Labels

Comments

@saattrupdan
Copy link

saattrupdan commented Feb 28, 2024

I am using vLLM together with outlines for structured generation.

After having upgraded from v0.3.2, generation became very slow, and the RAM usage leads to OOM crashes now.

Here is a minimal example:

from vllm import LLM, SamplingParams
from outlines.serve.vllm import JSONLogitsProcessor
from pydantic import BaseModel, conlist
import datetime as dt

class Output(BaseModel):
    names: conlist(str, max_length=5)
    organizations: conlist(str, max_length=5)
    locations: conlist(str, max_length=5)
    miscellanous: conlist(str, max_length=5)

llm = LLM('mistralai/Mistral-7B-v0.1', max_model_len=10_000, gpu_memory_utilization=0.9)
logits_processor = JSONLogitsProcessor(schema=Output, llm=llm.llm_engine)
logits_processor.fsm.vocabulary = list(logits_processor.fsm.vocabulary)
prompt = """
Locate all the names, organizations, locations and other miscellaneous entities in the following sentence: 
"Charles went and saw Anna at the coffee shop Starbucks, which was based in a small town in Germany called Essen."
"""
sampling_params = SamplingParams(max_tokens=128, temperature=0, logits_processors=[logits_processor])

t0 = dt.datetime.now()
llm.generate([prompt] * 256, sampling_params=sampling_params)
time_elapsed = (dt.datetime.now() - t0).total_seconds()
print(f"Generation took {time_elapsed:,} seconds.")

When I run the above with vllm==0.3.1, the generation takes 58 seconds and use ~6GB memory, but if I upgrade vllm to v0.3.2 (and none of the other packages are changed), then suddenly the generation takes 418 seconds and spend ~18GB memory. Almost all of the time is spent stalling, not generating anything, but slowly using more and more memory, until it finally begins to generate.

I tried installing a forked version of outlines to see if the stalling was due to the internals of the JSONLogitsProcessor, but it is only called after the "stalling process" is done, so it seems like this is a vLLM issue.

@saattrupdan
Copy link
Author

saattrupdan commented Feb 28, 2024

This seems to be due to the deepcopy of the SamplingParams on this line in the LLMEngine, which will thus also copy out the logits processors, which take up a considerable amount of memory in my case. This was added 2 weeks ago in this PR, and which is part of the v0.3.2 release.

Tagging relevant people to that PR: @njhill @Yard1

@simon-mo
Copy link
Collaborator

@Yard1 I think this should be fixed by next release, especially we are shipping #2819

@Yard1
Copy link
Collaborator

Yard1 commented Feb 28, 2024

Hmm, I see, we should probably make it so that the logit processors are exempt from the deepcopy (unless #2819 already fixes that)

@njhill
Copy link
Member

njhill commented Feb 28, 2024

Ah, yes sorry about this. I can open a PR to do what @Yard1 suggests.

@njhill
Copy link
Member

njhill commented Feb 29, 2024

@Yard1 @simon-mo @saattrupdan fix is in #3099

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants