diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 0b07bb6b3..6c9f88d9f 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -5,6 +5,7 @@ # # ---------------------------------------------------------------------------- +import gc import inspect import logging import shutil @@ -63,6 +64,9 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None: (arch := getattr(self.model.config, "architectures", None)) and len(arch) > 0 and arch[0] ) or None + # Flag for checking if weights are offloaded + self._is_weights_offloaded: bool = False + # Apply the transformations any_transformed = False for transform in self._pytorch_transforms: @@ -74,6 +78,44 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None: else: logger.info(f"Pytorch transforms applied to model: {self.model_name}") + def _offload_model_weights(self, offload_pt_weights) -> bool: + """ + Clear PyTorch weights after export if offload_pt_weights is set to True + + Returns: + bool: True if weights were successfully offloaded, False otherwise + """ + # Check if offloading is enabled and weights are not already offloaded + if offload_pt_weights and not self._is_weights_offloaded: + try: + self.model = self.model.to_empty(device="meta") + self._is_weights_offloaded = True + logger.info("Model weights offloaded to meta device") + + gc.collect() + logger.info("PyTorch weights cleared after export") + return True + + except Exception as e: + logger.error(f"Failed to offload model weights: {e}") + return False + return False + + def _model_offloaded_check(self) -> None: + """ + Check if the model is in meta state or weights are offloaded. + + Raises: + RuntimeError: If model is in meta state or if weights are offloaded + """ + if self._is_weights_offloaded or any(param.is_meta for param in self.model.parameters()): + error_msg = ( + "Cannot re-export model: weights have been offloaded to save memory. " + "To re-export, please create a new model instance using from_pretrained() method." + ) + logger.error(error_msg) + raise RuntimeError(error_msg) + @property @abstractmethod def model_name(self) -> str: ... @@ -130,9 +172,15 @@ def _export( export_kwargs: Optional[Dict[str, any]] = None, onnx_transform_kwargs: Optional[Dict[str, any]] = None, export_dir: Optional[str] = None, + offload_pt_weights: bool = True, ) -> str: """ - Export the Pytorch model to ONNX. + Export the PyTorch model to ONNX and apply ONNX transforms + + This method: + 1. Exports PyTorch model to ONNX using torch.onnx.export + 2. Clears PyTorch weights after export + 3. Applies ONNX transforms with reduced memory footprint Args: :example_inputs (dict): Sample inputs to trace the model. @@ -141,18 +189,30 @@ def _export( :export_kwargs (dict): Additional arguments to be passed to `torch.onnx.export`. :onnx_transform_kwargs (dict): Additional arguments to be passed to `Transform.apply` for this class. :export_dir (str): Specify the export directory. The export_dir will be suffixed with a hash corresponding to current model. + :offload_pt_weights (bool): If True, offload PyTorch model weights to meta device + after successful export to reduce memory usage. Set to False if you need to + keep weights for further operations. Defaults to True. + Note: + Once weights are offloaded, the model cannot be re-exported. Create a new + instance using from_pretrained() for re-export. + """ onnx_path = export_dir / f"{self.model_name}.onnx" + + # Return early if ONNX already exists if onnx_path.is_file(): self.onnx_path = onnx_path return onnx_path + # check if the model is in meta state or weights are offloaded + self._model_offloaded_check() + + # Setup temporary paths tmp_onnx_dir = export_dir / "onnx_tmp" tmp_onnx_path = tmp_onnx_dir / f"{self.model_name}.onnx" tmp_onnx_dir.mkdir(parents=True, exist_ok=True) # Create input_names from example_inputs - input_names = [] for param in inspect.signature(self.model.forward).parameters: if param in example_inputs: @@ -188,7 +248,9 @@ def _export( opset_version=constants.ONNX_EXPORT_OPSET, **export_kwargs, ) - logger.info("Pytorch export successful") + logger.info("PyTorch export successful") + + _ = self._offload_model_weights(offload_pt_weights) model = onnx.load(tmp_onnx_path, load_external_data=False) transform_kwargs = { @@ -200,17 +262,17 @@ def _export( for transform in self._onnx_transforms: model, transformed = transform.apply(model, **transform_kwargs) + model.metadata_props.append( onnx.StringStringEntryProto(key="qeff_transforms", value=",".join(self._transform_names())) ) logger.info("ONNX transforms applied") onnx.save(model, onnx_path) - logger.info("Transformed onnx saved") + logger.info("Transformed ONNX saved") except Exception as e: - logger.error(f"ONNX export (or) ONNXTransforms failed: {e}") - + logger.error(f"ONNX export or transforms failed: {e}") raise e finally: diff --git a/QEfficient/peft/auto.py b/QEfficient/peft/auto.py index f1532ad1b..820372561 100644 --- a/QEfficient/peft/auto.py +++ b/QEfficient/peft/auto.py @@ -287,7 +287,7 @@ def generate( generation_config = generation_config or self.model.generation_config generation_config, model_kwargs = self.model._prepare_generation_config(generation_config, **kwargs) - self.model._prepare_special_tokens(generation_config) + self.model._prepare_special_tokens(generation_config, device="cpu") if generation_config.do_sample: raise NotImplementedError("do_sample=True not supported currently") if generation_config.num_beams > 1: diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 3e50a2783..b3d27f3a5 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -433,8 +433,10 @@ def __init__(self, model: nn.modules, **kwargs): self.model = model.get_qeff_vision_encoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ - def export(self, inputs, output_names, dynamic_axes, export_dir=None): - return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir) + def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True): + return self._export( + inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights + ) def compile( self, @@ -488,8 +490,10 @@ def __init__(self, model, **kwargs): self.model = model.get_qeff_language_decoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ - def export(self, inputs, output_names, dynamic_axes, export_dir=None): - return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir) + def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True): + return self._export( + inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights + ) def compile( self, @@ -583,14 +587,18 @@ def export( inputs = self.model.get_dummy_inputs(kv_offload=True) dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True) output_names = self.model.get_output_names(kv_offload=True) + self.vision_model.export( inputs["vision"], output_names["vision"], dynamic_axes["vision"], export_dir=export_dir, + offload_pt_weights=False, + ) + self.lang_model.export( + inputs["lang"], output_names["lang"], dynamic_axes["lang"], export_dir=export_dir, offload_pt_weights=True ) - self.lang_model.export(inputs["lang"], output_names["lang"], dynamic_axes["lang"], export_dir=export_dir) return self.onnx_path def compile( diff --git a/tests/base/test_export_memory_offload.py b/tests/base/test_export_memory_offload.py new file mode 100644 index 000000000..d1b7a4653 --- /dev/null +++ b/tests/base/test_export_memory_offload.py @@ -0,0 +1,160 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import pytest +from transformers import AutoConfig, AutoModelForCausalLM + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + +# Simple test config for memory reduction testing +test_config = AutoConfig.for_model( + "gpt2", + max_position_embeddings=256, + num_hidden_layers=2, + num_attention_heads=4, + hidden_size=128, + intermediate_size=512, + vocab_size=127, + num_key_value_heads=2, +) + +model_kwargs = {"attn_implementation": "eager"} + + +@pytest.fixture +def tmp_cache(tmp_path, monkeypatch): + monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path) + yield tmp_path + + +def test_offload_weights_method(): + """Test the _offload_model_weights method with both True and False values.""" + model = AutoModelForCausalLM.from_config(test_config, **model_kwargs) + qeff_model = QEFFAutoModelForCausalLM(model, continuous_batching=False) + + # Initially weights should not be offloaded + assert not qeff_model._is_weights_offloaded + assert not any(param.is_meta for param in qeff_model.model.parameters()) + + # Test with offload_pt_weights=True + success = qeff_model._offload_model_weights(offload_pt_weights=True) + assert success + assert qeff_model._is_weights_offloaded + assert all(param.is_meta for param in qeff_model.model.parameters()) + + # Reset for next test + model2 = AutoModelForCausalLM.from_config(test_config, **model_kwargs) + qeff_model2 = QEFFAutoModelForCausalLM(model2, continuous_batching=False) + + # Test with offload_pt_weights=False + success = qeff_model2._offload_model_weights(offload_pt_weights=False) + assert not success + assert not qeff_model2._is_weights_offloaded + assert not any(param.is_meta for param in qeff_model2.model.parameters()) + + +def test_re_export_behavior_with_offloaded_weights(tmp_cache): + """Test that re-export fails when weights are offloaded.""" + model = AutoModelForCausalLM.from_config(test_config, **model_kwargs) + qeff_model = QEFFAutoModelForCausalLM(model, continuous_batching=False) + + # First export should succeed + _ = qeff_model.export() + assert qeff_model.onnx_path is not None + + # Manually offload weights + qeff_model._offload_model_weights(offload_pt_weights=True) + assert qeff_model._is_weights_offloaded + + # Force a new export by removing the file + import os + + os.remove(qeff_model.onnx_path) + qeff_model.onnx_path = None + + # Re-export should fail with RuntimeError due to offloaded weights + with pytest.raises(RuntimeError, match="weights have been offloaded"): + qeff_model.export() + + +def test_vlm_dual_qpc_memory_offload_behavior(): + """Test asymmetric memory offload behavior for VLM dual QPC models.""" + + # Mock vision model (should NOT offload weights) + class MockVisionModel: + def __init__(self): + self._is_weights_offloaded = False + + def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True): + if offload_pt_weights: + self._is_weights_offloaded = True + return "vision_export_path" + + # Mock language model (should offload weights) + class MockLangModel: + def __init__(self): + self._is_weights_offloaded = False + + def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True): + if offload_pt_weights: + self._is_weights_offloaded = True + return "lang_export_path" + + # Test dual QPC behavior + vision_model = MockVisionModel() + lang_model = MockLangModel() + + # Simulate dual QPC export behavior + vision_model.export({}, [], {}, offload_pt_weights=False) # Vision model doesn't offload + lang_model.export({}, [], {}, offload_pt_weights=True) # Language model offloads + + # Verify asymmetric behavior + assert not vision_model._is_weights_offloaded # Vision model should NOT be offloaded + assert lang_model._is_weights_offloaded # Language model should be offloaded + + +def test_vlm_single_qpc_memory_offload_behavior(): + """Test memory offload behavior for VLM single QPC models with both True and False.""" + + class MockParam: + def __init__(self, is_meta=False): + self.is_meta = is_meta + + class MockModel: + def __init__(self): + self._params = [MockParam(is_meta=False)] + + def parameters(self): + return self._params + + class MockSingleQPCModel: + def __init__(self): + self._is_weights_offloaded = False + self.model = MockModel() + + def _offload_model_weights(self): + self._is_weights_offloaded = True + for param in self.model.parameters(): + param.is_meta = True + return True + + def export(self, export_dir=None, offload_pt_weights=True): + if offload_pt_weights: + self._offload_model_weights() + return "single_qpc_export_path" + + # Test with offload_pt_weights=True + qeff_model = MockSingleQPCModel() + qeff_model.export(offload_pt_weights=True) + assert qeff_model._is_weights_offloaded + assert all(param.is_meta for param in qeff_model.model.parameters()) + + # Test with offload_pt_weights=False + qeff_model2 = MockSingleQPCModel() + qeff_model2.export(offload_pt_weights=False) + assert not qeff_model2._is_weights_offloaded + assert not any(param.is_meta for param in qeff_model2.model.parameters()) diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index 77354ee23..49d2ccf8c 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -282,6 +282,10 @@ def test_causal_lm_export_with_deprecated_api(model_name): tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name) qeff_model = QEFFAutoModelForCausalLM(model, model_name=model_name, pretrained_model_name_or_path=model_name) new_api_onnx_model_path = qeff_model.export() + + # Again loading model since the export moves model to meta device + model, _ = load_causal_lm_model(model_name, n_layer=1) + qeff_model = QEFFAutoModelForCausalLM(model, model_name=model_name, pretrained_model_name_or_path=model_name) _, old_api_onnx_model_path = qualcomm_efficient_converter( model_name=model_name, model_kv=qeff_model, tokenizer=tokenizer )