-
Notifications
You must be signed in to change notification settings - Fork 722
Introduce seq_len as inference param, and improve warnings #15716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/15716
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 76b9c28 with merge base e774b77 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@abhinaykukkadapu has exported this pull request. If you are a Meta employee, you can view the originating Diff in D86696759. |
This PR needs a
|
winskuo-quic
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @abhinaykukkadapu,
Thanks for the PR.
I would like to know if we could achieve the same thing with the combination of
--pre_gen_pte & --max_seq_len.
For example:
During compile time, you can provide:
--max_seq_len 1024 --compile_only
During inference, you can provide:
--max_seq_len 512 --pre_gen_pte ./path_to_pregen_pte
@winskuo-quic Thanks for the quick review, the goal of this additional param is to avoid confusing the users of the script thinking that --max_seq_len can be dynamic but it is a static param and is fixed during compilation. Currently, one can pass |
|
|
||
| parser.add_argument( | ||
| "--seq_len", | ||
| help="[Runtime-time] Maximum number of tokens to generate (prompt + output). If not specified, uses --max_seq_len. Will be clamped to compiled max_seq_len if exceeded.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe [Runtime-time] -> [Runtime]
I see. |
@winskuo-quic Right, i think we should be transparent in these, i've already added the messages that i think would be helpful, but suggest if you have any more in mind.
Yeah i think this param is misleading as it clearly represents max context a model can have, so using it to during inference is a bit misleading for someone new to Qcom delegate, who might not know we use static llama and might think this can change total context length of the model dynamically. Also, all i did is use the same param of qnn runner (--seq_len) and bubbled up to llama.py script. If you think this adds to the confusion, I'm also open to remove it and only keep warning messages in this PR. |
7e0bdb0 to
536047d
Compare
…5716) Summary: Changes: 1. add `--seq_len` param to llama script to distinguish max_seq_len which is compile time param 2. Add warnings in the runner when `seq_len` is clamped to `max_seq_len` to avoid silently clamping it. 3. Add warnings in the token generator when EOS is not reached due to insufficient seq_len or max_seq_len. Differential Revision: D86696759
| outputs.append(f.read()) | ||
|
|
||
| seq_len = args.max_seq_len | ||
| # Use --seq_len if provided (inference-only), otherwise fall back to --max_seq_len |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite follow why we need seq_len, can you share more? I feel like it might further causing the confusion...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The runner itself uses a param named --seq_len it is the script llama.py repurposing --max_seq_len and using it as seq_len for the runner. For context, if you've followed the internal discussion on benchmark numbers, we thought we swept the benchmarks for max_seq_len and prompt_length but the sweeping was not valid for max_seq_len because it is a compile time param and is ignored if it is more than what the model is compiled with.
Looking at CoreML and they use separate params as well: https://github.com/pytorch/executorch/blob/main/examples/apple/coreml/llama/run.py#L97-L103
Open to suggestions, if there are better ways to distinguish this param during compile time vs runtime from ux perspective.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to re-visit the seq_len in qnn llama runner because it seems it's only for debugging/profiling purpose. Users wouldn't need to use it and I want to make sure we don't make it more confusing.
536047d to
aa98ba0
Compare
…5716) Summary: Changes: 1. add `--seq_len` param to llama script to distinguish max_seq_len which is compile time param 2. Add warnings in the runner when `seq_len` is clamped to `max_seq_len` to avoid silently clamping it. 3. Add warnings in the token generator when EOS is not reached due to insufficient seq_len or max_seq_len. Differential Revision: D86696759
…5716) Summary: Changes: 1. add `--seq_len` param to llama script to distinguish max_seq_len which is compile time param 2. Add warnings in the runner when `seq_len` is clamped to `max_seq_len` to avoid silently clamping it. 3. Add warnings in the token generator when EOS is not reached due to insufficient seq_len or max_seq_len. Differential Revision: D86696759
aa98ba0 to
76b9c28
Compare
winskuo-quic
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation.
I am slightly leaning toward leaving warnings messages in runtime and reuse the --max_seq_len flag in llama.py, which aligns with your new commit.
Thanks for the help on improving user's experience!
Summary:
Changes:
--seq_lenparam to llama script to distinguish max_seq_len which is compile time paramseq_lenis clamped tomax_seq_lento avoid silently clamping it.Differential Revision: D86696759
Tests
Use --seq_len=600, prompt_len=512
Use --seq_len=2048, prefill_ar_len=1024