-
Notifications
You must be signed in to change notification settings - Fork 248
[Distributed] Make decode in flight optional #1180
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1180
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit faec017 with merge base 3aba730 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| res = [[] for _ in range(total_prompts)] | ||
| num_tokens = 40 | ||
| # need a row dimension for each prompt in the batch | ||
| new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
general comment- do we really need int64 for the dtype given that int32 holds to positive 2B? (meant to ask this earlier).
Seems like some minor savings and I don't see any vocabs getting to 2B anytime soon?
Can just leave as is for consistency and maybe a single PR in the future to sweep this up if we agree int32 is fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good!
longer term, I already started a PR similar to this one but with the idea of doing a chunky decoding, where there is a variable that represents how many tokens to generate between decoding.
The idea here is that we don't want to wait until the entire generation is done before showing things to the user. But as this PR points out, we have a sync every time we decode...so maybe generating 20 tokens and then decode display, another round of generate/display, etc. so that we balance the tension between decoding speed and still updating the user as to the response.
Anyway, we can discuss and can build on this option in this PR to make it more tunable.
Thanks for adding this!
|
Ah, sorry, didn't mean to "overwrite" your PR. I was just doing code cleaning and all of a sudden touched the decode path.. The chunk decode idea definitely sounds good! My flag of "disable-in-flight-decode" can be turned into chunk size = -1 (infinity) in your case. |
711eb41 to
3364c72
Compare
Decoding tensor-form token ids to strings involves a CPU sync.
This PR added a flag to disable on-the-flight conversion.