Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Conversation

@kwen2501
Copy link
Contributor

Decoding tensor-form token ids to strings involves a CPU sync.

This PR added a flag to disable on-the-flight conversion.

torchrun --nproc-per-node 4 dist_run.py llama2-7b-chat --pp 2 --disable-in-flight-decode

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1180

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit faec017 with merge base 3aba730 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@kwen2501 kwen2501 requested a review from lessw2020 September 23, 2024 20:16
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 23, 2024
@kwen2501 kwen2501 changed the base branch from main to cache_lanes September 23, 2024 20:17
res = [[] for _ in range(total_prompts)]
num_tokens = 40
# need a row dimension for each prompt in the batch
new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

general comment- do we really need int64 for the dtype given that int32 holds to positive 2B? (meant to ask this earlier).
Seems like some minor savings and I don't see any vocabs getting to 2B anytime soon?
Can just leave as is for consistency and maybe a single PR in the future to sweep this up if we agree int32 is fine.

Copy link
Contributor

@lessw2020 lessw2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good!
longer term, I already started a PR similar to this one but with the idea of doing a chunky decoding, where there is a variable that represents how many tokens to generate between decoding.
The idea here is that we don't want to wait until the entire generation is done before showing things to the user. But as this PR points out, we have a sync every time we decode...so maybe generating 20 tokens and then decode display, another round of generate/display, etc. so that we balance the tension between decoding speed and still updating the user as to the response.
Anyway, we can discuss and can build on this option in this PR to make it more tunable.
Thanks for adding this!

@kwen2501
Copy link
Contributor Author

Ah, sorry, didn't mean to "overwrite" your PR. I was just doing code cleaning and all of a sudden touched the decode path..

The chunk decode idea definitely sounds good! My flag of "disable-in-flight-decode" can be turned into chunk size = -1 (infinity) in your case.

@kwen2501 kwen2501 changed the base branch from cache_lanes to main September 24, 2024 19:55
@kwen2501 kwen2501 merged commit 6fd90bc into main Sep 25, 2024
51 checks passed
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants