Skip to content

Conversation

@Tabrizian
Copy link
Member

@Tabrizian Tabrizian commented Jul 10, 2024

What does the PR do?

Triton's output token throughput for generate endpoint increases by 18% for concurrency 50. There is still a small gap between vLLM-only and vLLM + Triton solution.

The model is llama-2-7b.

Changes:

  • Use add_request instead of generate for vLLM backend. It appears that with generate , the token generation is delayed until the next iteration of the loop.
  • Delegate response sending to a separate thread.

Next steps:

  • In the long term it might be better to have an async_send API to avoid creating a separate thread for sending the responses.

image

image

Checklist

  • PR title reflects the change and is of format <commit_type>: <Title>
  • Changes are described in the pull request.
  • Related issues are referenced.
  • Populated github labels field
  • Added test plan and verified test passes.
  • Verified that the PR passes existing CI.
  • Verified copyright is correct on all changed files.
  • Added succinct git squash message before merging ref.
  • All template sections are filled out.
  • Optional: Additional screenshots for behavior/output changes with before/after.

Commit Type:

Check the conventional commit type
box here and add the label to the github PR.

  • build
  • ci
  • docs
  • feat
  • fix
  • perf
  • refactor
  • revert
  • style
  • test

Related PRs:

N/A

Where should the reviewer start?

N/A

Test plan:

This is a performance improvement, existing test cases should be sufficient at covering any possible issues.

  • CI Pipeline ID: 16917857 16926866

Caveats:

N/A

Background

N/A

Related Issues: (use one of the action keywords Closes / Fixes / Resolves / Relates to)

N/A

@Tabrizian Tabrizian marked this pull request as draft July 10, 2024 17:53
@Tabrizian Tabrizian requested review from kthui and oandreeva-nv July 18, 2024 18:06
@kthui kthui requested a review from tanmayv25 July 23, 2024 18:33
kthui
kthui previously approved these changes Jul 23, 2024
@Tabrizian Tabrizian marked this pull request as ready for review July 23, 2024 19:10
@tanmayv25
Copy link
Contributor

@Tabrizian Can you add description for the code changes in the PR? Also including the performance improvement you observed and in what cases?

@Tabrizian
Copy link
Member Author

@tanmayv25 I updated the PR description.

@kthui kthui mentioned this pull request Jul 25, 2024
20 tasks
kthui
kthui previously approved these changes Jul 25, 2024
kthui
kthui previously approved these changes Jul 25, 2024
@kthui kthui changed the title Improve vLLM backend performance by using a separate thread for responses perf: Improve vLLM backend performance by using a separate thread for responses Jul 25, 2024
tanmayv25
tanmayv25 previously approved these changes Jul 25, 2024
@oandreeva-nv
Copy link
Contributor

@kthui So you've received close to Iman's perf results after sync?

@kthui
Copy link
Contributor

kthui commented Jul 26, 2024

@kthui So you've received close to Iman's perf results after sync?

yes

@kthui kthui dismissed stale reviews from tanmayv25 and themself via c54dfef July 26, 2024 17:46
@kthui kthui requested review from oandreeva-nv and tanmayv25 July 26, 2024 17:46
Copy link
Contributor

@oandreeva-nv oandreeva-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@kthui kthui merged commit 128abc3 into main Jul 26, 2024
@kthui kthui added the PR: perf A code change that improves performance label Jul 31, 2024
@zhaotyer
Copy link

zhaotyer commented Aug 2, 2024

I test Qwen2-7B-chat whith a A100*80G,The pr did not work, The gap compared to vllm deployments is nearly 40% for concurrency 64

vllm

============ Serving Benchmark Result ============
Successful requests:                     64        
Benchmark duration (s):                  15.35     
Total input tokens:                      14970     
Total generated tokens:                  15304     
Request throughput (req/s):              4.17      
Input token throughput (tok/s):          975.53    
Output token throughput (tok/s):         997.30    
---------------Time to First Token----------------
Mean TTFT (ms):                          1222.87   
Median TTFT (ms):                        1168.27   
P99 TTFT (ms):                           2284.71   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          55.58     
Median TPOT (ms):                        55.96     
P99 TPOT (ms):                           56.88     
---------------Inter-token Latency----------------
Mean ITL (ms):                           53.51     
Median ITL (ms):                         51.57     
P99 ITL (ms):                            78.99     
==================================================

triton+llm+stream

============ Serving Benchmark Result ============
Successful requests:                     64        
Benchmark duration (s):                  24.85     
Total input tokens:                      14970     
Total generated tokens:                  15178     
Request throughput (req/s):              2.58      
Input token throughput (tok/s):          602.44    
Output token throughput (tok/s):         610.81    
---------------Time to First Token----------------
Mean TTFT (ms):                          1762.53   
Median TTFT (ms):                        1687.10   
P99 TTFT (ms):                           3342.62   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          91.16     
Median TPOT (ms):                        91.76     
P99 TPOT (ms):                           93.63     
---------------Inter-token Latency----------------
Mean ITL (ms):                           138.93    
Median ITL (ms):                         93.49     
P99 ITL (ms):                            368.58    
==================================================

triton+llm+no stream

============ Serving Benchmark Result ============
Successful requests:                     64        
Benchmark duration (s):                  18.48
Total input tokens:                      14784     
Total generated tokens:                  15232
Request throughput (req/s):              3.46
Input token throughput (tok/s):          801.91
Output token throughput (tok/s):         826.32

@zhaotyer
Copy link

zhaotyer commented Aug 2, 2024

I test Qwen2-7B-chat whith a A100*80G,The pr did not work, The gap compared to vllm deployments is nearly 40% for concurrency 64

vllm

============ Serving Benchmark Result ============
Successful requests:                     64        
Benchmark duration (s):                  15.35     
Total input tokens:                      14970     
Total generated tokens:                  15304     
Request throughput (req/s):              4.17      
Input token throughput (tok/s):          975.53    
Output token throughput (tok/s):         997.30    
---------------Time to First Token----------------
Mean TTFT (ms):                          1222.87   
Median TTFT (ms):                        1168.27   
P99 TTFT (ms):                           2284.71   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          55.58     
Median TPOT (ms):                        55.96     
P99 TPOT (ms):                           56.88     
---------------Inter-token Latency----------------
Mean ITL (ms):                           53.51     
Median ITL (ms):                         51.57     
P99 ITL (ms):                            78.99     
==================================================

triton+llm+stream

============ Serving Benchmark Result ============
Successful requests:                     64        
Benchmark duration (s):                  24.85     
Total input tokens:                      14970     
Total generated tokens:                  15178     
Request throughput (req/s):              2.58      
Input token throughput (tok/s):          602.44    
Output token throughput (tok/s):         610.81    
---------------Time to First Token----------------
Mean TTFT (ms):                          1762.53   
Median TTFT (ms):                        1687.10   
P99 TTFT (ms):                           3342.62   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          91.16     
Median TPOT (ms):                        91.76     
P99 TPOT (ms):                           93.63     
---------------Inter-token Latency----------------
Mean ITL (ms):                           138.93    
Median ITL (ms):                         93.49     
P99 ITL (ms):                            368.58    
==================================================

triton+llm+no stream

============ Serving Benchmark Result ============
Successful requests:                     64        
Benchmark duration (s):                  18.48
Total input tokens:                      14784     
Total generated tokens:                  15232
Request throughput (req/s):              3.46
Input token throughput (tok/s):          801.91
Output token throughput (tok/s):         826.32

llm+triton gpu utilization is significantly lower than vllm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

PR: perf A code change that improves performance

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants