-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[Weight Loading] Expand quantized weight reloading support #28627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
…om disk Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
…nd looks good Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
| def reload_weights(self) -> None: | ||
| assert getattr(self, "model", None) is not None, ( | ||
| "Cannot reload weights before model is loaded." | ||
| def reload_weights( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so a lot of functionality is here, does this mean all runners have to be modified to include these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All (well supported) runners already have a reload_weights method. Future work would expand their support to support a weights iterator argument.
The reload_weights function can be generalized to runners outside of the GPUModelRunner. Future work could move this implementation to a base class or mixin to share functionality with runners such as tpu, xpu, etc.
|
Which approach should we take? ##26327 or this pr, @kylesayrs @jerryzh168. |
|
@david6666666 see #26327 (comment), current plan is
the API and where we do quantization is going to change a few times in this process so we don't need to spend time improving current API. |
|
This pull request has merge conflicts that must be resolved before it can be |
Purpose
process_weights_after_loadingis called) do not match the model loading formats (parameters beforeprocess_weights_after_loadingis called)Background
When a model is loaded, its parameters are often modified by
process_weights_after_loadingto be better suited to the chosen kernel. This processing can involve online weight quantization or operations like padding and repacking. However, the new parameters after processing cannot be used to load new weights, because they are no longer in the format that they were when they were loaded.Proposed Solution
In order to support reloading of model weights after kernel formatting occurs, information about the model state prior to
process_weights_after_loadingmust be captured. This capture must include metadata like shape and dtype, as well as attributes such as the parameter's weight loader.After the model format has been captured, the captured metadata can then be used to reconstruct the model load format whenever a weight reload occurs. Newly allocated parameters created by
process_weights_after_loadingare also deleted prior to restoration in order to avoid device memory overflows.In the case that a user has already formatted their weights into kernel format, this system can be bypassed by calling
reload_weights(process_weights_after_loading=False).Integration Plan
Many user scripts already exist to reload weights by directly calling
model.load_weights, rather thanrunner.reload_weights. These changes do not affect the existing functionality of those scripts. However, these scripts will only work if the weights already match the kernel format.If users want to support loading weights which are not in kernel format (for example, to let vllm handle auto weight quantization), they are encouraged to either use
runner.reload_weights, or wrap theirmodel.load_weightscalls withrestore_weights_for_reloadingandprocess_weights_after_loading.While this design requires some user code changes, I think that these changes are reasonable to add functionality which did not exist previously.
Changes
online_quantization.pyrecord_weights_for_reloadingandrestore_weights_for_reloadingprocess_weights_after_loading_already_called,weight_metadata_and_attr_saved,original_weights_rebuild_keys,recorded_weight_attrflags/attributes, instead use a single attributeweight_loading_metadataprocess_weights_after_loadingfunctions create new parameters (these parameters have to be deleted on reload in order to avoid gpu oom)reload_weightsto support new arguments and quantization configsweights_iteratorallows a user to pass weights in memory. If none is provided, the weights are reloaded from diskprocess_weights_after_loadingallows a user to load weights directly from kernel formatFuture Work
reload_weightsfunction can be generalized to runners outside of theGPUModelRunner. Future work could move this implementation to a base class or mixin to share functionality with runners such as tpu, xpu, etc.ONLINE_RELOAD_QUANT_CONFIGS.