From dbf10e5999c211dd5e7fe7e36be373b2d31cb15c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 11 Nov 2025 13:37:14 -0500 Subject: [PATCH 01/11] no new parameters Signed-off-by: Kyle Sayers --- .../model_executor/layers/quantization/fp8.py | 74 +++++++------------ .../layers/quantization/kv_cache.py | 7 ++ 2 files changed, 33 insertions(+), 48 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index cb065eb68b66..f83248b82e8f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -7,7 +7,6 @@ import torch from torch.nn import Module -from torch.nn.parameter import Parameter import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -525,13 +524,10 @@ def process_weights_after_loading(self, layer: Module) -> None: weight = weight.t() # Update layer with new values. - layer.weight = Parameter(weight.data, requires_grad=False) - layer.weight_scale = Parameter(weight_scale.data, requires_grad=False) - layer.input_scale = ( - Parameter(input_scale, requires_grad=False) - if input_scale is not None - else None - ) + layer.weight.copy_(weight.data) + layer.weight_scale.copy_(weight_scale.data) + if input_scale is not None: + layer.input_scale.copy_(input_scale) if self.use_marlin: prepare_fp8_layer_for_marlin(layer, size_k_first) @@ -827,22 +823,18 @@ def process_weights_after_loading(self, layer: Module) -> None: w2_weight_scale_inv = layer.w2_weight_scale_inv # torch.compile() cannot use Parameter subclasses. - layer.w13_weight = Parameter(w13_weight, requires_grad=False) - layer.w13_weight_scale_inv = Parameter( - w13_weight_scale_inv, requires_grad=False - ) - layer.w2_weight = Parameter(w2_weight, requires_grad=False) - layer.w2_weight_scale_inv = Parameter( - w2_weight_scale_inv, requires_grad=False - ) + layer.w13_weight.copy_(w13_weight) + layer.w13_weight_scale_inv.copy_(w13_weight_scale_inv) + layer.w2_weight.copy_(w2_weight) + layer.w2_weight_scale_inv.copy_(w2_weight_scale_inv) if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight.data, layer.w2_weight.data ) - layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) + layer.w13_weight.copy_(shuffled_w13) + layer.w2_weight.copy_(shuffled_w2) # DeepGemm scales need to be transposed and aligned. We try to do # it ahead of time for performance reasons. @@ -864,7 +856,7 @@ def process_weights_after_loading(self, layer: Module) -> None: # Re-initialize w13_scale because we directly quantize # merged w13 weights and generate a single scaling factor. - layer.w13_weight_scale = torch.nn.Parameter( + layer.w13_weight_scale.copy_( torch.ones( layer.local_num_experts, dtype=torch.float32, @@ -879,16 +871,16 @@ def process_weights_after_loading(self, layer: Module) -> None: w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) ) - layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w13_weight.copy_(w13_weight) + layer.w2_weight.copy_(w2_weight) if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight, layer.w2_weight ) - layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) + layer.w13_weight.copy_(shuffled_w13) + layer.w2_weight.copy_(shuffled_w2) # If checkpoint is fp8, we need to handle that the # MoE kernels require single activation scale and single weight # scale for w13 per expert. @@ -909,12 +901,8 @@ def process_weights_after_loading(self, layer: Module) -> None: "fp8 MoE layer. Using the maximum across experts " "for each layer." ) - layer.w13_input_scale = torch.nn.Parameter( - layer.w13_input_scale.max(), requires_grad=False - ) - layer.w2_input_scale = torch.nn.Parameter( - layer.w2_input_scale.max(), requires_grad=False - ) + layer.w13_input_scale.copy_(layer.w13_input_scale.max()) + layer.w2_input_scale.copy_(layer.w2_input_scale.max()) if current_platform.is_fp8_fnuz(): # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = ( @@ -928,22 +916,14 @@ def process_weights_after_loading(self, layer: Module) -> None: ) ) # Reset the parameter - layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter( - w13_weight_scale, requires_grad=False - ) + layer.w13_weight.copy_(w13_weight) + layer.w13_weight_scale.copy_(w13_weight_scale) if w13_input_scale is not None: - layer.w13_input_scale = torch.nn.Parameter( - w13_input_scale, requires_grad=False - ) - layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter( - w2_weight_scale, requires_grad=False - ) + layer.w13_input_scale.copy_(w13_input_scale) + layer.w2_weight.copy_(w2_weight) + layer.w2_weight_scale.copy_(w2_weight_scale) if w2_input_scale is not None: - layer.w2_input_scale = torch.nn.Parameter( - w2_input_scale, requires_grad=False - ) + layer.w2_input_scale.copy_(w2_input_scale) # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max then dequant and requant each expert. @@ -967,12 +947,10 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w13_weight, layer.w2_weight ) - layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) + layer.w13_weight.copy_(shuffled_w13) + layer.w2_weight.copy_(shuffled_w2) - layer.w13_weight_scale = torch.nn.Parameter( - max_w13_scales, requires_grad=False - ) + layer.w13_weight_scale.copy_(max_w13_scales) if self.flashinfer_moe_backend is not None: # NOTE: weights have to be swapped since the activation is diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 78456dcf1ca5..cf4a778af0b9 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -45,6 +45,13 @@ def apply(self, layer: torch.nn.Module) -> torch.Tensor: raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.") def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # skip if there are no weights to process (for examplle, weight reloading) + if not hasattr(layer, "q_scale"): + assert not hasattr(layer, "k_scale") + assert not hasattr(layer, "v_scale") + assert not hasattr(layer, "prob_scale") + return + # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 # regardless whether the kv-scale is available in the checkpoint. # No need to process kv scales after loading if we are going to From 157640958f08398e3cc46063a78316511fe96140 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 11 Nov 2025 13:42:35 -0500 Subject: [PATCH 02/11] fix typo Signed-off-by: Kyle Sayers --- vllm/model_executor/layers/quantization/kv_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index cf4a778af0b9..f0497a872290 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -45,7 +45,7 @@ def apply(self, layer: torch.nn.Module) -> torch.Tensor: raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.") def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # skip if there are no weights to process (for examplle, weight reloading) + # skip if there are no weights to process (for example, weight reloading) if not hasattr(layer, "q_scale"): assert not hasattr(layer, "k_scale") assert not hasattr(layer, "v_scale") From 749c91c8ff5ce897c85663369a1806f5ffcb28c5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 11 Nov 2025 17:09:41 -0500 Subject: [PATCH 03/11] register weight scale in create params, still issue with reloading from disk Signed-off-by: Kyle Sayers --- .../model_executor/layers/quantization/fp8.py | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f83248b82e8f..d42034c1ebd7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -483,6 +483,18 @@ def create_weights( else: layer.register_parameter("input_scale", None) + # create per-tensor qparams populated by process_weights_after_loading + else: + scale = create_fp8_scale_parameter( + PerTensorScaleParameter, + output_partition_sizes, + input_size_per_partition, + None, + weight_loader, + ) + set_weight_attrs(scale, {"scale_type": "weight_scale"}) + layer.register_parameter("weight_scale", scale) + def process_weights_after_loading(self, layer: Module) -> None: size_k_first = True input_scale = None @@ -494,8 +506,8 @@ def process_weights_after_loading(self, layer: Module) -> None: weight, weight_scale = process_fp8_weight_block_strategy( layer.weight, layer.weight_scale_inv ) - # Delete the weight_scale_inv parameter to avoid confusion - # with the weight_scale parameter + # Rename weight_scale_inv parameter for consistency + layer.weight_scale = layer.weight_scale_inv del layer.weight_scale_inv # If checkpoint not serialized fp8, quantize the weights. @@ -755,12 +767,10 @@ def create_weights( if self.block_quant else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} ) - # If loading fp8 checkpoint, pass the weight loaders. - # If loading an fp16 checkpoint, do not (we will quantize in - # process_weights_after_loading() - if self.quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # add weight loaders to support loading (and reloading) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.quant_config.activation_scheme == "static": From 4fc904316ca27c29174a277aee516f61114095f4 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 12 Nov 2025 14:31:16 -0500 Subject: [PATCH 04/11] WIP: reload Signed-off-by: Kyle Sayers --- .../model_loader/default_loader.py | 43 +--- .../model_loader/online_quantization.py | 223 ++++++------------ vllm/v1/worker/gpu_model_runner.py | 31 ++- 3 files changed, 108 insertions(+), 189 deletions(-) diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index c06ac550a94a..23fb81420074 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -5,7 +5,7 @@ import os import time from collections.abc import Generator, Iterable -from typing import cast +from typing import cast, Optional import torch from torch import nn @@ -272,44 +272,21 @@ def download_model(self, model_config: ModelConfig) -> None: allow_patterns_overrides=None, ) - def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: + def load_weights(self, model: nn.Module, model_config: ModelConfig, weights_iterator: Optional[Iterable[tuple[str, torch.Tensor]]] = None) -> None: if model_config.quantization == "torchao" and torchao_version_at_least( "0.14.0" ): self.load_config.safetensors_load_strategy = "torchao" - weights_to_load = {name for name, _ in model.named_parameters()} - - # if we don't have `model.weight_metadata_and_attr_saved` defined and - # set to True, it means that this is either offline quantization case - # or the first run of online quantization - # see online_quantization.py for detailed notes - offline_quantization_or_first_run_of_online_quantization = not getattr( - model, "weight_metadata_and_attr_saved", False - ) + + # use provided weights or load from disk + if weights_iterator is None: + weights_iterator = self.get_all_weights(model_config, model) - if model_config.quantization is None: - # model is not quantized - loaded_weights = model.load_weights( - self.get_all_weights(model_config, model) - ) - elif offline_quantization_or_first_run_of_online_quantization: - # case 1: offline quantized checkpoint - # case 2: Step I1 first run of weight loading with - # online quantization - # see online_quantization.py for detailed notes - loaded_weights = model.load_weights( - self.get_all_weights(model_config, model) - ) - else: - # to avoid circular dependency - from vllm.model_executor.model_loader.online_quantization import ( - load_weights_and_online_quantize, - ) - - # subsequent runs of weight loading with online - # quantization - loaded_weights = load_weights_and_online_quantize(self, model, model_config) + # load weights into model + weights_to_load = {name for name, _ in model.named_parameters()} + loaded_weights = model.load_weights(weights_iterator) + # logging and validation self.counter_after_loading_weights = time.perf_counter() logger.info_once( "Loading weights took %.2f seconds", diff --git a/vllm/model_executor/model_loader/online_quantization.py b/vllm/model_executor/model_loader/online_quantization.py index 890dd7231a0e..8fbfd284bd7d 100644 --- a/vllm/model_executor/model_loader/online_quantization.py +++ b/vllm/model_executor/model_loader/online_quantization.py @@ -5,6 +5,7 @@ import torch from torch import nn +from copy import deepcopy from vllm.config import ModelConfig from vllm.logger import init_logger @@ -13,6 +14,11 @@ logger = init_logger(__name__) +SUPPORTED_QUANT_CONFIGS = { + "torchao", + "fp8", +} + # Notes for Online Quantization # In terms of state of checkpoints, quantization config and their # correspondance to online quantization: @@ -64,161 +70,74 @@ def maybe_save_metadata_and_attributes_for_weight_reloading( model: nn.Module, model_config: ModelConfig ): - # following is to support on the fly quantization, currently only supported - # for torchao - if model_config.quantization != "torchao": - return - - if getattr(model, "process_weights_after_loading_already_called", False): - # In case `process_weights_after_loading` is called multiple times - # we'll skip it at later times - logger.warning( - "process_weights_after_loading already called for model %s", model - ) - return - + # assume this is called right after weight loading and before/ at the start of process_weights_after_loading from vllm.model_executor.model_loader.weight_utils import get_quant_config quant_config = get_quant_config(model_config, None) - - # If checkpoint is already torchao serialized, this means it's - # pre-quantized quantization case, we'll skip saving the metadata - # Otherwise, this is Step I2 of initialization steps of - # online quantization - # This step record the weights metadata and weight attributes so we can - # restore the bfloat16 model weights during the relad step (R1 and R2) - # see Notes in online_quantization.py for more details - if not ( - hasattr(quant_config, "is_checkpoint_torchao_serialized") - and not quant_config.is_checkpoint_torchao_serialized - ): + if quant_config.get_name() not in SUPPORTED_QUANT_CONFIGS: return - - # This is the I2 step of online quantiztion that saves - # metadata and attributes of weights so they can be used in R1 and - # R2 step, note that we only save these during initialization - - # Includes two things - # 1. save floating point metadata (shape, dtype, device) for init - # 2. save weight attributes, e.g. `output_dim`, `weight_loader` for init - - if getattr(model, "weight_metadata_and_attr_saved", False): - return - - # save the dtype, shape and device for model parameter, used for - # restoring the model high precision parameters before - # reloading the weights - assert not hasattr(model, "original_weights_rebuild_keys") - model.original_weights_rebuild_keys = {} - for name, p in model.named_parameters(): - model.original_weights_rebuild_keys[name] = { - "shape": p.shape, - "dtype": p.dtype, - "device": p.device, - } - - # record the weight attributes (loader functions etc.) - # so these can be recovered later when we reload the weights - # structure: {"weight_name": {"weight_attr_key": attr}} - assert not hasattr(model, "recorded_weight_attr") - model.recorded_weight_attr = {} - for name, param in model.named_parameters(): - model.recorded_weight_attr[name] = {} - for key in param.__dict__: - if hasattr(param, key): - attr = getattr(param, key) - if not callable(attr): - model.recorded_weight_attr[name][key] = attr - elif hasattr(attr, "__self__") and param is attr.__self__: - # if attr is a bonded method for an instance, and - # attr.__self__ points to the instance (param) - # we'll record the underlying function object - model.recorded_weight_attr[name][key] = attr.__func__ - else: - model.recorded_weight_attr[name][key] = attr - # mark the metadata and attributes saved so we don't run it again - model.weight_metadata_and_attr_saved = True - - -def _bond_method_to_cls(func, obj): - if hasattr(func, "__self__") or not callable(func): - # If the function is already bound to an instance, return it as is - return func - else: - return types.MethodType(func, obj) - - -def load_weights_and_online_quantize( - model_loader: DefaultModelLoader, model: nn.Module, model_config: ModelConfig -) -> set[str]: - # online quantization, right now only enabled for - # torchao - # R1, R2, R3, R4 in the Notes - - # TODO: Add fp8 support - assert model_config.quantization == "torchao", ( - "online quantization is only enabled for torchao currently" + + if not hasattr(model, "weight_loading_metadata"): + setattr(model, "weight_loading_metadata", { + name: _copy_to_meta_tensor(param) + for name, param in model.named_parameters() + }) + + return getattr(model, "weight_loading_metadata") + + +def restore_weights_for_loading(model: nn.Module): + assert hasattr(model, "weight_loading_metadata") + metadata: dict[str, torch.Tensor] = getattr(model, "weight_loading_metadata") + model_param_names = dict(model.named_parameters(remove_duplicate=False)).keys() + + # remove parameters which were not present at load time + params_to_remove = model_param_names - metadata.keys() + for param_fqn in params_to_remove: + module_name, param_name = param_fqn.rsplit(".", 1) + module = model.get_submodule(module_name) + delattr(module, param_name) + + # restore parameters that were present at load time + for param_fqn, meta_tensor in metadata.items(): + module_name, param_name = param_fqn.rsplit(".", 1) + module = model.get_submodule(module_name) + + # for faster runtime, skip materialization if the tensors match + original_tensor = getattr(module, param_name, None) + if _tensors_alike(original_tensor, meta_tensor): + continue + + param = _materialize_meta_tensor(meta_tensor) + setattr(module, param) + + +def _copy_to_meta_tensor(tensor: torch.Tensor) -> torch.Tensor: + new_tensor = tensor.to("meta") + new_tensor.__class__ = tensor.__class__ + new_tensor.__dict__ = deepcopy(tensor.__dict__) + new_tensor._original_device = tensor.device + return new_tensor + + +def _tensors_alike(tensor: torch.Tensor | None, meta: torch.Tensor) -> bool: + if tensor is None: + return False + + return ( + tensor.device == meta._original_device + and tensor.dtype == meta.dtype + and tensor.shape == meta.shape + and tensor.__dict__ == meta.__dict__ ) - # TODO: use create_weights to restore the weights to original state - - # Step R1: First restore the quantized weights to original bfloat16 - # weights, with original metadata (shape, dtype, device) - # and attributes, so that bfloat16 weights can be loaded properly - existing_param_names = dict(model.named_parameters(remove_duplicate=False)).keys() - named_modules = dict(model.named_modules(remove_duplicate=False)) - model_device = None - - # Step R2: recover the parameter to the state before first loading - for name, d in model.original_weights_rebuild_keys.items(): - _shape = d["shape"] - _dtype = d["dtype"] - _device = d["device"] - if model_device is not None: - assert model_device == _device, ( - "Expecting all weights " - "to be in the same device for now, got both: " - f"{model_device} and {_device}" - ) - else: - model_device = _device - - if name in existing_param_names: - module_name, weight_name = name.rsplit(".", 1) - module = named_modules[module_name] - setattr( - module, - weight_name, - torch.nn.Parameter(torch.empty(_shape, dtype=_dtype, device=_device)), - ) - - # recorded_weight_attr is - # {"weight_name": {"weight_attr_key": attr}} - # e.g. - # { - # { - # "layer.0.weight": { - # "weight_loader": weight_loader_function_object, - # "input_dim": 0, ... - # }, - # "layer.1.weight": ..., - # } - # } - for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items(): - for attr_name, attr in weight_attr_dict.items(): - module_name, weight_name = full_weight_name.rsplit(".", 1) - module = named_modules[module_name] - weight = getattr(module, weight_name) - if not hasattr(weight, attr_name): - setattr(weight, attr_name, _bond_method_to_cls(attr, weight)) - - # Step I1: reload bfloat16 / high precision weights - loaded_weights = model.load_weights( - model_loader.get_all_weights(model_config, model) + + + +def _materialize_meta_tensor(tensor: torch.Tensor) -> torch.Tensor: + return torch.empty_strided( + size=tuple(tensor.size()), + stride=tuple(tensor.stride()), + dtype=tensor.dtype, + device=tensor._original_device, + requires_grad=False, # set below to match input ) - - # Step I2: online quantize the weights - # manually process weights after loading - model.process_weights_after_loading_already_called = False - process_weights_after_loading(model, model_config, model_device) - model.process_weights_after_loading_already_called = True - return loaded_weights diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fbd3e5f31316..26bfde6e283f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -10,7 +10,7 @@ from copy import deepcopy from functools import reduce from itertools import product -from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast +from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast, Optional, Iterable import numpy as np import torch @@ -3177,13 +3177,36 @@ def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None: return None - def reload_weights(self) -> None: + def reload_weights( + self, weights_iterator: Optional[Iterable[tuple[str, torch.Tensor]]] = None, + process_weights_after_loading: bool = True) -> None: + from vllm.model_executor.model_loader.utils import process_weights_after_loading as _process + from vllm.model_executor.model_loader.online_quantization import ( + restore_weights_for_loading, + ) + assert getattr(self, "model", None) is not None, ( "Cannot reload weights before model is loaded." ) - model_loader = get_model_loader(self.load_config) + model = self.get_model() + + # for select quant configs, regenerate weights for proper weight loading + if process_weights_after_loading and hasattr("weight_loading_metadata"): + restore_weights_for_loading(model) + logger.info("Reloading weights inplace...") - model_loader.load_weights(self.get_model(), model_config=self.model_config) + model_loader = get_model_loader(self.load_config) + model_loader.load_weights(model, model_config=self.model_config, weights_iterator=weights_iterator) + + if process_weights_after_loading: + device_config = self.vllm_config.device_config + load_config = self.vllm_config.load_config + load_device = ( + device_config.device if load_config.device is None else load_config.device + ) + _process(model, self.model_config, load_device) + + # TODO: logging total reload time def save_tensorized_model( self, From d2504cf127902f30f5e9c98c4452e2f8c74d2aba Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 12 Nov 2025 15:09:18 -0500 Subject: [PATCH 05/11] WIP: standardize formats, style Signed-off-by: Kyle Sayers --- .../model_loader/default_loader.py | 9 +-- .../model_loader/online_quantization.py | 29 +++---- vllm/v1/worker/gpu_model_runner.py | 80 +++++++++++++++---- 3 files changed, 79 insertions(+), 39 deletions(-) diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 23fb81420074..f0f68000da29 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -5,7 +5,7 @@ import os import time from collections.abc import Generator, Iterable -from typing import cast, Optional +from typing import cast import torch from torch import nn @@ -272,18 +272,15 @@ def download_model(self, model_config: ModelConfig) -> None: allow_patterns_overrides=None, ) - def load_weights(self, model: nn.Module, model_config: ModelConfig, weights_iterator: Optional[Iterable[tuple[str, torch.Tensor]]] = None) -> None: + def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: if model_config.quantization == "torchao" and torchao_version_at_least( "0.14.0" ): self.load_config.safetensors_load_strategy = "torchao" - - # use provided weights or load from disk - if weights_iterator is None: - weights_iterator = self.get_all_weights(model_config, model) # load weights into model weights_to_load = {name for name, _ in model.named_parameters()} + weights_iterator = self.get_all_weights(model_config, model) loaded_weights = model.load_weights(weights_iterator) # logging and validation diff --git a/vllm/model_executor/model_loader/online_quantization.py b/vllm/model_executor/model_loader/online_quantization.py index 8fbfd284bd7d..c24b4b7dc6f2 100644 --- a/vllm/model_executor/model_loader/online_quantization.py +++ b/vllm/model_executor/model_loader/online_quantization.py @@ -1,20 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import types + +from copy import deepcopy import torch from torch import nn -from copy import deepcopy from vllm.config import ModelConfig from vllm.logger import init_logger -from vllm.model_executor.model_loader.default_loader import DefaultModelLoader -from vllm.model_executor.model_loader.utils import process_weights_after_loading logger = init_logger(__name__) -SUPPORTED_QUANT_CONFIGS = { +ONLINE_RELOAD_QUANT_CONFIGS = { "torchao", "fp8", } @@ -70,25 +68,25 @@ def maybe_save_metadata_and_attributes_for_weight_reloading( model: nn.Module, model_config: ModelConfig ): - # assume this is called right after weight loading and before/ at the start of process_weights_after_loading + # this function should be called at the start of `process_weights_after_loading` from vllm.model_executor.model_loader.weight_utils import get_quant_config quant_config = get_quant_config(model_config, None) - if quant_config.get_name() not in SUPPORTED_QUANT_CONFIGS: + if quant_config.get_name() not in ONLINE_RELOAD_QUANT_CONFIGS: return - + if not hasattr(model, "weight_loading_metadata"): - setattr(model, "weight_loading_metadata", { + model.weight_loading_metadata = { name: _copy_to_meta_tensor(param) for name, param in model.named_parameters() - }) + } - return getattr(model, "weight_loading_metadata") + return model.weight_loading_metadata def restore_weights_for_loading(model: nn.Module): assert hasattr(model, "weight_loading_metadata") - metadata: dict[str, torch.Tensor] = getattr(model, "weight_loading_metadata") + metadata: dict[str, torch.Tensor] = model.weight_loading_metadata model_param_names = dict(model.named_parameters(remove_duplicate=False)).keys() # remove parameters which were not present at load time @@ -109,7 +107,7 @@ def restore_weights_for_loading(model: nn.Module): continue param = _materialize_meta_tensor(meta_tensor) - setattr(module, param) + setattr(module, param_name, param) def _copy_to_meta_tensor(tensor: torch.Tensor) -> torch.Tensor: @@ -123,15 +121,14 @@ def _copy_to_meta_tensor(tensor: torch.Tensor) -> torch.Tensor: def _tensors_alike(tensor: torch.Tensor | None, meta: torch.Tensor) -> bool: if tensor is None: return False - + return ( tensor.device == meta._original_device and tensor.dtype == meta.dtype and tensor.shape == meta.shape and tensor.__dict__ == meta.__dict__ ) - - + def _materialize_meta_tensor(tensor: torch.Tensor) -> torch.Tensor: return torch.empty_strided( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 26bfde6e283f..34bdb633bfe7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5,12 +5,12 @@ import itertools import time from collections import defaultdict -from collections.abc import Iterator +from collections.abc import Iterable, Iterator from contextlib import contextmanager from copy import deepcopy from functools import reduce from itertools import product -from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast, Optional, Iterable +from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast import numpy as np import torch @@ -3178,35 +3178,81 @@ def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None: return None def reload_weights( - self, weights_iterator: Optional[Iterable[tuple[str, torch.Tensor]]] = None, - process_weights_after_loading: bool = True) -> None: - from vllm.model_executor.model_loader.utils import process_weights_after_loading as _process + self, + weights_iterator: Iterable[tuple[str, torch.Tensor]] | None = None, + process_weights_after_loading: bool = True, + ) -> None: from vllm.model_executor.model_loader.online_quantization import ( restore_weights_for_loading, ) - - assert getattr(self, "model", None) is not None, ( - "Cannot reload weights before model is loaded." + from vllm.model_executor.model_loader.utils import ( + process_weights_after_loading as _process_after_loading, ) - model = self.get_model() - # for select quant configs, regenerate weights for proper weight loading - if process_weights_after_loading and hasattr("weight_loading_metadata"): - restore_weights_for_loading(model) + # argument validation + if weights_iterator is None and not process_weights_after_loading: + logger.warning( + "Loading from disk means that weights will be in checkpoint format" + ) + + if getattr(self, "model", None) is not None: + raise ValueError("Cannot reload weights before model is loaded.") + model = self.get_model() logger.info("Reloading weights inplace...") - model_loader = get_model_loader(self.load_config) - model_loader.load_weights(model, model_config=self.model_config, weights_iterator=weights_iterator) + counter_before_loading_weights = time.perf_counter() + + # maybe load weights from disk + if weights_iterator is None: + model_loader = get_model_loader(self.load_config) + weights_iterator = model_loader.get_all_weights(self.model_config, model) + weights_iterator = cast( + Iterable[tuple[str, torch.Tensor]], weights_iterator + ) if process_weights_after_loading: + # restore model to checkpoint format + if hasattr(model, "weight_loading_metadata"): + restore_weights_for_loading(model) + else: + logger.warning("Quant config is not supported") + + # load weights from checkpoint format + weights_to_load = {name for name, _ in model.named_parameters()} + loaded_weights = model.load_weights(weights_iterator) + + # process weights into kernel format device_config = self.vllm_config.device_config load_config = self.vllm_config.load_config load_device = ( - device_config.device if load_config.device is None else load_config.device + device_config.device + if load_config.device is None + else load_config.device ) - _process(model, self.model_config, load_device) + _process_after_loading(model, self.model_config, load_device) + + else: + # load weights from kernel format + for name, weight in weights_iterator: + param = model.get_parameter(name) + param.weight_loader(param, weight) + + # logging + counter_after_loading_weights = time.perf_counter() + diff_seconds = counter_after_loading_weights - counter_before_loading_weights + logger.info_once( + "Loading weights took %.2f seconds", diff_seconds, scope="local" + ) - # TODO: logging total reload time + # We only enable strict check for non-quantized models + # that have loaded weights tracking currently. + if self.model_config.quantization is None and loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError( + "Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}" + ) def save_tensorized_model( self, From 82a25cc6a18a7ad098fc35d7c425634846b9a5d5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 12 Nov 2025 15:15:52 -0500 Subject: [PATCH 06/11] rename Signed-off-by: Kyle Sayers --- .../model_loader/online_quantization.py | 47 ++++++++++--------- vllm/model_executor/model_loader/utils.py | 4 +- vllm/v1/worker/gpu_model_runner.py | 4 +- 3 files changed, 28 insertions(+), 27 deletions(-) diff --git a/vllm/model_executor/model_loader/online_quantization.py b/vllm/model_executor/model_loader/online_quantization.py index c24b4b7dc6f2..5977adf3e863 100644 --- a/vllm/model_executor/model_loader/online_quantization.py +++ b/vllm/model_executor/model_loader/online_quantization.py @@ -65,9 +65,7 @@ # load_weights -def maybe_save_metadata_and_attributes_for_weight_reloading( - model: nn.Module, model_config: ModelConfig -): +def record_weights_for_reloading(model: nn.Module, model_config: ModelConfig): # this function should be called at the start of `process_weights_after_loading` from vllm.model_executor.model_loader.weight_utils import get_quant_config @@ -81,10 +79,8 @@ def maybe_save_metadata_and_attributes_for_weight_reloading( for name, param in model.named_parameters() } - return model.weight_loading_metadata - -def restore_weights_for_loading(model: nn.Module): +def restore_weights_for_reloading(model: nn.Module): assert hasattr(model, "weight_loading_metadata") metadata: dict[str, torch.Tensor] = model.weight_loading_metadata model_param_names = dict(model.named_parameters(remove_duplicate=False)).keys() @@ -111,30 +107,35 @@ def restore_weights_for_loading(model: nn.Module): def _copy_to_meta_tensor(tensor: torch.Tensor) -> torch.Tensor: - new_tensor = tensor.to("meta") - new_tensor.__class__ = tensor.__class__ - new_tensor.__dict__ = deepcopy(tensor.__dict__) - new_tensor._original_device = tensor.device - return new_tensor + meta_tensor = tensor.to("meta") + meta_tensor.__class__ = tensor.__class__ + meta_tensor.__dict__ = deepcopy(tensor.__dict__) + meta_tensor._original_device = tensor.device + + return meta_tensor -def _tensors_alike(tensor: torch.Tensor | None, meta: torch.Tensor) -> bool: +def _tensors_alike(tensor: torch.Tensor | None, meta_tensor: torch.Tensor) -> bool: if tensor is None: return False return ( - tensor.device == meta._original_device - and tensor.dtype == meta.dtype - and tensor.shape == meta.shape - and tensor.__dict__ == meta.__dict__ + tensor.device == meta_tensor._original_device + and tensor.dtype == meta_tensor.dtype + and tensor.shape == meta_tensor.shape + and tensor.__dict__ == meta_tensor.__dict__ ) -def _materialize_meta_tensor(tensor: torch.Tensor) -> torch.Tensor: - return torch.empty_strided( - size=tuple(tensor.size()), - stride=tuple(tensor.stride()), - dtype=tensor.dtype, - device=tensor._original_device, - requires_grad=False, # set below to match input +def _materialize_meta_tensor(meta_tensor: torch.Tensor) -> torch.Tensor: + tensor = torch.empty_strided( + size=tuple(meta_tensor.size()), + stride=tuple(meta_tensor.stride()), + dtype=meta_tensor.dtype, + device=meta_tensor._original_device, + requires_grad=meta_tensor.requires_grad, ) + tensor.__class__ = meta_tensor.__class__ + tensor.__dict__ = deepcopy(meta_tensor.__dict__) + + return tensor diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index ba708a098c0d..b6e7560d8c4d 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -90,10 +90,10 @@ def process_weights_after_loading( ) -> None: # to avoid circular dependency from vllm.model_executor.model_loader.online_quantization import ( - maybe_save_metadata_and_attributes_for_weight_reloading, + record_weights_for_reloading, ) - maybe_save_metadata_and_attributes_for_weight_reloading(model, model_config) + record_weights_for_reloading(model, model_config) for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 34bdb633bfe7..1e89a428f6a8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3183,7 +3183,7 @@ def reload_weights( process_weights_after_loading: bool = True, ) -> None: from vllm.model_executor.model_loader.online_quantization import ( - restore_weights_for_loading, + restore_weights_for_reloading, ) from vllm.model_executor.model_loader.utils import ( process_weights_after_loading as _process_after_loading, @@ -3213,7 +3213,7 @@ def reload_weights( if process_weights_after_loading: # restore model to checkpoint format if hasattr(model, "weight_loading_metadata"): - restore_weights_for_loading(model) + restore_weights_for_reloading(model) else: logger.warning("Quant config is not supported") From af5772e9fdb262a6407993495c42d699a2867a28 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 12 Nov 2025 19:53:11 -0500 Subject: [PATCH 07/11] base runner Signed-off-by: Kyle Sayers --- .../model_loader/online_quantization.py | 76 ++++++--------- vllm/v1/worker/base.py | 95 +++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 86 ++--------------- 3 files changed, 130 insertions(+), 127 deletions(-) create mode 100644 vllm/v1/worker/base.py diff --git a/vllm/model_executor/model_loader/online_quantization.py b/vllm/model_executor/model_loader/online_quantization.py index 5977adf3e863..6ec73a22fbe9 100644 --- a/vllm/model_executor/model_loader/online_quantization.py +++ b/vllm/model_executor/model_loader/online_quantization.py @@ -3,6 +3,7 @@ from copy import deepcopy +from types import MethodType import torch from torch import nn @@ -17,56 +18,29 @@ "fp8", } -# Notes for Online Quantization -# In terms of state of checkpoints, quantization config and their -# correspondance to online quantization: -# | Use Case | Checkpoints | model_config.quantization | -# | no quant | high precision | None | -# | offline quant | quantized | fp8, torchao etc. | -# | online quant | high precision | torchao etc. | -# -# The process for loading non-quantized checkpoint -# 1. load non-quantized weights (load_weights) -# 2. do any additional post processing (process_weights_after_loading) -# -# The process for loading offline quantized checkpoint -# 1. load offline-quantized weights (load_weights) -# 2. do any additional post processing (process_weights_after_loading) - -# The process for unquantized model reloading -# (repeated run in RL training loop) -# first run -# UI1. load_weights: load bfloat16 weights -# UI2. process_weights_after_loading: any additional post processing -# subsequent run -# UC1: load_weights: load bfloat16 weights -# (shouldn't be any issues since we didn't change any attributes -# of the weights) -# UC2: process_weights_after_loading: any additional post processing - -# The process for weight reloading with online quantization -# (repeated run in RL training loop) -# first run -# I1. load_weights: load bfloat16 weights -# I2. process_weights_after_loading: -# record weight metadata and attributes for R1 and R2 -# quantize weights to fp8 -# subsequent run -# (beginning model weight is in fp8) -# load_weights: -# R1. restore bfloat16 model weight metadata -# R2. restore the model weight attributes -# R3. reload bfloat16 weights -# R4. quantize weights (by calling process_weights_after_loading), -# also set `process_weights_after_loading_already_called` to -# True to stop it from running again -# process_weights_after_loading (if called): -# this will be skipped since it's already ran in -# load_weights +""" + +First time loading lifecycle +1. Model checkpoint is loaded by `ModelLoader.get_all_weights` into `weights_iterator` +2. `weights_iterator` is loaded into model by `model.load_weights` +3. Model state is captured by `record_weights_for_reloading` +4. `process_weights_after_loading` converts model state into kernel format +5. Model can run now that weights are in kernel format + + +Subsequent reloading lifecycle +1. Model weights updates are packed into an async/chunked `weights_iterator` +or model checkpoint is loaded from disk into `weights_iterator` +2. Model state is restored to by `restore_weights_for_reloading` +3. + + +""" def record_weights_for_reloading(model: nn.Module, model_config: ModelConfig): - # this function should be called at the start of `process_weights_after_loading` + # this function should be called before `process_weights_after_loading` + # in practice, this happens at the very start of `process_weights_after_loading` from vllm.model_executor.model_loader.weight_utils import get_quant_config quant_config = get_quant_config(model_config, None) @@ -76,7 +50,7 @@ def record_weights_for_reloading(model: nn.Module, model_config: ModelConfig): if not hasattr(model, "weight_loading_metadata"): model.weight_loading_metadata = { name: _copy_to_meta_tensor(param) - for name, param in model.named_parameters() + for name, param in model.named_parameters(remove_duplicate=False) } @@ -138,4 +112,10 @@ def _materialize_meta_tensor(meta_tensor: torch.Tensor) -> torch.Tensor: tensor.__class__ = meta_tensor.__class__ tensor.__dict__ = deepcopy(meta_tensor.__dict__) + # rebind any references to the original tensor + # assume that methods are bound to the original tensor + for key, value in tensor.__dict__.items(): + if isinstance(value, MethodType): + tensor[key] = MethodType(value.__func__, tensor) + return tensor diff --git a/vllm/v1/worker/base.py b/vllm/v1/worker/base.py new file mode 100644 index 000000000000..bebc547bf112 --- /dev/null +++ b/vllm/v1/worker/base.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time +from abc import abstractmethod +from collections.abc import Iterable +from typing import cast + +import torch + +from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model_loader +from vllm.model_executor.model_loader.online_quantization import ( + restore_weights_for_reloading, +) +from vllm.model_executor.model_loader.utils import ( + process_weights_after_loading as _process_after_loading, +) + +logger = init_logger(__name__) + + +class ModelRunnerBase: + vllm_config: VllmConfig + load_config: LoadConfig + model_config: ModelConfig + + @abstractmethod + def get_model(self) -> torch.nn.Module: + raise NotImplementedError() + + def reload_weights( + self, + weights_iterator: Iterable[tuple[str, torch.Tensor]] | None = None, + process_weights_after_loading: bool = True, + ) -> None: + """ """ + # argument validation + weights_from_disk = weights_iterator is None + if weights_from_disk and not process_weights_after_loading: + logger.warning( + "Loading from disk means that weights will be in checkpoint format" + ) + + if getattr(self, "model", None) is not None: + raise ValueError("Cannot reload weights before model is loaded.") + + model = self.get_model() + logger.info("Reloading weights inplace...") + counter_before_loading_weights = time.perf_counter() + + # load weights from disk if none are provided + if weights_iterator is None: + model_loader = get_model_loader(self.load_config) + weights_iterator = model_loader.get_all_weights(self.model_config, model) + weights_iterator = cast( + Iterable[tuple[str, torch.Tensor]], weights_iterator + ) + + if process_weights_after_loading: + # restore model to checkpoint format + if hasattr(model, "weight_loading_metadata"): + restore_weights_for_reloading(model) + else: + logger.warning("Quant config is not supported") + + # load weights from checkpoint format + loaded_weights = model.load_weights(weights_iterator) + + # process weights into kernel format + device_config = self.vllm_config.device_config + load_config = self.vllm_config.load_config + load_device = ( + device_config.device + if load_config.device is None + else load_config.device + ) + _process_after_loading(model, self.model_config, load_device) + + else: + # load weights from kernel format + loaded_weights = set() + for name, loaded_weight in weights_iterator: + param = model.get_parameter(name) + param.weight_loader(param, loaded_weight) + loaded_weights.add(loaded_weight) + + # logging + counter_after_loading_weights = time.perf_counter() + diff_seconds = counter_after_loading_weights - counter_before_loading_weights + logger.info_once( + f"Loading {len(loaded_weights)} weights took %.2f seconds", + diff_seconds, + scope="local", + ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1e89a428f6a8..a24c3090de76 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5,7 +5,7 @@ import itertools import time from collections import defaultdict -from collections.abc import Iterable, Iterator +from collections.abc import Iterator from contextlib import contextmanager from copy import deepcopy from functools import reduce @@ -133,6 +133,9 @@ from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext +from vllm.v1.worker.base import ( + ModelRunnerBase, +) from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper @@ -239,7 +242,9 @@ class ExecuteModelState(NamedTuple): kv_connector_output: KVConnectorOutput | None -class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): +class GPUModelRunner( + LoRAModelRunnerMixin, KVConnectorModelRunnerMixin, ModelRunnerBase +): def __init__( self, vllm_config: VllmConfig, @@ -3177,83 +3182,6 @@ def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None: return None - def reload_weights( - self, - weights_iterator: Iterable[tuple[str, torch.Tensor]] | None = None, - process_weights_after_loading: bool = True, - ) -> None: - from vllm.model_executor.model_loader.online_quantization import ( - restore_weights_for_reloading, - ) - from vllm.model_executor.model_loader.utils import ( - process_weights_after_loading as _process_after_loading, - ) - - # argument validation - if weights_iterator is None and not process_weights_after_loading: - logger.warning( - "Loading from disk means that weights will be in checkpoint format" - ) - - if getattr(self, "model", None) is not None: - raise ValueError("Cannot reload weights before model is loaded.") - - model = self.get_model() - logger.info("Reloading weights inplace...") - counter_before_loading_weights = time.perf_counter() - - # maybe load weights from disk - if weights_iterator is None: - model_loader = get_model_loader(self.load_config) - weights_iterator = model_loader.get_all_weights(self.model_config, model) - weights_iterator = cast( - Iterable[tuple[str, torch.Tensor]], weights_iterator - ) - - if process_weights_after_loading: - # restore model to checkpoint format - if hasattr(model, "weight_loading_metadata"): - restore_weights_for_reloading(model) - else: - logger.warning("Quant config is not supported") - - # load weights from checkpoint format - weights_to_load = {name for name, _ in model.named_parameters()} - loaded_weights = model.load_weights(weights_iterator) - - # process weights into kernel format - device_config = self.vllm_config.device_config - load_config = self.vllm_config.load_config - load_device = ( - device_config.device - if load_config.device is None - else load_config.device - ) - _process_after_loading(model, self.model_config, load_device) - - else: - # load weights from kernel format - for name, weight in weights_iterator: - param = model.get_parameter(name) - param.weight_loader(param, weight) - - # logging - counter_after_loading_weights = time.perf_counter() - diff_seconds = counter_after_loading_weights - counter_before_loading_weights - logger.info_once( - "Loading weights took %.2f seconds", diff_seconds, scope="local" - ) - - # We only enable strict check for non-quantized models - # that have loaded weights tracking currently. - if self.model_config.quantization is None and loaded_weights is not None: - weights_not_loaded = weights_to_load - loaded_weights - if weights_not_loaded: - raise ValueError( - "Following weights were not initialized from " - f"checkpoint: {weights_not_loaded}" - ) - def save_tensorized_model( self, tensorizer_config: "TensorizerConfig", From bf4dcee5517d91ee6d444420a55098512b522f76 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 12 Nov 2025 21:16:05 -0500 Subject: [PATCH 08/11] leave base class for later, fix some typos, small regression tested and looks good Signed-off-by: Kyle Sayers --- .../model_loader/online_quantization.py | 8 +- vllm/model_executor/model_loader/utils.py | 9 +- vllm/v1/worker/base.py | 95 ------------------- vllm/v1/worker/gpu_model_runner.py | 79 +++++++++++++-- 4 files changed, 78 insertions(+), 113 deletions(-) delete mode 100644 vllm/v1/worker/base.py diff --git a/vllm/model_executor/model_loader/online_quantization.py b/vllm/model_executor/model_loader/online_quantization.py index 6ec73a22fbe9..d14c93726f41 100644 --- a/vllm/model_executor/model_loader/online_quantization.py +++ b/vllm/model_executor/model_loader/online_quantization.py @@ -33,18 +33,13 @@ or model checkpoint is loaded from disk into `weights_iterator` 2. Model state is restored to by `restore_weights_for_reloading` 3. - - """ def record_weights_for_reloading(model: nn.Module, model_config: ModelConfig): # this function should be called before `process_weights_after_loading` # in practice, this happens at the very start of `process_weights_after_loading` - from vllm.model_executor.model_loader.weight_utils import get_quant_config - - quant_config = get_quant_config(model_config, None) - if quant_config.get_name() not in ONLINE_RELOAD_QUANT_CONFIGS: + if model_config.quantization not in ONLINE_RELOAD_QUANT_CONFIGS: return if not hasattr(model, "weight_loading_metadata"): @@ -76,6 +71,7 @@ def restore_weights_for_reloading(model: nn.Module): if _tensors_alike(original_tensor, meta_tensor): continue + delattr(module, param_name) # delete before materialization to avoid oom param = _materialize_meta_tensor(meta_tensor) setattr(module, param_name, param) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index b6e7560d8c4d..7955d8eeb5cf 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -19,6 +19,9 @@ QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.model_loader.online_quantization import ( + record_weights_for_reloading, +) from vllm.model_executor.models.adapters import ( as_embedding_model, as_reward_model, @@ -88,11 +91,7 @@ def initialize_model( def process_weights_after_loading( model: nn.Module, model_config: ModelConfig, target_device: torch.device ) -> None: - # to avoid circular dependency - from vllm.model_executor.model_loader.online_quantization import ( - record_weights_for_reloading, - ) - + # weight reloading: must be called before weights are processed record_weights_for_reloading(model, model_config) for _, module in model.named_modules(): diff --git a/vllm/v1/worker/base.py b/vllm/v1/worker/base.py deleted file mode 100644 index bebc547bf112..000000000000 --- a/vllm/v1/worker/base.py +++ /dev/null @@ -1,95 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time -from abc import abstractmethod -from collections.abc import Iterable -from typing import cast - -import torch - -from vllm.config import LoadConfig, ModelConfig, VllmConfig -from vllm.logger import init_logger -from vllm.model_executor.model_loader import get_model_loader -from vllm.model_executor.model_loader.online_quantization import ( - restore_weights_for_reloading, -) -from vllm.model_executor.model_loader.utils import ( - process_weights_after_loading as _process_after_loading, -) - -logger = init_logger(__name__) - - -class ModelRunnerBase: - vllm_config: VllmConfig - load_config: LoadConfig - model_config: ModelConfig - - @abstractmethod - def get_model(self) -> torch.nn.Module: - raise NotImplementedError() - - def reload_weights( - self, - weights_iterator: Iterable[tuple[str, torch.Tensor]] | None = None, - process_weights_after_loading: bool = True, - ) -> None: - """ """ - # argument validation - weights_from_disk = weights_iterator is None - if weights_from_disk and not process_weights_after_loading: - logger.warning( - "Loading from disk means that weights will be in checkpoint format" - ) - - if getattr(self, "model", None) is not None: - raise ValueError("Cannot reload weights before model is loaded.") - - model = self.get_model() - logger.info("Reloading weights inplace...") - counter_before_loading_weights = time.perf_counter() - - # load weights from disk if none are provided - if weights_iterator is None: - model_loader = get_model_loader(self.load_config) - weights_iterator = model_loader.get_all_weights(self.model_config, model) - weights_iterator = cast( - Iterable[tuple[str, torch.Tensor]], weights_iterator - ) - - if process_weights_after_loading: - # restore model to checkpoint format - if hasattr(model, "weight_loading_metadata"): - restore_weights_for_reloading(model) - else: - logger.warning("Quant config is not supported") - - # load weights from checkpoint format - loaded_weights = model.load_weights(weights_iterator) - - # process weights into kernel format - device_config = self.vllm_config.device_config - load_config = self.vllm_config.load_config - load_device = ( - device_config.device - if load_config.device is None - else load_config.device - ) - _process_after_loading(model, self.model_config, load_device) - - else: - # load weights from kernel format - loaded_weights = set() - for name, loaded_weight in weights_iterator: - param = model.get_parameter(name) - param.weight_loader(param, loaded_weight) - loaded_weights.add(loaded_weight) - - # logging - counter_after_loading_weights = time.perf_counter() - diff_seconds = counter_after_loading_weights - counter_before_loading_weights - logger.info_once( - f"Loading {len(loaded_weights)} weights took %.2f seconds", - diff_seconds, - scope="local", - ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a24c3090de76..b4b39302e29a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5,7 +5,7 @@ import itertools import time from collections import defaultdict -from collections.abc import Iterator +from collections.abc import Iterable, Iterator from contextlib import contextmanager from copy import deepcopy from functools import reduce @@ -51,6 +51,12 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader +from vllm.model_executor.model_loader.online_quantization import ( + restore_weights_for_reloading, +) +from vllm.model_executor.model_loader.utils import ( + process_weights_after_loading as _process_weights_after_loading, +) from vllm.model_executor.models.interfaces import ( SupportsMultiModal, is_mixture_of_experts, @@ -133,9 +139,6 @@ from vllm.v1.spec_decode.suffix_decoding import SuffixDecodingProposer from vllm.v1.structured_output.utils import apply_grammar_bitmask from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext -from vllm.v1.worker.base import ( - ModelRunnerBase, -) from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper @@ -242,9 +245,7 @@ class ExecuteModelState(NamedTuple): kv_connector_output: KVConnectorOutput | None -class GPUModelRunner( - LoRAModelRunnerMixin, KVConnectorModelRunnerMixin, ModelRunnerBase -): +class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def __init__( self, vllm_config: VllmConfig, @@ -3182,6 +3183,70 @@ def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None: return None + def reload_weights( + self, + weights_iterator: Iterable[tuple[str, torch.Tensor]] | None = None, + process_weights_after_loading: bool = True, + ) -> None: + # argument validation + if weights_iterator is None and not process_weights_after_loading: + logger.warning( + "Reloading from disk means that weights will be in checkpoint format. " + "Please use `process_weights_after_loading=True` " + "to avoid weight reloading errors" + ) + if getattr(self, "model", None) is None: + raise ValueError("Cannot reload weights before model is loaded.") + + model = self.get_model() + logger.info("Reloading weights inplace...") + counter_before_loading_weights = time.perf_counter() + + # load weights from disk if none are provided + if weights_iterator is None: + model_loader = get_model_loader(self.load_config) + weights_iterator = model_loader.get_all_weights(self.model_config, model) + weights_iterator = cast( + Iterable[tuple[str, torch.Tensor]], weights_iterator + ) + + if process_weights_after_loading: + # restore to original model format + if hasattr(model, "weight_loading_metadata"): + restore_weights_for_reloading(model) + else: + logger.warning("Quant config is not supported") + + # load weights from checkpoint/ original model format + loaded_weights = model.load_weights(weights_iterator) + + # process weights into kernel format + device_config = self.vllm_config.device_config + load_config = self.vllm_config.load_config + load_device = ( + device_config.device + if load_config.device is None + else load_config.device + ) + _process_weights_after_loading(model, self.model_config, load_device) + + else: + # load weights from kernel format + loaded_weights = set() + for name, loaded_weight in weights_iterator: + param = model.get_parameter(name) + param.weight_loader(param, loaded_weight) + loaded_weights.add(loaded_weight) + + # logging + counter_after_loading_weights = time.perf_counter() + diff_seconds = counter_after_loading_weights - counter_before_loading_weights + logger.info_once( + f"Reloading {len(loaded_weights)} weights took %.2f seconds", + diff_seconds, + scope="local", + ) + def save_tensorized_model( self, tensorizer_config: "TensorizerConfig", From d91b6f22c54f8402e274f694ab365143e8f6fb85 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 12 Nov 2025 22:27:58 -0500 Subject: [PATCH 09/11] timing, general cleanup Signed-off-by: Kyle Sayers --- .../model_loader/online_quantization.py | 72 ++++++++++++------- vllm/model_executor/model_loader/utils.py | 23 +++++- vllm/v1/worker/gpu_model_runner.py | 47 +++++++----- 3 files changed, 95 insertions(+), 47 deletions(-) diff --git a/vllm/model_executor/model_loader/online_quantization.py b/vllm/model_executor/model_loader/online_quantization.py index d14c93726f41..1311b7a04922 100644 --- a/vllm/model_executor/model_loader/online_quantization.py +++ b/vllm/model_executor/model_loader/online_quantization.py @@ -1,47 +1,65 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Utilities for enabling weight reloading and online quantization +For more information and diagrams, see https://github.com/neuralmagic/vllm/pull/128 + +## Model Reloading Lifecycle ## +1. Model is loadeded for the first time + a. Checkpoint is loaded by `ModelLoader.get_all_weights` into `weights_iterator` + b. `weights_iterator` is loaded into model by `model.load_weights` + c. Model state is captured by `record_weights_for_reloading` + d. `process_weights_after_loading` converts model state into kernel format. + The model is no longer loadable while its weights are in kernel format + +2. Model is reloaded via `reload_weights` + a. A `weights_iterator` is provided, which may be async/ chunked/ sharded + b. The original model state is restored by `restore_weights_for_reloading` + using metadata information from `record_weights_for_reloading` + c. `weights_iterator` is loaded into model by `model.load_weights` + d. `process_weights_after_loading` converts model state into kernel format. + The model is no longer loadable while its weights are in kernel format + +Alternatively, if a user does not want to use `reload_weights`, they can call +steps 2b and 2d manually: + +```python +record_weights_for_reloading(model) + +for weights in weights_iterator: # may be async/ chunked/ sharded + model.load_weights(weights) + +process_weights_after_loading(model, model_config, device) +``` +""" - -from copy import deepcopy from types import MethodType import torch from torch import nn -from vllm.config import ModelConfig from vllm.logger import init_logger logger = init_logger(__name__) -ONLINE_RELOAD_QUANT_CONFIGS = { +__all__ = [ + "RELOADABLE_QUANT_CONFIGS", + "record_weights_for_reloading", + "restore_weights_for_reloading", +] + +# in theory, this implementation of weight recording/restoring +# should support any quantization config +RELOADABLE_QUANT_CONFIGS = { + None, "torchao", "fp8", } -""" - -First time loading lifecycle -1. Model checkpoint is loaded by `ModelLoader.get_all_weights` into `weights_iterator` -2. `weights_iterator` is loaded into model by `model.load_weights` -3. Model state is captured by `record_weights_for_reloading` -4. `process_weights_after_loading` converts model state into kernel format -5. Model can run now that weights are in kernel format - -Subsequent reloading lifecycle -1. Model weights updates are packed into an async/chunked `weights_iterator` -or model checkpoint is loaded from disk into `weights_iterator` -2. Model state is restored to by `restore_weights_for_reloading` -3. -""" - - -def record_weights_for_reloading(model: nn.Module, model_config: ModelConfig): +def record_weights_for_reloading(model: nn.Module): # this function should be called before `process_weights_after_loading` # in practice, this happens at the very start of `process_weights_after_loading` - if model_config.quantization not in ONLINE_RELOAD_QUANT_CONFIGS: - return - if not hasattr(model, "weight_loading_metadata"): model.weight_loading_metadata = { name: _copy_to_meta_tensor(param) @@ -79,7 +97,7 @@ def restore_weights_for_reloading(model: nn.Module): def _copy_to_meta_tensor(tensor: torch.Tensor) -> torch.Tensor: meta_tensor = tensor.to("meta") meta_tensor.__class__ = tensor.__class__ - meta_tensor.__dict__ = deepcopy(tensor.__dict__) + meta_tensor.__dict__ = tensor.__dict__ meta_tensor._original_device = tensor.device return meta_tensor @@ -106,7 +124,7 @@ def _materialize_meta_tensor(meta_tensor: torch.Tensor) -> torch.Tensor: requires_grad=meta_tensor.requires_grad, ) tensor.__class__ = meta_tensor.__class__ - tensor.__dict__ = deepcopy(meta_tensor.__dict__) + tensor.__dict__ = meta_tensor.__dict__ # rebind any references to the original tensor # assume that methods are bound to the original tensor diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 7955d8eeb5cf..8d49f14d115c 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -3,7 +3,9 @@ """Utilities for selecting and loading models.""" import inspect +import time import warnings +from collections.abc import Iterable from contextlib import contextmanager from dataclasses import dataclass, field @@ -20,6 +22,7 @@ QuantizeMethodBase, ) from vllm.model_executor.model_loader.online_quantization import ( + RELOADABLE_QUANT_CONFIGS, record_weights_for_reloading, ) from vllm.model_executor.models.adapters import ( @@ -88,12 +91,26 @@ def initialize_model( return model_class(**kwargs) +def default_model_weight_loader( + model: torch.nn.Module, weights_iterator: Iterable[tuple[str, torch.Tensor]] +) -> set[str]: + loaded_weights = set() + for name, loaded_weight in weights_iterator: + param = model.get_parameter(name) + param.weight_loader(param, loaded_weight) + loaded_weights.add(name) + + return loaded_weights + + def process_weights_after_loading( model: nn.Module, model_config: ModelConfig, target_device: torch.device ) -> None: # weight reloading: must be called before weights are processed - record_weights_for_reloading(model, model_config) + if model_config.quantization in RELOADABLE_QUANT_CONFIGS: + record_weights_for_reloading(model) + counter_before_processing_weights = time.perf_counter() for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if isinstance(quant_method, QuantizeMethodBase): @@ -115,6 +132,10 @@ def process_weights_after_loading( # of process_weights_after_loading module.process_weights_after_loading(model_config.dtype) + counter_after_processing_weights = time.perf_counter() + diff_seconds = counter_after_processing_weights - counter_before_processing_weights + logger.debug("Processing weights took %.2f seconds", diff_seconds) + @contextmanager def device_loading_context(module: torch.nn.Module, target_device: torch.device): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b4b39302e29a..3449794a4afa 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -52,8 +52,12 @@ from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.model_loader.online_quantization import ( + RELOADABLE_QUANT_CONFIGS, restore_weights_for_reloading, ) +from vllm.model_executor.model_loader.utils import ( + default_model_weight_loader, +) from vllm.model_executor.model_loader.utils import ( process_weights_after_loading as _process_weights_after_loading, ) @@ -3199,8 +3203,8 @@ def reload_weights( raise ValueError("Cannot reload weights before model is loaded.") model = self.get_model() - logger.info("Reloading weights inplace...") - counter_before_loading_weights = time.perf_counter() + weights_to_load = {name for name, _ in model.named_parameters()} + counter_before_reloading = time.perf_counter() # load weights from disk if none are provided if weights_iterator is None: @@ -3210,12 +3214,18 @@ def reload_weights( Iterable[tuple[str, torch.Tensor]], weights_iterator ) + # begin loading weights + logger.info("Reloading weights inplace...") if process_weights_after_loading: - # restore to original model format - if hasattr(model, "weight_loading_metadata"): + if self.model_config.quantization in RELOADABLE_QUANT_CONFIGS: restore_weights_for_reloading(model) else: - logger.warning("Quant config is not supported") + logger.warning_once( + "Given quantization %s does not support weight reloading. " + "Consider adding to list of supported configs: %s", + self.model_config.quantization, + f"{RELOADABLE_QUANT_CONFIGS}", + ) # load weights from checkpoint/ original model format loaded_weights = model.load_weights(weights_iterator) @@ -3232,20 +3242,19 @@ def reload_weights( else: # load weights from kernel format - loaded_weights = set() - for name, loaded_weight in weights_iterator: - param = model.get_parameter(name) - param.weight_loader(param, loaded_weight) - loaded_weights.add(loaded_weight) - - # logging - counter_after_loading_weights = time.perf_counter() - diff_seconds = counter_after_loading_weights - counter_before_loading_weights - logger.info_once( - f"Reloading {len(loaded_weights)} weights took %.2f seconds", - diff_seconds, - scope="local", - ) + default_model_weight_loader(model, weights_iterator) + + # logging and validation + counter_after_reloading = time.perf_counter() + diff_seconds = counter_after_reloading - counter_before_reloading + logger.info("Reloading and processing weights took %.2f seconds", diff_seconds) + if self.model_config.quantization is None and loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError( + "Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}" + ) def save_tensorized_model( self, From 4e5270ec99315eb741c7f823ce41a4503a2bd47b Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 13 Nov 2025 02:46:40 -0500 Subject: [PATCH 10/11] restore fp8 changes, fix shared modules and attached methods Signed-off-by: Kyle Sayers --- .../model_executor/layers/quantization/fp8.py | 100 ++++++++++-------- .../model_loader/online_quantization.py | 14 +-- vllm/v1/worker/gpu_model_runner.py | 8 +- 3 files changed, 66 insertions(+), 56 deletions(-) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d42034c1ebd7..cb065eb68b66 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -7,6 +7,7 @@ import torch from torch.nn import Module +from torch.nn.parameter import Parameter import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -483,18 +484,6 @@ def create_weights( else: layer.register_parameter("input_scale", None) - # create per-tensor qparams populated by process_weights_after_loading - else: - scale = create_fp8_scale_parameter( - PerTensorScaleParameter, - output_partition_sizes, - input_size_per_partition, - None, - weight_loader, - ) - set_weight_attrs(scale, {"scale_type": "weight_scale"}) - layer.register_parameter("weight_scale", scale) - def process_weights_after_loading(self, layer: Module) -> None: size_k_first = True input_scale = None @@ -506,8 +495,8 @@ def process_weights_after_loading(self, layer: Module) -> None: weight, weight_scale = process_fp8_weight_block_strategy( layer.weight, layer.weight_scale_inv ) - # Rename weight_scale_inv parameter for consistency - layer.weight_scale = layer.weight_scale_inv + # Delete the weight_scale_inv parameter to avoid confusion + # with the weight_scale parameter del layer.weight_scale_inv # If checkpoint not serialized fp8, quantize the weights. @@ -536,10 +525,13 @@ def process_weights_after_loading(self, layer: Module) -> None: weight = weight.t() # Update layer with new values. - layer.weight.copy_(weight.data) - layer.weight_scale.copy_(weight_scale.data) - if input_scale is not None: - layer.input_scale.copy_(input_scale) + layer.weight = Parameter(weight.data, requires_grad=False) + layer.weight_scale = Parameter(weight_scale.data, requires_grad=False) + layer.input_scale = ( + Parameter(input_scale, requires_grad=False) + if input_scale is not None + else None + ) if self.use_marlin: prepare_fp8_layer_for_marlin(layer, size_k_first) @@ -767,10 +759,12 @@ def create_weights( if self.block_quant else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} ) - - # add weight loaders to support loading (and reloading) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.quant_config.activation_scheme == "static": @@ -833,18 +827,22 @@ def process_weights_after_loading(self, layer: Module) -> None: w2_weight_scale_inv = layer.w2_weight_scale_inv # torch.compile() cannot use Parameter subclasses. - layer.w13_weight.copy_(w13_weight) - layer.w13_weight_scale_inv.copy_(w13_weight_scale_inv) - layer.w2_weight.copy_(w2_weight) - layer.w2_weight_scale_inv.copy_(w2_weight_scale_inv) + layer.w13_weight = Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale_inv = Parameter( + w13_weight_scale_inv, requires_grad=False + ) + layer.w2_weight = Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale_inv = Parameter( + w2_weight_scale_inv, requires_grad=False + ) if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight.data, layer.w2_weight.data ) - layer.w13_weight.copy_(shuffled_w13) - layer.w2_weight.copy_(shuffled_w2) + layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) # DeepGemm scales need to be transposed and aligned. We try to do # it ahead of time for performance reasons. @@ -866,7 +864,7 @@ def process_weights_after_loading(self, layer: Module) -> None: # Re-initialize w13_scale because we directly quantize # merged w13 weights and generate a single scaling factor. - layer.w13_weight_scale.copy_( + layer.w13_weight_scale = torch.nn.Parameter( torch.ones( layer.local_num_experts, dtype=torch.float32, @@ -881,16 +879,16 @@ def process_weights_after_loading(self, layer: Module) -> None: w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) ) - layer.w13_weight.copy_(w13_weight) - layer.w2_weight.copy_(w2_weight) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight, layer.w2_weight ) - layer.w13_weight.copy_(shuffled_w13) - layer.w2_weight.copy_(shuffled_w2) + layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) # If checkpoint is fp8, we need to handle that the # MoE kernels require single activation scale and single weight # scale for w13 per expert. @@ -911,8 +909,12 @@ def process_weights_after_loading(self, layer: Module) -> None: "fp8 MoE layer. Using the maximum across experts " "for each layer." ) - layer.w13_input_scale.copy_(layer.w13_input_scale.max()) - layer.w2_input_scale.copy_(layer.w2_input_scale.max()) + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False + ) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False + ) if current_platform.is_fp8_fnuz(): # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = ( @@ -926,14 +928,22 @@ def process_weights_after_loading(self, layer: Module) -> None: ) ) # Reset the parameter - layer.w13_weight.copy_(w13_weight) - layer.w13_weight_scale.copy_(w13_weight_scale) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + w13_weight_scale, requires_grad=False + ) if w13_input_scale is not None: - layer.w13_input_scale.copy_(w13_input_scale) - layer.w2_weight.copy_(w2_weight) - layer.w2_weight_scale.copy_(w2_weight_scale) + layer.w13_input_scale = torch.nn.Parameter( + w13_input_scale, requires_grad=False + ) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) if w2_input_scale is not None: - layer.w2_input_scale.copy_(w2_input_scale) + layer.w2_input_scale = torch.nn.Parameter( + w2_input_scale, requires_grad=False + ) # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max then dequant and requant each expert. @@ -957,10 +967,12 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w13_weight, layer.w2_weight ) - layer.w13_weight.copy_(shuffled_w13) - layer.w2_weight.copy_(shuffled_w2) + layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) - layer.w13_weight_scale.copy_(max_w13_scales) + layer.w13_weight_scale = torch.nn.Parameter( + max_w13_scales, requires_grad=False + ) if self.flashinfer_moe_backend is not None: # NOTE: weights have to be swapped since the activation is diff --git a/vllm/model_executor/model_loader/online_quantization.py b/vllm/model_executor/model_loader/online_quantization.py index 1311b7a04922..73b46a8b99b7 100644 --- a/vllm/model_executor/model_loader/online_quantization.py +++ b/vllm/model_executor/model_loader/online_quantization.py @@ -33,8 +33,6 @@ ``` """ -from types import MethodType - import torch from torch import nn @@ -77,7 +75,10 @@ def restore_weights_for_reloading(model: nn.Module): for param_fqn in params_to_remove: module_name, param_name = param_fqn.rsplit(".", 1) module = model.get_submodule(module_name) - delattr(module, param_name) + + # sometimes modules are shared, as is the case for `shared_experts` + if hasattr(module, param_name): + delattr(module, param_name) # restore parameters that were present at load time for param_fqn, meta_tensor in metadata.items(): @@ -89,7 +90,6 @@ def restore_weights_for_reloading(model: nn.Module): if _tensors_alike(original_tensor, meta_tensor): continue - delattr(module, param_name) # delete before materialization to avoid oom param = _materialize_meta_tensor(meta_tensor) setattr(module, param_name, param) @@ -126,10 +126,4 @@ def _materialize_meta_tensor(meta_tensor: torch.Tensor) -> torch.Tensor: tensor.__class__ = meta_tensor.__class__ tensor.__dict__ = meta_tensor.__dict__ - # rebind any references to the original tensor - # assume that methods are bound to the original tensor - for key, value in tensor.__dict__.items(): - if isinstance(value, MethodType): - tensor[key] = MethodType(value.__func__, tensor) - return tensor diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3449794a4afa..d9eb6e6df05f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3215,7 +3215,7 @@ def reload_weights( ) # begin loading weights - logger.info("Reloading weights inplace...") + logger.info_once("Reloading weights inplace...", scope="local") if process_weights_after_loading: if self.model_config.quantization in RELOADABLE_QUANT_CONFIGS: restore_weights_for_reloading(model) @@ -3247,7 +3247,11 @@ def reload_weights( # logging and validation counter_after_reloading = time.perf_counter() diff_seconds = counter_after_reloading - counter_before_reloading - logger.info("Reloading and processing weights took %.2f seconds", diff_seconds) + logger.info_once( + "Reloading and processing weights took %.2f seconds", + diff_seconds, + scope="local", + ) if self.model_config.quantization is None and loaded_weights is not None: weights_not_loaded = weights_to_load - loaded_weights if weights_not_loaded: From a964ddffee1a3770effa7a863ea7e86baf457b08 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 13 Nov 2025 11:31:56 -0500 Subject: [PATCH 11/11] reduce diff Signed-off-by: Kyle Sayers --- .../layers/quantization/kv_cache.py | 7 ------- vllm/model_executor/model_loader/utils.py | 15 +-------------- vllm/v1/worker/gpu_model_runner.py | 14 ++++++++++---- 3 files changed, 11 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index f0497a872290..78456dcf1ca5 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -45,13 +45,6 @@ def apply(self, layer: torch.nn.Module) -> torch.Tensor: raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.") def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # skip if there are no weights to process (for example, weight reloading) - if not hasattr(layer, "q_scale"): - assert not hasattr(layer, "k_scale") - assert not hasattr(layer, "v_scale") - assert not hasattr(layer, "prob_scale") - return - # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 # regardless whether the kv-scale is available in the checkpoint. # No need to process kv scales after loading if we are going to diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 8d49f14d115c..8c90eda51e0b 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -5,7 +5,6 @@ import inspect import time import warnings -from collections.abc import Iterable from contextlib import contextmanager from dataclasses import dataclass, field @@ -91,18 +90,6 @@ def initialize_model( return model_class(**kwargs) -def default_model_weight_loader( - model: torch.nn.Module, weights_iterator: Iterable[tuple[str, torch.Tensor]] -) -> set[str]: - loaded_weights = set() - for name, loaded_weight in weights_iterator: - param = model.get_parameter(name) - param.weight_loader(param, loaded_weight) - loaded_weights.add(name) - - return loaded_weights - - def process_weights_after_loading( model: nn.Module, model_config: ModelConfig, target_device: torch.device ) -> None: @@ -134,7 +121,7 @@ def process_weights_after_loading( counter_after_processing_weights = time.perf_counter() diff_seconds = counter_after_processing_weights - counter_before_processing_weights - logger.debug("Processing weights took %.2f seconds", diff_seconds) + logger.debug_once("Processed weights in %.2f seconds", diff_seconds, scope="local") @contextmanager diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d9eb6e6df05f..caa2f097a665 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -55,9 +55,6 @@ RELOADABLE_QUANT_CONFIGS, restore_weights_for_reloading, ) -from vllm.model_executor.model_loader.utils import ( - default_model_weight_loader, -) from vllm.model_executor.model_loader.utils import ( process_weights_after_loading as _process_weights_after_loading, ) @@ -3242,7 +3239,16 @@ def reload_weights( else: # load weights from kernel format - default_model_weight_loader(model, weights_iterator) + logger.warning_once( + "Reloading with `process_weights_after_loading=True` requires that " + "weights be in kernel format and already sharded", + scope="local", + ) + loaded_weights = set() + for name, loaded_weight in weights_iterator: + param = model.get_parameter(name) + param.copy_(loaded_weight) + loaded_weights.add(name) # logging and validation counter_after_reloading = time.perf_counter()