Skip to content

Conversation

yanbing-j
Copy link
Contributor

This PR is to add max-autotune for CPU in torch.compile. Meanwhile, split first token and next token in the log print.

Copy link

pytorch-bot bot commented Aug 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1055

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 74f921c with merge base 8cb8a35 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 23, 2024
@Jack-Khuu
Copy link
Contributor

Thanks for the plumbing through max_autotune @yanbing-j that part looks great to me

@vmpuri Can you give the profiling/token calculation a quick pass though?

@yanbing-j
Copy link
Contributor Author

@Jack-Khuu @vmpuri Thanks for the review! Let me clarify something in updating profile and fix next token calculation.

In profiling, I add the logic of print profiling table both for CPU and GPU. In next token calculation, t includes first token (prefill) and next token (decode_n_tokens). num_tokens_generated is next token length, therefore, t and num_tokens_generated are not match. I suppose this should be a typo when adding first token time. And I also add first token latency and next token latency in the print log seperately.

@yanbing-j
Copy link
Contributor Author

@Jack-Khuu @vmpuri Could you please help review and merge this PR?

Copy link
Contributor

@Jack-Khuu Jack-Khuu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @yanbing-j for the changes and pinging again

Some minor changes, then we're gtg

generate.py Outdated
with {'sequential' if generator_args.sequential_prefill else 'parallel'} prefill,\n\
generate {num_tokens_generated} tokens, in total {tokens_sec:.02f} tokens/sec, \n\
latency_per_token_seconds: {1 / tokens_sec:.04f} s/token\n\
first_token_latency_seconds: {aggregate_metrics.get('time_to_first_token', -1.0):.02f} s/token \n\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same as the time to first token

Suggested change
first_token_latency_seconds: {aggregate_metrics.get('time_to_first_token', -1.0):.02f} s/token \n\

generate.py Outdated
@@ -831,7 +847,8 @@ def callback(x, *, done_generating=False):
)
print("---------------------------------------------------")

tokens_sec = num_tokens_generated / t
tokens_sec = (num_tokens_generated + 1) / t
next_tokens_sec = num_tokens_generated / (t - aggregate_metrics.get('time_to_first_token', -1.0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.get(..., -1 ) in the denominator, logically influcence's t

Suggested change
next_tokens_sec = num_tokens_generated / (t - aggregate_metrics.get('time_to_first_token', -1.0))
next_tokens_sec = num_tokens_generated / (t - aggregate_metrics.get('time_to_first_token', 0))

generate.py Outdated
generate {num_tokens_generated} tokens, in total {tokens_sec:.02f} tokens/sec, \n\
latency_per_token_seconds: {1 / tokens_sec:.04f} s/token\n\
first_token_latency_seconds: {aggregate_metrics.get('time_to_first_token', -1.0):.02f} s/token \n\
next_token_latency_seconds: {1 / next_tokens_sec:.04f} s/token \n\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also log the time next_tokens_sec on the row above

So ultimately it'll be:

toks/sec (with first token)
sec/toks (with first token)
toks/sec (wo first token)
sec/toks (wo first token)

@yanbing-j
Copy link
Contributor Author

@Jack-Khuu Thanks for the comments! Please review again!

Copy link
Contributor

@Jack-Khuu Jack-Khuu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating the logging, everything looks good

Give the merge conflict (should be minor) a look and it's set

@yanbing-j
Copy link
Contributor Author

yanbing-j commented Sep 3, 2024

@Jack-Khuu Thanks for the review!

I have rebased on main branch. And I also hide tokens_sec with jit compilation time. Then the average throughput will be more accurate. Meanwhile, print out the average of total throughput, first token throughput and next tokens throughput.

@yanbing-j yanbing-j force-pushed the yanbing/update branch 2 times, most recently from 6d49401 to 74f921c Compare September 4, 2024 06:39
@yanbing-j
Copy link
Contributor Author

Hi @Jack-Khuu , please help merge this PR. Thanks!

@yanbing-j
Copy link
Contributor Author

Hi @Jack-Khuu , please help review and merge this PR. Thanks!

@Jack-Khuu
Copy link
Contributor

Thanks for following up. I'm debugging some weird behavior with the output messages at the moment (on main)

Will merge this in once that's resolved

@yanbing-j
Copy link
Contributor Author

yanbing-j commented Sep 5, 2024

@Jack-Khuu Thanks! All the CI passes. Please help me update branch, because If I do the rebase, all the CI need to run again.

)

self.decode_one_token = torch.compile(
self.decode_one_token, mode="reduce-overhead", fullgraph=True
self.decode_one_token, fullgraph=True, **kwargs
)

if generator_args.compile_prefill:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pass kwargs to compile_prefill model too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, miss this feature. Create another PR to support. #1112

@Jack-Khuu
Copy link
Contributor

Thanks again for the changes @yanbing-j

Merging in (I'll tweak some nits in a separate PR)

@Jack-Khuu Jack-Khuu merged commit d58923e into pytorch:main Sep 5, 2024
51 checks passed
@sanchitintel
Copy link
Contributor

sanchitintel commented Sep 5, 2024

@yanbing-j, with these changes, I observed different behavior from earlier while running generate.py.
Not sure if it's because of this PR, or because of other changes introduced in torchchat.

With this PR's commits merged onto torchchat's main branch, I see a lot of auto-tuning benchmarking results, even for the same shapes, after I run python3 torchchat.py generate llama3.1 --prompt 'Hello my name is' --quantize '{"linear:int8": {"bitwidth": 8, "groupsize": 0}}' --compile --num-samples 5 --device cpu --tokenizer-path /localdisk/sanchitj/llama_3.1/original/tokenizer.model --max-autotune

Is it expected behavior? Thanks!

@sanchitintel
Copy link
Contributor

sanchitintel commented Sep 5, 2024

@yanbing-j, turns out torch._inductor.config.trace.log_autotuning_results = True is simply displaying more auto-tuning results, but that's fine since auto-tuning is not being done for duplicate input shapes, so it's just that enabling this logging results in duplicate data being printed.

@yanbing-j
Copy link
Contributor Author

@sanchitintel The logs you observed from autotuning is printed by setting torch._inductor.config.trace.log_autotuning_results = True.

@sanchitintel
Copy link
Contributor

sanchitintel commented Sep 6, 2024

Thanks, @yanbing-j! That's what I meant.

Should we disable it, as it's too verbose? Even without torch._inductor.config.trace.log_autotuning_results = True, we get benchmarking logs for all unique input shapes. Thanks!

@yanbing-j
Copy link
Contributor Author

@sanchitintel Remove this config in #1112.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants