Skip to content

Conversation

@kylesayrs
Copy link
Contributor

@kylesayrs kylesayrs commented Nov 13, 2025

Purpose

  • Support quantized weight reloading for the purposes of QeRL-style quantized rollouts for RL training
  • Support weight reloading in the case where kernel formats (parameters after process_weights_after_loading is called) do not match the model loading formats (parameters before process_weights_after_loading is called)

Background

When a model is loaded, its parameters are often modified by process_weights_after_loading to be better suited to the chosen kernel. This processing can involve online weight quantization or operations like padding and repacking. However, the new parameters after processing cannot be used to load new weights, because they are no longer in the format that they were when they were loaded.

Proposed Solution

In order to support reloading of model weights after kernel formatting occurs, information about the model state prior to process_weights_after_loading must be captured. This capture must include metadata like shape and dtype, as well as attributes such as the parameter's weight loader.

Screenshot 2025-11-12 at 19 53 45

After the model format has been captured, the captured metadata can then be used to reconstruct the model load format whenever a weight reload occurs. Newly allocated parameters created by process_weights_after_loading are also deleted prior to restoration in order to avoid device memory overflows.

Screenshot 2025-11-12 at 19 54 01

In the case that a user has already formatted their weights into kernel format, this system can be bypassed by calling reload_weights(process_weights_after_loading=False).

Screenshot 2025-11-12 at 19 54 11

Integration Plan

Many user scripts already exist to reload weights by directly calling model.load_weights, rather than runner.reload_weights. These changes do not affect the existing functionality of those scripts. However, these scripts will only work if the weights already match the kernel format.

If users want to support loading weights which are not in kernel format (for example, to let vllm handle auto weight quantization), they are encouraged to either use runner.reload_weights, or wrap their model.load_weights calls with restore_weights_for_reloading and process_weights_after_loading.

Screenshot 2025-11-12 at 20 15 23

While this design requires some user code changes, I think that these changes are reasonable to add functionality which did not exist previously.

Changes

  • Refactor online_quantization.py
    • Rename core functions to record_weights_for_reloading and restore_weights_for_reloading
    • Remove process_weights_after_loading_already_called, weight_metadata_and_attr_saved, original_weights_rebuild_keys, recorded_weight_attr flags/attributes, instead use a single attribute weight_loading_metadata
    • Utilize meta tensors to handle recording of parameter attributes
    • Expand support to handle quantization configs whose process_weights_after_loading functions create new parameters (these parameters have to be deleted on reload in order to avoid gpu oom)
    • Skip reallocation of parameters which already exist (and have the same weight attributes) in order to reduce latency from model reloading
  • Expand reload_weights to support new arguments and quantization configs
    • weights_iterator allows a user to pass weights in memory. If none is provided, the weights are reloaded from disk
    • process_weights_after_loading allows a user to load weights directly from kernel format

Future Work

  • The reload_weights function can be generalized to runners outside of the GPUModelRunner. Future work could move this implementation to a base class or mixin to share functionality with runners such as tpu, xpu, etc.
  • This implementation theoretically does not have any requirements on quantization configs (quantization postprocessing is free to allocate new parameters, etc.) However, the new functionality of this PR needs to be tested with more configs, after which they can be added to the list of ONLINE_RELOAD_QUANT_CONFIGS.

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
…om disk

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
…nd looks good

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
def reload_weights(self) -> None:
assert getattr(self, "model", None) is not None, (
"Cannot reload weights before model is loaded."
def reload_weights(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so a lot of functionality is here, does this mean all runners have to be modified to include these?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All (well supported) runners already have a reload_weights method. Future work would expand their support to support a weights iterator argument.

The reload_weights function can be generalized to runners outside of the GPUModelRunner. Future work could move this implementation to a base class or mixin to share functionality with runners such as tpu, xpu, etc.

@david6666666
Copy link
Contributor

Which approach should we take? ##26327 or this pr, @kylesayrs @jerryzh168.

@jerryzh168
Copy link
Contributor

jerryzh168 commented Nov 18, 2025

@david6666666 see #26327 (comment), current plan is

the API and where we do quantization is going to change a few times in this process so we don't need to spend time improving current API.

@mergify
Copy link

mergify bot commented Nov 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @kylesayrs.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants