Skip to content

Commit ddc3b81

Browse files
committed
Delay pass creation until time to run
Logic avoids instantiating a Pass unless it is to be run. To avoid instantiation, only FullPassConfig is passed around.
1 parent f57bb8c commit ddc3b81

23 files changed

+206
-136
lines changed

olive/engine/engine.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from olive.logging import enable_filelog
2727
from olive.model import ModelConfig
2828
from olive.package_config import OlivePackageConfig
29+
from olive.passes.olive_pass import FullPassConfig
2930
from olive.search.search_sample import SearchSample
3031
from olive.search.search_strategy import SearchStrategy, SearchStrategyConfig
3132
from olive.systems.common import SystemType
@@ -677,28 +678,27 @@ def _run_pass(
677678
"""Run a pass on the input model."""
678679
run_start_time = datetime.now().timestamp()
679680

680-
pass_config: RunPassConfig = self.computed_passes_configs[pass_name]
681-
pass_type_name = pass_config.type
681+
run_pass_config: RunPassConfig = self.computed_passes_configs[pass_name]
682+
pass_type_name = run_pass_config.type
682683

683684
logger.info("Running pass %s:%s", pass_name, pass_type_name)
684685

685686
# check whether the config is valid
686-
pass_cls: Type[Pass] = self.olive_config.import_pass_module(pass_config.type)
687-
if not pass_cls.validate_config(pass_config.config, accelerator_spec):
687+
pass_cls: Type[Pass] = self.olive_config.import_pass_module(run_pass_config.type)
688+
if not pass_cls.validate_config(run_pass_config.config, accelerator_spec):
688689
logger.warning("Invalid config, pruned.")
689-
logger.debug(pass_config)
690+
logger.debug(run_pass_config)
690691
# no need to record in footprint since there was no run and thus no valid/failed model
691692
# invalid configs are also not cached since the same config can be valid for other accelerator specs
692693
# a pass can be accelerator agnostic but still have accelerator specific invalid configs
693694
# this helps reusing cached models for different accelerator specs
694695
return INVALID_CONFIG, None
695696

696-
p: Pass = pass_cls(accelerator_spec, pass_config.config, self.get_host_device())
697-
pass_config = p.config.to_json()
697+
pass_config = run_pass_config.config.to_json()
698698
output_model_config = None
699699

700700
# load run from cache if it exists
701-
run_accel = None if p.is_accelerator_agnostic(accelerator_spec) else accelerator_spec
701+
run_accel = None if pass_cls.is_accelerator_agnostic(accelerator_spec) else accelerator_spec
702702
output_model_id = self.cache.get_output_model_id(pass_type_name, pass_config, input_model_id, run_accel)
703703
run_cache = self.cache.load_run_from_model_id(output_model_id)
704704
if run_cache:
@@ -729,15 +729,18 @@ def _run_pass(
729729
input_model_config = self.cache.prepare_resources_for_local(input_model_config)
730730

731731
try:
732-
if p.run_on_target:
732+
if pass_cls.run_on_target:
733733
if self.target.system_type == SystemType.IsolatedORT:
734734
logger.warning(
735735
"Cannot run pass %s on IsolatedORT target, will use the host to run the pass.", pass_name
736736
)
737737
else:
738738
host = self.target
739739

740-
output_model_config = host.run_pass(p, input_model_config, output_model_path)
740+
full_pass_config = FullPassConfig.from_run_pass_config(
741+
run_pass_config, accelerator_spec, self.get_host_device()
742+
)
743+
output_model_config = host.run_pass(full_pass_config, input_model_config, output_model_path)
741744
except OlivePassError:
742745
logger.exception("Pass run_pass failed")
743746
output_model_config = FAILED_CONFIG

olive/passes/olive_pass.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
import logging
77
import shutil
88
from abc import ABC, abstractmethod
9+
from copy import deepcopy
910
from pathlib import Path
1011
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, get_args
1112

1213
from olive.common.config_utils import ParamCategory, validate_config
13-
from olive.common.pydantic_v1 import BaseModel, ValidationError, create_model
14+
from olive.common.pydantic_v1 import BaseModel, ValidationError, create_model, validator
1415
from olive.common.user_module_loader import UserModuleLoader
1516
from olive.data.config import DataConfig
16-
from olive.hardware import DEFAULT_CPU_ACCELERATOR, AcceleratorSpec
17+
from olive.hardware import DEFAULT_CPU_ACCELERATOR, AcceleratorSpec, Device
1718
from olive.model import CompositeModelHandler, DistributedOnnxModelHandler, OliveModelHandler, ONNXModelHandler
1819
from olive.passes.pass_config import (
1920
AbstractPassConfig,
@@ -80,17 +81,10 @@ def __init__(
8081
if hasattr(self.config, "user_script") and hasattr(self.config, "script_dir"):
8182
self._user_module_loader = UserModuleLoader(self.config.user_script, self.config.script_dir)
8283

83-
# Params that are paths [(param_name, required)]
84-
self.path_params = [
85-
(param, param_config.required, param_config.category)
86-
for param, param_config in self.default_config(accelerator_spec).items()
87-
if param_config.category in (ParamCategory.PATH, ParamCategory.DATA)
88-
]
89-
9084
self._initialized = False
9185

92-
@staticmethod
93-
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
86+
@classmethod
87+
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
9488
"""Whether the pass is accelerator agnostic. If True, the pass will be reused for all accelerators.
9589
9690
The default value is True. The subclass could choose to override this method to return False by using the
@@ -482,17 +476,35 @@ class FullPassConfig(AbstractPassConfig):
482476
reconstruct the pass from the JSON file.
483477
"""
484478

485-
accelerator: Dict[str, str] = None
486-
host_device: Optional[str] = None
479+
accelerator: Optional[AcceleratorSpec] = None
480+
host_device: Optional[Device] = None
481+
482+
@validator("accelerator", pre=True)
483+
def validate_accelerator(cls, v):
484+
if isinstance(v, AcceleratorSpec):
485+
return v
486+
elif isinstance(v, dict):
487+
return AcceleratorSpec(**v)
488+
raise ValueError("Invalid accelerator input.")
487489

488-
def create_pass(self):
489-
if not isinstance(self.accelerator, dict):
490-
raise ValueError(f"accelerator must be a dict, got {self.accelerator}")
490+
def create_pass(self) -> Pass:
491+
"""Create a Pass."""
492+
return super().create_pass_with_args(self.accelerator, self.host_device)
491493

492-
pass_cls = Pass.registry[self.type.lower()]
493-
accelerator_spec = AcceleratorSpec(**self.accelerator) # pylint: disable=not-a-mapping
494-
self.config = pass_cls.generate_config(accelerator_spec, self.config)
495-
return pass_cls(accelerator_spec, self.config, self.host_device)
494+
@staticmethod
495+
def from_run_pass_config(
496+
run_pass_config: Dict[str, Any],
497+
accelerator: "AcceleratorSpec",
498+
host_device: Device = None,
499+
) -> "FullPassConfig":
500+
config = deepcopy(run_pass_config) if isinstance(run_pass_config, dict) else run_pass_config.dict()
501+
config.update(
502+
{
503+
"accelerator": accelerator,
504+
"host_device": host_device,
505+
}
506+
)
507+
return validate_config(config, FullPassConfig)
496508

497509

498510
# TODO(myguo): deprecate or remove this function by explicitly specify the accelerator_spec in the arguments

olive/passes/onnx/inc_quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,8 @@
251251
class IncQuantization(Pass):
252252
"""Quantize ONNX model with Intel® Neural Compressor."""
253253

254-
@staticmethod
255-
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
254+
@classmethod
255+
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
256256
"""Override this method to return False by using the accelerator spec information."""
257257
return False
258258

olive/passes/onnx/model_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ def validate_config(
129129
return False
130130
return True
131131

132-
@staticmethod
133-
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
132+
@classmethod
133+
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
134134
return False
135135

136136
def _run_for_config(

olive/passes/onnx/optimum_merging.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ class OptimumMerging(Pass):
1919

2020
_accepts_composite_model = True
2121

22-
@staticmethod
23-
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
22+
@classmethod
23+
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
2424
"""Override this method to return False by using the accelerator spec information."""
2525
return False
2626

olive/passes/onnx/session_params_tuning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def get_thread_affinity_nums(affinity_str):
7979
class OrtSessionParamsTuning(Pass):
8080
"""Optimize ONNX Runtime inference settings."""
8181

82-
@staticmethod
83-
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
82+
@classmethod
83+
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
8484
"""Override this method to return False by using the accelerator spec information."""
8585
return False
8686

olive/passes/onnx/transformer_optimization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ class OrtTransformersOptimization(Pass):
3636
# using a Linux machine which doesn't support onnxruntime-directml package.
3737
# It is enough for the pass to fail if `opt_level` > 0 and the host doesn't have the required packages.
3838

39-
@staticmethod
40-
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
39+
@classmethod
40+
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
4141
"""Override this method to return False by using the accelerator spec information."""
4242
from onnxruntime import __version__ as OrtVersion
4343
from packaging import version

olive/passes/onnx/vitis_ai_quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,8 @@ def _initialize(self):
209209
super()._initialize()
210210
self.tmp_dir = tempfile.TemporaryDirectory(prefix="olive_vaiq_tmp")
211211

212-
@staticmethod
213-
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
212+
@classmethod
213+
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
214214
"""Override this method to return False by using the accelerator spec information."""
215215
return False
216216

olive/passes/pass_config.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
55
from pathlib import Path
6-
from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Type, Union
6+
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional, Set, Type, Union
77

88
from olive.common.config_utils import (
99
ConfigBase,
@@ -21,6 +21,9 @@
2121
from olive.resource_path import validate_resource_path
2222
from olive.search.search_parameter import SearchParameter, SpecialParamValue, json_to_search_parameter
2323

24+
if TYPE_CHECKING:
25+
from olive.hardware.accelerator import AcceleratorSpec
26+
2427

2528
class PassParamDefault(StrEnumBase):
2629
"""Default values for passes."""
@@ -129,6 +132,21 @@ class AbstractPassConfig(NestedConfig):
129132
def validate_type(cls, v):
130133
return validate_lowercase(v)
131134

135+
@validator("config", pre=True, always=True)
136+
def validate_config(cls, v):
137+
return v or {}
138+
139+
def create_pass_with_args(self, accelerator: "AcceleratorSpec", host_device: Device):
140+
"""Create a Pass."""
141+
from olive.package_config import OlivePackageConfig # pylint: disable=cyclic-import
142+
143+
olive_config = OlivePackageConfig.load_default_config()
144+
pass_cls = olive_config.import_pass_module(self.type)
145+
self.config = pass_cls.generate_config(
146+
accelerator, self.config if isinstance(self.config, dict) else self.config.dict()
147+
)
148+
return pass_cls(accelerator, self.config, host_device)
149+
132150

133151
def create_config_class(
134152
pass_type: str,

olive/passes/pytorch/capture_split_info.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def validate_config(
7777

7878
return True
7979

80-
@staticmethod
81-
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
80+
@classmethod
81+
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
8282
return False
8383

8484
def _run_for_config(

olive/systems/azureml/aml_system.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141
if TYPE_CHECKING:
4242
from olive.hardware.accelerator import AcceleratorSpec
43-
from olive.passes.olive_pass import Pass
43+
from olive.passes.olive_pass import FullPassConfig
4444

4545

4646
logger = logging.getLogger(__name__)
@@ -243,23 +243,30 @@ def _assert_not_none(self, obj):
243243
if obj is None:
244244
raise ValueError(f"{obj.__class__.__name__} is missing in the inputs!")
245245

246-
def run_pass(self, the_pass: "Pass", model_config: ModelConfig, output_model_path: str) -> ModelConfig:
246+
def run_pass(
247+
self,
248+
full_pass_config: "FullPassConfig",
249+
model_config: "ModelConfig",
250+
output_model_path: str,
251+
) -> ModelConfig:
247252
"""Run the pass on the model."""
248253
ml_client = self.azureml_client_config.create_client()
249254

250-
# serialize pass
251-
pass_config = the_pass.to_json(check_object=True)
255+
# serialize config
256+
serialized_pass_config = full_pass_config.to_json(check_object=True)
252257

253258
with tempfile.TemporaryDirectory() as tempdir:
254-
pipeline_job = self._create_pipeline_for_pass(tempdir, model_config.to_json(check_object=True), pass_config)
259+
pipeline_job = self._create_pipeline_for_pass(
260+
tempdir, model_config.to_json(check_object=True), serialized_pass_config
261+
)
255262

256263
# submit job
257264
named_outputs_dir = self._run_job(
258265
ml_client,
259266
pipeline_job,
260267
"olive-pass",
261268
tempdir,
262-
tags={"Pass": pass_config["type"]},
269+
tags={"Pass": serialized_pass_config["type"]},
263270
output_name="pipeline_output",
264271
)
265272
pipeline_output_path = named_outputs_dir / "pipeline_output"

0 commit comments

Comments
 (0)