Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changed scheduler to use deques instead of lists #2290

Merged
merged 4 commits into from
Jan 7, 2024

Conversation

NadavShmayo
Copy link
Contributor

@NadavShmayo NadavShmayo commented Dec 27, 2023

Currently the scheduler uses lists to store the running, waiting and swapped requests.
When iterating over each state queue we pop the first item and append to a new list, which is not good for performance since each pop is O(N) time complexity, which means the scheduler currently runs in O(N^2) time complexity (a pop happens for each item in the state queue).
Instead of using lists we could use deques, since we can pop the first item from a deque in O(1) time complexity, making the scheduler run in O(N) time complexity instead of O(N^2).

I wasn't sure whether I should keep the same types in the SchedulerOutputs class and cast from a deque to a list before returning the result, but since it seems to work with the SchedulerOutputs contains deques instead of lists I decided to change the types.

@WoosukKwon WoosukKwon self-requested a review January 2, 2024 17:25
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @NadavShmayo, thanks for submitting the PR! Yes, it seems we should use deque instead of list for the queues in our scheduler. I left some minor comments on the PR. Please take a look at them!

vllm/core/policy.py Outdated Show resolved Hide resolved
vllm/core/scheduler.py Outdated Show resolved Hide resolved
vllm/core/scheduler.py Outdated Show resolved Hide resolved
vllm/core/scheduler.py Outdated Show resolved Hide resolved
vllm/core/scheduler.py Outdated Show resolved Hide resolved
vllm/core/scheduler.py Outdated Show resolved Hide resolved
vllm/core/scheduler.py Outdated Show resolved Hide resolved
vllm/core/scheduler.py Outdated Show resolved Hide resolved
vllm/core/scheduler.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I've made some minor changes to accelerate the merge. Thanks again for the PR!

@WoosukKwon WoosukKwon merged commit 05921a9 into vllm-project:main Jan 7, 2024
2 of 3 checks passed
@NadavShmayo
Copy link
Contributor Author

LGTM. I've made some minor changes to accelerate the merge. Thanks again for the PR!

Hey, sorry for the delayed response.
I didn't get to fixing the code yet, meant to do it tomorrow, but good to see you already did!

Regarding the changes in the new_seq_lens logic, I thought it was better this way, instead of creating a new list each time to calculate the total batched tokens which might be slow.
But these are minor changes anyways.

Thanks for the review! 😄

hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Jan 18, 2024
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
njhill added a commit to njhill/vllm that referenced this pull request Jan 20, 2024
vllm-project#2290 changed the scheduler seq group lists to be deques for more efficient updates, but missed one place where the `running` deque gets converted back to a list.
hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants