diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index c06ac550a94a..f0f68000da29 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -277,39 +277,13 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: "0.14.0" ): self.load_config.safetensors_load_strategy = "torchao" - weights_to_load = {name for name, _ in model.named_parameters()} - - # if we don't have `model.weight_metadata_and_attr_saved` defined and - # set to True, it means that this is either offline quantization case - # or the first run of online quantization - # see online_quantization.py for detailed notes - offline_quantization_or_first_run_of_online_quantization = not getattr( - model, "weight_metadata_and_attr_saved", False - ) - if model_config.quantization is None: - # model is not quantized - loaded_weights = model.load_weights( - self.get_all_weights(model_config, model) - ) - elif offline_quantization_or_first_run_of_online_quantization: - # case 1: offline quantized checkpoint - # case 2: Step I1 first run of weight loading with - # online quantization - # see online_quantization.py for detailed notes - loaded_weights = model.load_weights( - self.get_all_weights(model_config, model) - ) - else: - # to avoid circular dependency - from vllm.model_executor.model_loader.online_quantization import ( - load_weights_and_online_quantize, - ) - - # subsequent runs of weight loading with online - # quantization - loaded_weights = load_weights_and_online_quantize(self, model, model_config) + # load weights into model + weights_to_load = {name for name, _ in model.named_parameters()} + weights_iterator = self.get_all_weights(model_config, model) + loaded_weights = model.load_weights(weights_iterator) + # logging and validation self.counter_after_loading_weights = time.perf_counter() logger.info_once( "Loading weights took %.2f seconds", diff --git a/vllm/model_executor/model_loader/online_quantization.py b/vllm/model_executor/model_loader/online_quantization.py index 890dd7231a0e..73b46a8b99b7 100644 --- a/vllm/model_executor/model_loader/online_quantization.py +++ b/vllm/model_executor/model_loader/online_quantization.py @@ -1,224 +1,129 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import types +""" +Utilities for enabling weight reloading and online quantization +For more information and diagrams, see https://github.com/neuralmagic/vllm/pull/128 + +## Model Reloading Lifecycle ## +1. Model is loadeded for the first time + a. Checkpoint is loaded by `ModelLoader.get_all_weights` into `weights_iterator` + b. `weights_iterator` is loaded into model by `model.load_weights` + c. Model state is captured by `record_weights_for_reloading` + d. `process_weights_after_loading` converts model state into kernel format. + The model is no longer loadable while its weights are in kernel format + +2. Model is reloaded via `reload_weights` + a. A `weights_iterator` is provided, which may be async/ chunked/ sharded + b. The original model state is restored by `restore_weights_for_reloading` + using metadata information from `record_weights_for_reloading` + c. `weights_iterator` is loaded into model by `model.load_weights` + d. `process_weights_after_loading` converts model state into kernel format. + The model is no longer loadable while its weights are in kernel format + +Alternatively, if a user does not want to use `reload_weights`, they can call +steps 2b and 2d manually: + +```python +record_weights_for_reloading(model) + +for weights in weights_iterator: # may be async/ chunked/ sharded + model.load_weights(weights) + +process_weights_after_loading(model, model_config, device) +``` +""" import torch from torch import nn -from vllm.config import ModelConfig from vllm.logger import init_logger -from vllm.model_executor.model_loader.default_loader import DefaultModelLoader -from vllm.model_executor.model_loader.utils import process_weights_after_loading logger = init_logger(__name__) -# Notes for Online Quantization -# In terms of state of checkpoints, quantization config and their -# correspondance to online quantization: -# | Use Case | Checkpoints | model_config.quantization | -# | no quant | high precision | None | -# | offline quant | quantized | fp8, torchao etc. | -# | online quant | high precision | torchao etc. | -# -# The process for loading non-quantized checkpoint -# 1. load non-quantized weights (load_weights) -# 2. do any additional post processing (process_weights_after_loading) -# -# The process for loading offline quantized checkpoint -# 1. load offline-quantized weights (load_weights) -# 2. do any additional post processing (process_weights_after_loading) - -# The process for unquantized model reloading -# (repeated run in RL training loop) -# first run -# UI1. load_weights: load bfloat16 weights -# UI2. process_weights_after_loading: any additional post processing -# subsequent run -# UC1: load_weights: load bfloat16 weights -# (shouldn't be any issues since we didn't change any attributes -# of the weights) -# UC2: process_weights_after_loading: any additional post processing - -# The process for weight reloading with online quantization -# (repeated run in RL training loop) -# first run -# I1. load_weights: load bfloat16 weights -# I2. process_weights_after_loading: -# record weight metadata and attributes for R1 and R2 -# quantize weights to fp8 -# subsequent run -# (beginning model weight is in fp8) -# load_weights: -# R1. restore bfloat16 model weight metadata -# R2. restore the model weight attributes -# R3. reload bfloat16 weights -# R4. quantize weights (by calling process_weights_after_loading), -# also set `process_weights_after_loading_already_called` to -# True to stop it from running again -# process_weights_after_loading (if called): -# this will be skipped since it's already ran in -# load_weights - - -def maybe_save_metadata_and_attributes_for_weight_reloading( - model: nn.Module, model_config: ModelConfig -): - # following is to support on the fly quantization, currently only supported - # for torchao - if model_config.quantization != "torchao": - return - - if getattr(model, "process_weights_after_loading_already_called", False): - # In case `process_weights_after_loading` is called multiple times - # we'll skip it at later times - logger.warning( - "process_weights_after_loading already called for model %s", model - ) - return - - from vllm.model_executor.model_loader.weight_utils import get_quant_config - - quant_config = get_quant_config(model_config, None) - - # If checkpoint is already torchao serialized, this means it's - # pre-quantized quantization case, we'll skip saving the metadata - # Otherwise, this is Step I2 of initialization steps of - # online quantization - # This step record the weights metadata and weight attributes so we can - # restore the bfloat16 model weights during the relad step (R1 and R2) - # see Notes in online_quantization.py for more details - if not ( - hasattr(quant_config, "is_checkpoint_torchao_serialized") - and not quant_config.is_checkpoint_torchao_serialized - ): - return - - # This is the I2 step of online quantiztion that saves - # metadata and attributes of weights so they can be used in R1 and - # R2 step, note that we only save these during initialization - - # Includes two things - # 1. save floating point metadata (shape, dtype, device) for init - # 2. save weight attributes, e.g. `output_dim`, `weight_loader` for init - - if getattr(model, "weight_metadata_and_attr_saved", False): - return - - # save the dtype, shape and device for model parameter, used for - # restoring the model high precision parameters before - # reloading the weights - assert not hasattr(model, "original_weights_rebuild_keys") - model.original_weights_rebuild_keys = {} - for name, p in model.named_parameters(): - model.original_weights_rebuild_keys[name] = { - "shape": p.shape, - "dtype": p.dtype, - "device": p.device, +__all__ = [ + "RELOADABLE_QUANT_CONFIGS", + "record_weights_for_reloading", + "restore_weights_for_reloading", +] + +# in theory, this implementation of weight recording/restoring +# should support any quantization config +RELOADABLE_QUANT_CONFIGS = { + None, + "torchao", + "fp8", +} + + +def record_weights_for_reloading(model: nn.Module): + # this function should be called before `process_weights_after_loading` + # in practice, this happens at the very start of `process_weights_after_loading` + if not hasattr(model, "weight_loading_metadata"): + model.weight_loading_metadata = { + name: _copy_to_meta_tensor(param) + for name, param in model.named_parameters(remove_duplicate=False) } - # record the weight attributes (loader functions etc.) - # so these can be recovered later when we reload the weights - # structure: {"weight_name": {"weight_attr_key": attr}} - assert not hasattr(model, "recorded_weight_attr") - model.recorded_weight_attr = {} - for name, param in model.named_parameters(): - model.recorded_weight_attr[name] = {} - for key in param.__dict__: - if hasattr(param, key): - attr = getattr(param, key) - if not callable(attr): - model.recorded_weight_attr[name][key] = attr - elif hasattr(attr, "__self__") and param is attr.__self__: - # if attr is a bonded method for an instance, and - # attr.__self__ points to the instance (param) - # we'll record the underlying function object - model.recorded_weight_attr[name][key] = attr.__func__ - else: - model.recorded_weight_attr[name][key] = attr - # mark the metadata and attributes saved so we don't run it again - model.weight_metadata_and_attr_saved = True - - -def _bond_method_to_cls(func, obj): - if hasattr(func, "__self__") or not callable(func): - # If the function is already bound to an instance, return it as is - return func - else: - return types.MethodType(func, obj) - - -def load_weights_and_online_quantize( - model_loader: DefaultModelLoader, model: nn.Module, model_config: ModelConfig -) -> set[str]: - # online quantization, right now only enabled for - # torchao - # R1, R2, R3, R4 in the Notes - - # TODO: Add fp8 support - assert model_config.quantization == "torchao", ( - "online quantization is only enabled for torchao currently" + +def restore_weights_for_reloading(model: nn.Module): + assert hasattr(model, "weight_loading_metadata") + metadata: dict[str, torch.Tensor] = model.weight_loading_metadata + model_param_names = dict(model.named_parameters(remove_duplicate=False)).keys() + + # remove parameters which were not present at load time + params_to_remove = model_param_names - metadata.keys() + for param_fqn in params_to_remove: + module_name, param_name = param_fqn.rsplit(".", 1) + module = model.get_submodule(module_name) + + # sometimes modules are shared, as is the case for `shared_experts` + if hasattr(module, param_name): + delattr(module, param_name) + + # restore parameters that were present at load time + for param_fqn, meta_tensor in metadata.items(): + module_name, param_name = param_fqn.rsplit(".", 1) + module = model.get_submodule(module_name) + + # for faster runtime, skip materialization if the tensors match + original_tensor = getattr(module, param_name, None) + if _tensors_alike(original_tensor, meta_tensor): + continue + + param = _materialize_meta_tensor(meta_tensor) + setattr(module, param_name, param) + + +def _copy_to_meta_tensor(tensor: torch.Tensor) -> torch.Tensor: + meta_tensor = tensor.to("meta") + meta_tensor.__class__ = tensor.__class__ + meta_tensor.__dict__ = tensor.__dict__ + meta_tensor._original_device = tensor.device + + return meta_tensor + + +def _tensors_alike(tensor: torch.Tensor | None, meta_tensor: torch.Tensor) -> bool: + if tensor is None: + return False + + return ( + tensor.device == meta_tensor._original_device + and tensor.dtype == meta_tensor.dtype + and tensor.shape == meta_tensor.shape + and tensor.__dict__ == meta_tensor.__dict__ ) - # TODO: use create_weights to restore the weights to original state - - # Step R1: First restore the quantized weights to original bfloat16 - # weights, with original metadata (shape, dtype, device) - # and attributes, so that bfloat16 weights can be loaded properly - existing_param_names = dict(model.named_parameters(remove_duplicate=False)).keys() - named_modules = dict(model.named_modules(remove_duplicate=False)) - model_device = None - - # Step R2: recover the parameter to the state before first loading - for name, d in model.original_weights_rebuild_keys.items(): - _shape = d["shape"] - _dtype = d["dtype"] - _device = d["device"] - if model_device is not None: - assert model_device == _device, ( - "Expecting all weights " - "to be in the same device for now, got both: " - f"{model_device} and {_device}" - ) - else: - model_device = _device - - if name in existing_param_names: - module_name, weight_name = name.rsplit(".", 1) - module = named_modules[module_name] - setattr( - module, - weight_name, - torch.nn.Parameter(torch.empty(_shape, dtype=_dtype, device=_device)), - ) - - # recorded_weight_attr is - # {"weight_name": {"weight_attr_key": attr}} - # e.g. - # { - # { - # "layer.0.weight": { - # "weight_loader": weight_loader_function_object, - # "input_dim": 0, ... - # }, - # "layer.1.weight": ..., - # } - # } - for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items(): - for attr_name, attr in weight_attr_dict.items(): - module_name, weight_name = full_weight_name.rsplit(".", 1) - module = named_modules[module_name] - weight = getattr(module, weight_name) - if not hasattr(weight, attr_name): - setattr(weight, attr_name, _bond_method_to_cls(attr, weight)) - - # Step I1: reload bfloat16 / high precision weights - loaded_weights = model.load_weights( - model_loader.get_all_weights(model_config, model) + + +def _materialize_meta_tensor(meta_tensor: torch.Tensor) -> torch.Tensor: + tensor = torch.empty_strided( + size=tuple(meta_tensor.size()), + stride=tuple(meta_tensor.stride()), + dtype=meta_tensor.dtype, + device=meta_tensor._original_device, + requires_grad=meta_tensor.requires_grad, ) + tensor.__class__ = meta_tensor.__class__ + tensor.__dict__ = meta_tensor.__dict__ - # Step I2: online quantize the weights - # manually process weights after loading - model.process_weights_after_loading_already_called = False - process_weights_after_loading(model, model_config, model_device) - model.process_weights_after_loading_already_called = True - return loaded_weights + return tensor diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index ba708a098c0d..8c90eda51e0b 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -3,6 +3,7 @@ """Utilities for selecting and loading models.""" import inspect +import time import warnings from contextlib import contextmanager from dataclasses import dataclass, field @@ -19,6 +20,10 @@ QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.model_loader.online_quantization import ( + RELOADABLE_QUANT_CONFIGS, + record_weights_for_reloading, +) from vllm.model_executor.models.adapters import ( as_embedding_model, as_reward_model, @@ -88,13 +93,11 @@ def initialize_model( def process_weights_after_loading( model: nn.Module, model_config: ModelConfig, target_device: torch.device ) -> None: - # to avoid circular dependency - from vllm.model_executor.model_loader.online_quantization import ( - maybe_save_metadata_and_attributes_for_weight_reloading, - ) - - maybe_save_metadata_and_attributes_for_weight_reloading(model, model_config) + # weight reloading: must be called before weights are processed + if model_config.quantization in RELOADABLE_QUANT_CONFIGS: + record_weights_for_reloading(model) + counter_before_processing_weights = time.perf_counter() for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if isinstance(quant_method, QuantizeMethodBase): @@ -116,6 +119,10 @@ def process_weights_after_loading( # of process_weights_after_loading module.process_weights_after_loading(model_config.dtype) + counter_after_processing_weights = time.perf_counter() + diff_seconds = counter_after_processing_weights - counter_before_processing_weights + logger.debug_once("Processed weights in %.2f seconds", diff_seconds, scope="local") + @contextmanager def device_loading_context(module: torch.nn.Module, target_device: torch.device): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fbd3e5f31316..caa2f097a665 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5,7 +5,7 @@ import itertools import time from collections import defaultdict -from collections.abc import Iterator +from collections.abc import Iterable, Iterator from contextlib import contextmanager from copy import deepcopy from functools import reduce @@ -51,6 +51,13 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader +from vllm.model_executor.model_loader.online_quantization import ( + RELOADABLE_QUANT_CONFIGS, + restore_weights_for_reloading, +) +from vllm.model_executor.model_loader.utils import ( + process_weights_after_loading as _process_weights_after_loading, +) from vllm.model_executor.models.interfaces import ( SupportsMultiModal, is_mixture_of_experts, @@ -3177,13 +3184,87 @@ def _get_eagle3_aux_layers_from_config(self) -> tuple[int, ...] | None: return None - def reload_weights(self) -> None: - assert getattr(self, "model", None) is not None, ( - "Cannot reload weights before model is loaded." + def reload_weights( + self, + weights_iterator: Iterable[tuple[str, torch.Tensor]] | None = None, + process_weights_after_loading: bool = True, + ) -> None: + # argument validation + if weights_iterator is None and not process_weights_after_loading: + logger.warning( + "Reloading from disk means that weights will be in checkpoint format. " + "Please use `process_weights_after_loading=True` " + "to avoid weight reloading errors" + ) + if getattr(self, "model", None) is None: + raise ValueError("Cannot reload weights before model is loaded.") + + model = self.get_model() + weights_to_load = {name for name, _ in model.named_parameters()} + counter_before_reloading = time.perf_counter() + + # load weights from disk if none are provided + if weights_iterator is None: + model_loader = get_model_loader(self.load_config) + weights_iterator = model_loader.get_all_weights(self.model_config, model) + weights_iterator = cast( + Iterable[tuple[str, torch.Tensor]], weights_iterator + ) + + # begin loading weights + logger.info_once("Reloading weights inplace...", scope="local") + if process_weights_after_loading: + if self.model_config.quantization in RELOADABLE_QUANT_CONFIGS: + restore_weights_for_reloading(model) + else: + logger.warning_once( + "Given quantization %s does not support weight reloading. " + "Consider adding to list of supported configs: %s", + self.model_config.quantization, + f"{RELOADABLE_QUANT_CONFIGS}", + ) + + # load weights from checkpoint/ original model format + loaded_weights = model.load_weights(weights_iterator) + + # process weights into kernel format + device_config = self.vllm_config.device_config + load_config = self.vllm_config.load_config + load_device = ( + device_config.device + if load_config.device is None + else load_config.device + ) + _process_weights_after_loading(model, self.model_config, load_device) + + else: + # load weights from kernel format + logger.warning_once( + "Reloading with `process_weights_after_loading=True` requires that " + "weights be in kernel format and already sharded", + scope="local", + ) + loaded_weights = set() + for name, loaded_weight in weights_iterator: + param = model.get_parameter(name) + param.copy_(loaded_weight) + loaded_weights.add(name) + + # logging and validation + counter_after_reloading = time.perf_counter() + diff_seconds = counter_after_reloading - counter_before_reloading + logger.info_once( + "Reloading and processing weights took %.2f seconds", + diff_seconds, + scope="local", ) - model_loader = get_model_loader(self.load_config) - logger.info("Reloading weights inplace...") - model_loader.load_weights(self.get_model(), model_config=self.model_config) + if self.model_config.quantization is None and loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError( + "Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}" + ) def save_tensorized_model( self,