Skip to content
138 changes: 101 additions & 37 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import onnx
import torch

from QEfficient.base.onnx_transforms import OnnxTransform
from QEfficient.base.onnx_transforms import BaseOnnxTransform, OnnxTransform
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.generation.cloud_infer import QAICInferenceSession
Expand Down Expand Up @@ -47,11 +47,12 @@ class QEFFBaseModel(ABC):
"""

_pytorch_transforms: List[PytorchTransform]
_onnx_transforms: List[OnnxTransform]
_onnx_transforms = [BaseOnnxTransform]

@classmethod
def _transform_names(cls) -> List[str]:
return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms]
pytorch_names = [x.__name__ for x in cls._pytorch_transforms]
return pytorch_names + cls._onnx_transforms

def __init__(self, model: torch.nn.Module, **kwargs) -> None:
super().__init__()
Expand All @@ -78,28 +79,71 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
else:
logger.info(f"Pytorch transforms applied to model: {self.model_name}")

def _offload_model_weights(self, offload_pt_weights) -> bool:
"""
Clear PyTorch weights after export if offload_pt_weights is set to True

Returns:
bool: True if weights were successfully offloaded, False otherwise
"""
# Check if offloading is enabled and weights are not already offloaded
if offload_pt_weights and not self._is_weights_offloaded:
try:
self.model = self.model.to_empty(device="meta")
self._is_weights_offloaded = True
logger.info("Model weights offloaded to meta device")

gc.collect()
logger.info("PyTorch weights cleared after export")
return True
def _clear_model_weights(self) -> None:
"""Clear PyTorch model weights to reduce memory usage after ONNX export."""
try:
# Clear tensor storage and replace with empty shell
for param in self.model.parameters():
if hasattr(param, "data") and hasattr(param.data, "storage"):
param.data.storage().resize_(0)

for buffer in self.model.buffers():
if hasattr(buffer, "data") and hasattr(buffer.data, "storage"):
buffer.data.storage().resize_(0)

# Clear module dictionaries and hooks
for module in self.model.modules():
if hasattr(module, "_parameters"):
module._parameters.clear()
if hasattr(module, "_buffers"):
module._buffers.clear()

# Clear hooks
for hook_dict in [
getattr(module, "_forward_hooks", {}),
getattr(module, "_forward_pre_hooks", {}),
getattr(module, "_backward_hooks", {}),
getattr(module, "_state_dict_hooks", {}),
getattr(module, "_load_state_dict_pre_hooks", {}),
]:
hook_dict.clear()

# Replace with minimal shell for compatibility
class ModelShell:
def __init__(self, config):
self.config = config
self.qaic_config = None
self.device = torch.device("meta")

def parameters(self):
return iter([])

def named_parameters(self):
return iter([])

def buffers(self):
return iter([])

def named_buffers(self):
return iter([])

def modules(self):
return iter([self])

def state_dict(self):
return {}

def to(self, device):
return self

def eval(self):
return self

config = getattr(self.model, "config", None)
self.model = ModelShell(config)

except Exception as e:
logger.error(f"Failed to offload model weights: {e}")
return False
return False
except Exception as e:
logger.warning(f"Weight clearing failed, continuing: {e}")

def _model_offloaded_check(self) -> None:
"""
Expand Down Expand Up @@ -244,19 +288,32 @@ def _export(

try:
export_kwargs = {} if export_kwargs is None else export_kwargs
torch.onnx.export(
self.model,
(example_inputs,),
str(tmp_onnx_path),
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=constants.ONNX_EXPORT_OPSET,
**export_kwargs,
)

with torch.no_grad():
torch.onnx.export(
self.model,
(example_inputs,),
str(tmp_onnx_path),
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=constants.ONNX_EXPORT_OPSET,
**export_kwargs,
)
logger.info("PyTorch export successful")

_ = self._offload_model_weights(offload_pt_weights)
# Clear PyTorch weights after successful export to reduce memory usage
if offload_pt_weights:
self._clear_model_weights()
self._is_weights_offloaded = True
logger.info("PyTorch weights cleared after ONNX export")

# Clear temporary references
example_inputs.clear()
input_names.clear()

# Force garbage collection
gc.collect()

model = onnx.load(tmp_onnx_path, load_external_data=False)
transform_kwargs = {
Expand All @@ -266,8 +323,9 @@ def _export(
if onnx_transform_kwargs is not None:
transform_kwargs.update(onnx_transform_kwargs)

for transform in self._onnx_transforms:
model, transformed = transform.apply(model, **transform_kwargs)
transform_kwargs["transforms"] = self._onnx_transforms
# for transform in self._onnx_transforms:
model, transformed = OnnxTransform.apply(model, **transform_kwargs)

model.metadata_props.append(
onnx.StringStringEntryProto(key="qeff_transforms", value=",".join(self._transform_names()))
Expand All @@ -283,6 +341,12 @@ def _export(

finally:
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)
# Clear external data from memory and cache after all transforms and saving
# Make sure model exists before trying to clean it up
if "model" in locals():
BaseOnnxTransform._cleanup_external_data_and_cache(model)
BaseOnnxTransform._cleanup_memory()
logger.info("Cleanup complete.")

self.onnx_path = onnx_path
return onnx_path
Expand Down
Loading
Loading