-
-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Misc]: Throughput/Latency for guided_json with ~100% GPU cache utilization #3567
Comments
Is the JSON schema complex at all, and is it the same each time? The
I'm interested in fixing the performance here. |
Hi Simon, The JSON schema is the same at all times, and it is as follows:
Thanks for looking into this 🫶 |
@simon-mo any update on this? 😊 |
Facing similar issue here, I have a json with 14 fields, the request stucks forever. |
My schema only has 2 fields and also has significant latency issues than when using without guided_json. Would love to have this fixed as model performance severely decreases without it. |
If testing lm-format-enforcer, I highly recommend adding the latest version of it to the image, as there have been performance improvements to the JsonSchemaParser. The next version of vLLM will include them, but until them, do |
what speeds are you getting @noamgat vs the outlines backend? |
I didn't test on A100/H100s, but on my dev setup (GTX 3090, Mistral7B), for simple schemas, I was getting a less than 2x reduction of tokens/s. |
+1, it seems not GPU related, I tested with A100 / V100 GPUs both have similar issue. Using line profiler, I found this get_guided_decoding_logits_processor call takes 93% time |
This get Just tested, the speed up is not obvious, probabbly the main bottleneck is still the get_guided_decoding_logits_processor |
|
@noamgat here's a profling when I use lm-format-enforcer 0.10.1.
The two decoding in for loop seems took most time. Happy to make further test if needed. |
Build regular token list only happens at the first request that uses LMFE.
Does it happen every time? If so, maybe there is a problem with lru caching
not working.
…On Tue, May 7, 2024, 20:30 nullpointer0xffff ***@***.***> wrote:
@noamgat <https://github.com/noamgat> here's a profling when I use
lm-format-enforcer 0.10.1.
/lib/python3.10/site-packages/lmformatenforcer/integrations/transformers.py
Function: _build_regular_tokens_list at line 58
Line # Hits Time Per Hit % Time Line Contents
==============================================================
58 @Profile
59 def _build_regular_tokens_list(tokenizer: PreTrainedTokenizerBase) -> List[Tuple[int, str, bool]]:
60 1 912794903.0 9e+08 9.5 token_0 = tokenizer.encode("0")[-1]
61 1 8025.0 8025.0 0.0 regular_tokens = []
62 128257 28050361.0 218.7 0.3 for token_idx in range(len(tokenizer)):
63 128256 78294452.0 610.5 0.8 if token_idx in tokenizer.all_special_ids:
64 2 450.0 225.0 0.0 continue
65 # We prepend token 0 and skip the first letter of the result to get a space if the token is a start word.
66 128254 5319568501.0 41476.8 55.3 decoded_after_0 = tokenizer.decode([token_0, token_idx])[1:]
67 128254 3162992335.0 24661.9 32.9 decoded_regular = tokenizer.decode([token_idx])
68 128254 56427009.0 440.0 0.6 is_word_start_token = len(decoded_after_0) > len(decoded_regular)
69 128254 61975079.0 483.2 0.6 regular_tokens.append((token_idx, decoded_after_0, is_word_start_token))
70 1 240.0 240.0 0.0 return regular_tokens
The two decoding in for loop seems took most time. Happy to make further
test if needed.
—
Reply to this email directly, view it on GitHub
<#3567 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAKFA2A6JFYKEQAHJKKQFTTZBEFTBAVCNFSM6AAAAABFDGVKP6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAOJYHE2TQMZUGY>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Just clarifying - if possible, start the tokens/s measuring and/or profiling from the second request onwards. While the warm-up time is also something that can be optimized, the post-warmup performance matters much more for real-world use cases. This is true for all guided decoding backends. |
@nullpointer0xffff @jens-create I just confirmed the caching of LMFE tokenizer init (very very slow) via @lru_cache is working so
|
maybe we can modify the call method to separate the mask computation from the logits adjustment. This allows the mask to be computed once and reused. let me know if this makes sense @simon-mo |
Just sharing my experience with this issue - Seems to align with the OPs experience. Summary: CPU constrained guidance means that batching can't scale correctly. Vllm 0.4.2
Single request:Outlines: ~70 tps - CPU 100% Batched requests:Outlines: ~70 tps - CPU 100%
Example guidance:regex
~~~response\n# Content\\n([.\\W\\w]+)\\n{2}~{3} json
{"type":"object","properties":{"test":{"type":"string"}},"required":["test"]} |
Here's line timings for model_executor/guided_decoding/outlines_logits_processors.py
|
Based on that timing breakdown, can you try to replace |
I've been doing some further perf analysis and breaking things out a bit to try and understand the bottleneck. Doesn't seem to be related to the indexer but rather, moving the allowed_tokens array around. cpu first, move to gpu
straight to gpu:
|
|
Beyond this, I'm not sure I see a way forward without changes to outlines and lm-format-enforcer to provide the information in a more efficient structure than a List. Does anyone see any memorisation opportunities here to at least reduce the iteration counts? |
One thing I think we could do to make it faster is to use the fact that allowed_tokens is either almost all the tokens, or none of the tokens. Currently the mask is created at -math.inf, but we could also create the mask at 0 if the length of allowed_tokens is < scores.shape[0]/2 and then fill_ with -math.inf instead? |
I went down that same line of thinking - I don't think the timings above support it however. Its getting the python List into a Tensor that seems to be 80%+ of the cost per iteration. So short of data structure changes upstream, my current thinking is we're left with iteration optimisations - can we avoid going back to fsm.allowed_token_ids in certain situations. Not sure on that yet - still learning how this all fits together. |
Are the PRs for this issue currently stalled due to competing priorities? |
I believe my PR mitigates the issue, would appreciate some testing to verify dottxt-ai/outlines#1013 It decreases the worst case Outlines structured generation logits processor overhead from 50ms to 1ms. Please let me know if this doesn't resolve this threads issue and another approach is needed. Edit: It's available in Outlines
|
I did not test the solution yet, but here is the same issue posted to Outlines. dottxt-ai/outlines#1011 (comment) I suspected it might be a threading issue in Vllm logit decoding, but sounds like this thread would’ve picked up on that if that were the case. |
Just anecdotally I am still seeing a 20x slowdown doing batch generation with |
Thanks for reporting back Seems we've tackled one component of slowness, as reported by @lynkz-matt-psaltis "Its getting the python List into a Tensor that seems to be 80%+ of the cost per iteration". However as suggested by @robcaulk, there may be another source of slowness involving vLLMs Outlines integration involving the
I also have an alternative hypothesis: iff the issue only occurs when using vLLM + ray, perhaps the state of these logits processors is expensive to communicate between ray workers. Could you please
|
Is anyone still experiencing this issue? If so, could you please provide the call you're making (if distinct from that of someone previously posting in this thread), and hardware info so I can reproduce? |
Did any of these changes make it into a release? Or is all of this still in PRs and such |
Hey Andrew, yes the issue still exists - reproduction is made using the script I furnished in the following thread: Regarding Vllm+ray, the issue also occurs without However, using llm-enforcer, we see identical issue, and llm-enforcer uses the same ThreadPoolExecuter. So that would suggest the problem is likely something to do with the logit processing inside Vllm. |
Hey all, In terms of changes as I understand them: Outlines has moved towards tensors being passed between outlines guides and vllm which is a fantastic step for this. If we look at the current outlines integration here: https://github.com/outlines-dev/outlines/blob/main/outlines/integrations/vllm.py#L112 We see the use of the caching strategy to reduce repetitive operations. So, depending on if you use vllm with outlines integration or use outlines directly, you're like to get different behaviours. Here's a draft PR taking the caching opportunity in outlines and reapplying it to vllm on top of latest. #6715 Would love feedback on if this helps others. It may be we need the threadpool work as well to realise a true benefit. This helps us at least with batched requests on 1xA100 24 core compute. To reaffirm - the caching does not change individual request performance, it helps with batched requests. I've definitely seen similar call stacks and flame graphs to what @robcaulk has described above. Hoping to find time for further investigations on all this around Septemember so if anyone beat me to it fanatastic! :D |
It's in
Thanks so much for investigating! That helps a lot.
Will review. |
New draft PR based on our discussion @lapp0 to remove the duplicate logits processors in VLLM. I'm seeing an upstream error with CachedLlamaTokenizerFast here: Will investigate if that's not a known issue. Cheers! |
@lynkz-matt-psaltis thanks for the info! We need a normalized tokenizer. Could you ensure Also beware that |
Hi, would like to ask if caching the outlines allowed tokens might cause the GPU to OOM in high load, since it is unbounded? Or is there another process within VLLM that will manage VRAM usage? |
For each state in the automata, outlines stores a tensor with a list of legal token IDs. However we don't store these tensors on GPU, so it shouldn't result in CudaOOM. |
Anything you want to discuss about vllm.
Hi,
I am running some benchmarks on the
vllm.entrypoints.openai.api_server
measuring latency and throughput with different number of concurrent requests.Specs:
I am sending 1000 requests with random prompts of token length 512. These are the results I get (see attached image):
Guided_json
Non-guided_json
At 10 concurrent request (GPU utlization << 100%
Non-guided_json: ~20 ms median token time
guided_json: ~ 160 ms median token time
Currently the application I am building heavily relies on guided_json, however, to put it in an online setting I would like to ask 1) are the numbers I experience sensible and 2) what can be done to improve performance in the guided_json paradigm?
I am debating whether I should try and prompt my way to structured outputs and thus avoiding constrained decoding.
)The text was updated successfully, but these errors were encountered: