-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
Delay GPU->CPU sync in sampling #1337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Delay GPU->CPU sync in sampling #1337
Conversation
cc @zhuohan123 |
updating the PR! |
Yes, we could certainly optimize that as well - I would like to do that in a followup. |
@zhuohan123 Updated, PTAL! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thank you for your contribution! Left some small comments based on recent PRs. Let's refactor the bloated input_metadata
in a future PR.
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
Need to fix max_prompt_len in worker |
Any progress on this PR? |
@hanzhi713 I am on holidays currently but will wrap it up next week! |
Should be good now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thank you for your contribution!
@Yard1 @zhuohan123 Have you been test the sampling performance on AWQ model? I found that the AWQ model's first sampling process is much slower than FP16 model, but faster in the following tokens. I can't figure the exact bottleneck out :( |
@Yard1 @zhuohan123 Sorry for that I am not familar with Assume we have 3 prompts. The length are 9, 8, 6. On the prompt (prefill) stage, selected_token_indices=[5, 14, 23]. But I think that selected_token_indices should be [8, 16, 22]. Below is a diagram, which is relatively more intuitive. I have been struggling to understand it here, and I would appreciate your interpretation. |
This PR preallocates tensors used to prune hidden states and to index the samples by sampling type allowing us to delay the GPU->CPU sync slightly longer (up until
torch.multinomial
in the sampler).Should fix the slight performance regression introduced in #1309
Technically, we can delay the sync even further by reordering some more operations, but that's left for a future PR.