Skip to content

Commit 15dbee5

Browse files
committed
Delay pass creation until time to run
Logic avoids instantiating a Pass unless it is to be run. To avoid instantiation, only FullPassConfig is passed around.
1 parent 34eb17a commit 15dbee5

23 files changed

+201
-136
lines changed

olive/engine/engine.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from olive.logging import enable_filelog
2727
from olive.model import ModelConfig
2828
from olive.package_config import OlivePackageConfig
29+
from olive.passes.olive_pass import FullPassConfig
2930
from olive.search.search_sample import SearchSample
3031
from olive.search.search_strategy import SearchStrategy, SearchStrategyConfig
3132
from olive.systems.common import SystemType
@@ -682,28 +683,27 @@ def _run_pass(
682683
"""Run a pass on the input model."""
683684
run_start_time = datetime.now().timestamp()
684685

685-
pass_config: RunPassConfig = self.computed_passes_configs[pass_name]
686-
pass_type_name = pass_config.type
686+
run_pass_config: RunPassConfig = self.computed_passes_configs[pass_name]
687+
pass_type_name = run_pass_config.type
687688

688689
logger.info("Running pass %s:%s", pass_name, pass_type_name)
689690

690691
# check whether the config is valid
691-
pass_cls: Type[Pass] = self.olive_config.import_pass_module(pass_config.type)
692-
if not pass_cls.validate_config(pass_config.config, accelerator_spec):
692+
pass_cls: Type[Pass] = self.olive_config.import_pass_module(run_pass_config.type)
693+
if not pass_cls.validate_config(run_pass_config.config, accelerator_spec):
693694
logger.warning("Invalid config, pruned.")
694-
logger.debug(pass_config)
695+
logger.debug(run_pass_config)
695696
# no need to record in footprint since there was no run and thus no valid/failed model
696697
# invalid configs are also not cached since the same config can be valid for other accelerator specs
697698
# a pass can be accelerator agnostic but still have accelerator specific invalid configs
698699
# this helps reusing cached models for different accelerator specs
699700
return INVALID_CONFIG, None
700701

701-
p: Pass = pass_cls(accelerator_spec, pass_config.config, self.get_host_device())
702-
pass_config = p.config.to_json()
702+
pass_config = run_pass_config.config.to_json()
703703
output_model_config = None
704704

705705
# load run from cache if it exists
706-
run_accel = None if p.is_accelerator_agnostic(accelerator_spec) else accelerator_spec
706+
run_accel = None if pass_cls.is_accelerator_agnostic(accelerator_spec) else accelerator_spec
707707
output_model_id = self.cache.get_output_model_id(pass_type_name, pass_config, input_model_id, run_accel)
708708
run_cache = self.cache.load_run_from_model_id(output_model_id)
709709
if run_cache:
@@ -734,15 +734,18 @@ def _run_pass(
734734
input_model_config = self.cache.prepare_resources_for_local(input_model_config)
735735

736736
try:
737-
if p.run_on_target:
737+
if pass_cls.run_on_target:
738738
if self.target.system_type == SystemType.IsolatedORT:
739739
logger.warning(
740740
"Cannot run pass %s on IsolatedORT target, will use the host to run the pass.", pass_name
741741
)
742742
else:
743743
host = self.target
744744

745-
output_model_config = host.run_pass(p, input_model_config, output_model_path)
745+
full_pass_config = FullPassConfig.from_run_pass_config(
746+
run_pass_config, accelerator_spec, self.get_host_device()
747+
)
748+
output_model_config = host.run_pass(full_pass_config, input_model_config, output_model_path)
746749
except OlivePassError:
747750
logger.exception("Pass run_pass failed")
748751
output_model_config = FAILED_CONFIG

olive/passes/olive_pass.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
import logging
77
import shutil
88
from abc import ABC, abstractmethod
9+
from copy import deepcopy
910
from pathlib import Path
1011
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, get_args
1112

1213
from olive.common.config_utils import ParamCategory, validate_config
13-
from olive.common.pydantic_v1 import BaseModel, ValidationError, create_model
14+
from olive.common.pydantic_v1 import BaseModel, ValidationError, create_model, validator
1415
from olive.common.user_module_loader import UserModuleLoader
1516
from olive.data.config import DataConfig
16-
from olive.hardware import DEFAULT_CPU_ACCELERATOR, AcceleratorSpec
17+
from olive.hardware import DEFAULT_CPU_ACCELERATOR, AcceleratorSpec, Device
1718
from olive.model import CompositeModelHandler, DistributedOnnxModelHandler, OliveModelHandler, ONNXModelHandler
1819
from olive.passes.pass_config import (
1920
AbstractPassConfig,
@@ -80,17 +81,10 @@ def __init__(
8081
if hasattr(self.config, "user_script") and hasattr(self.config, "script_dir"):
8182
self._user_module_loader = UserModuleLoader(self.config.user_script, self.config.script_dir)
8283

83-
# Params that are paths [(param_name, required)]
84-
self.path_params = [
85-
(param, param_config.required, param_config.category)
86-
for param, param_config in self.default_config(accelerator_spec).items()
87-
if param_config.category in (ParamCategory.PATH, ParamCategory.DATA)
88-
]
89-
9084
self._initialized = False
9185

92-
@staticmethod
93-
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
86+
@classmethod
87+
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
9488
"""Whether the pass is accelerator agnostic. If True, the pass will be reused for all accelerators.
9589
9690
The default value is True. The subclass could choose to override this method to return False by using the
@@ -482,17 +476,35 @@ class FullPassConfig(AbstractPassConfig):
482476
reconstruct the pass from the JSON file.
483477
"""
484478

485-
accelerator: Dict[str, str] = None
486-
host_device: Optional[str] = None
479+
accelerator: Optional[AcceleratorSpec] = None
480+
host_device: Optional[Device] = None
481+
482+
@validator("accelerator", pre=True)
483+
def validate_accelerator(cls, v):
484+
if isinstance(v, AcceleratorSpec):
485+
return v
486+
elif isinstance(v, dict):
487+
return AcceleratorSpec(**v)
488+
raise ValueError("Invalid accelerator input.")
487489

488-
def create_pass(self):
489-
if not isinstance(self.accelerator, dict):
490-
raise ValueError(f"accelerator must be a dict, got {self.accelerator}")
490+
def create_pass(self) -> Pass:
491+
"""Create a Pass."""
492+
return super().create_pass_with_args(self.accelerator, self.host_device)
491493

492-
pass_cls = Pass.registry[self.type.lower()]
493-
accelerator_spec = AcceleratorSpec(**self.accelerator) # pylint: disable=not-a-mapping
494-
self.config = pass_cls.generate_config(accelerator_spec, self.config)
495-
return pass_cls(accelerator_spec, self.config, self.host_device)
494+
@staticmethod
495+
def from_run_pass_config(
496+
run_pass_config: Dict[str, Any],
497+
accelerator: "AcceleratorSpec",
498+
host_device: Device = None,
499+
) -> "FullPassConfig":
500+
config = deepcopy(run_pass_config) if isinstance(run_pass_config, dict) else run_pass_config.dict()
501+
config.update(
502+
{
503+
"accelerator": accelerator,
504+
"host_device": host_device,
505+
}
506+
)
507+
return validate_config(config, FullPassConfig)
496508

497509

498510
# TODO(myguo): deprecate or remove this function by explicitly specify the accelerator_spec in the arguments

olive/passes/onnx/inc_quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,8 @@
251251
class IncQuantization(Pass):
252252
"""Quantize ONNX model with Intel® Neural Compressor."""
253253

254-
@staticmethod
255-
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
254+
@classmethod
255+
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
256256
"""Override this method to return False by using the accelerator spec information."""
257257
return False
258258

olive/passes/onnx/model_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ def validate_config(
121121
return False
122122
return True
123123

124-
@staticmethod
125-
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
124+
@classmethod
125+
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
126126
return False
127127

128128
def _run_for_config(

olive/passes/onnx/optimum_merging.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ class OptimumMerging(Pass):
1919

2020
_accepts_composite_model = True
2121

22-
@staticmethod
23-
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
22+
@classmethod
23+
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
2424
"""Override this method to return False by using the accelerator spec information."""
2525
return False
2626

olive/passes/onnx/session_params_tuning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def get_thread_affinity_nums(affinity_str):
7979
class OrtSessionParamsTuning(Pass):
8080
"""Optimize ONNX Runtime inference settings."""
8181

82-
@staticmethod
83-
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
82+
@classmethod
83+
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
8484
"""Override this method to return False by using the accelerator spec information."""
8585
return False
8686

olive/passes/onnx/transformer_optimization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ class OrtTransformersOptimization(Pass):
3636
# using a Linux machine which doesn't support onnxruntime-directml package.
3737
# It is enough for the pass to fail if `opt_level` > 0 and the host doesn't have the required packages.
3838

39-
@staticmethod
40-
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
39+
@classmethod
40+
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
4141
"""Override this method to return False by using the accelerator spec information."""
4242
from onnxruntime import __version__ as OrtVersion
4343
from packaging import version

olive/passes/onnx/vitis_ai_quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,8 @@ def _initialize(self):
209209
super()._initialize()
210210
self.tmp_dir = tempfile.TemporaryDirectory(prefix="olive_vaiq_tmp")
211211

212-
@staticmethod
213-
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
212+
@classmethod
213+
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
214214
"""Override this method to return False by using the accelerator spec information."""
215215
return False
216216

olive/passes/pass_config.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
55
from pathlib import Path
6-
from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Type, Union
6+
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional, Set, Type, Union
77

88
from olive.common.config_utils import (
99
ConfigBase,
@@ -20,6 +20,9 @@
2020
from olive.resource_path import validate_resource_path
2121
from olive.search.search_parameter import SearchParameter, SpecialParamValue, json_to_search_parameter
2222

23+
if TYPE_CHECKING:
24+
from olive.hardware.accelerator import AcceleratorSpec
25+
2326

2427
class PassParamDefault(StrEnumBase):
2528
"""Default values for passes."""
@@ -128,6 +131,23 @@ class AbstractPassConfig(NestedConfig):
128131
def validate_type(cls, v):
129132
return validate_lowercase(v)
130133

134+
@validator("config", pre=True, always=True)
135+
def validate_config(cls, v):
136+
return v or {}
137+
138+
def create_pass_with_args(self, accelerator: "AcceleratorSpec", host_device: Device):
139+
"""Create a Pass."""
140+
if TYPE_CHECKING:
141+
from olive.passes.olive_pass import Pass # pylint: disable=cyclic-import
142+
from olive.package_config import OlivePackageConfig # pylint: disable=cyclic-import
143+
144+
olive_config = OlivePackageConfig.load_default_config()
145+
pass_cls: Type[Pass] = olive_config.import_pass_module(self.type)
146+
self.config = pass_cls.generate_config(
147+
accelerator, self.config if isinstance(self.config, dict) else self.config.dict()
148+
)
149+
return pass_cls(accelerator, self.config, host_device)
150+
131151

132152
def create_config_class(
133153
pass_type: str,

olive/passes/pytorch/capture_split_info.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def validate_config(
8181

8282
return True
8383

84-
@staticmethod
85-
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
84+
@classmethod
85+
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
8686
return False
8787

8888
def _run_for_config(

0 commit comments

Comments
 (0)