-
Notifications
You must be signed in to change notification settings - Fork 132
Remove full hidden states capturing in custom backend #337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove full hidden states capturing in custom backend #337
Conversation
Summary of ChangesHello @yubofredwang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances memory efficiency in custom backend model implementations by enabling selective capturing of hidden states. Previously, models would capture all hidden states, which could lead to Out-of-Memory errors, especially with larger batch sizes. The changes introduce a new mechanism to specify exactly which layers' hidden states are needed, thereby reducing memory footprint and improving overall performance for debugging, analysis, and advanced use cases. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a useful feature for selectively capturing hidden states to reduce memory usage. The overall approach is sound, but there are critical inconsistencies in the implementation across different models that need to be addressed. Specifically, the custom backend models collect layer inputs instead of outputs, which is inconsistent with the Hugging Face backend implementation and will lead to incorrect behavior. Additionally, there are some minor performance and style issues. My review provides detailed feedback on how to fix these issues to ensure correctness and consistency.
|
It looks like you might have intended to request a code review. The correct command for that is |
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a valuable optimization by allowing selective capturing of hidden states, which should help mitigate OOM errors with large batch sizes. The changes are implemented across several custom backend models by adding a layers_to_output_hidden_states parameter.
The overall approach is sound and the refactoring in eagle3_target_model.py to abstract backend-specific logic is a good improvement. However, I've found a critical bug in qwen3_moe.py that will cause a NameError, and a high-severity issue in qwen2.py where a performance optimization is not correctly applied. There are also several opportunities for code simplification and ensuring consistency across the different model implementations, which I've detailed in the comments.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
* support checkpoint * lint * capture only required hidden states * revert regen * fix llama * backward compatible * Update specforge/modeling/target/custom_backend/qwen3_moe.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * gemini suggests * fix * fix phi --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Motivation
We are always capturing full hidden states with HF for target, which lead to more hidden states kept in memory. This can lead to OOM with larger batch size.
Modifications
This pull request introduces support for selectively outputting hidden states from specific layers in several custom backend model implementations. It does so by adding a new argument,
layers_to_output_hidden_states, to theforwardmethods of multiple model classes. Additionally, it removes the_can_record_outputsattribute from these models, which previously specified which intermediate outputs could be recorded. These changes provide more granular control over which hidden states are returned during model execution and simplify the output recording logic.Key changes by theme:
Selective Hidden State Output:
Added a
layers_to_output_hidden_statesargument to theforwardmethods of the following models:Llama,Llama4,Phi3,GptOss,Qwen2,Qwen3, andQwen3_moe. This allows users to specify a list of layer indices for which hidden states should be returned, rather than returning all or none. The logic for collecting hidden states was updated accordingly in each model.Updated the hidden state accumulation logic in each model's forward loop to check for the presence of
layers_to_output_hidden_statesand only collect hidden states for the specified layers.Ensured that the output objects (
BaseModelOutputWithPast,MoeModelOutputWithPast) include the filteredhidden_statestuple in their return values. [1] [2] [3] [4] [5] [6] [7]Simplification and Cleanup:
_can_record_outputsattribute from all affected model classes, which previously defined which intermediate outputs could be recorded. This streamlines the code and delegates output selection to the new argument. [1] [2] [3] [4]Type and Import Adjustments:
Listfrom thetypingmodule in all affected files, supporting the new argument's type annotation. [1] [2] [3] [4] [5]These changes collectively provide more flexible and efficient access to intermediate model representations, which is useful for debugging, analysis, and advanced use cases.
Related Issues
Accuracy Test
Tested locally, both qwen and llama online training runs fine
Benchmark & Profiling
Checklist