Skip to content

Delay pass creation until time to run #1621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions olive/engine/engine.py
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@
from olive.logging import enable_filelog
from olive.model import ModelConfig
from olive.package_config import OlivePackageConfig
from olive.passes.olive_pass import FullPassConfig
from olive.search.search_sample import SearchSample
from olive.search.search_strategy import SearchStrategy, SearchStrategyConfig
from olive.systems.common import SystemType
@@ -677,28 +678,27 @@ def _run_pass(
"""Run a pass on the input model."""
run_start_time = datetime.now().timestamp()

pass_config: RunPassConfig = self.computed_passes_configs[pass_name]
pass_type_name = pass_config.type
run_pass_config: RunPassConfig = self.computed_passes_configs[pass_name]
pass_type_name = run_pass_config.type

logger.info("Running pass %s:%s", pass_name, pass_type_name)

# check whether the config is valid
pass_cls: Type[Pass] = self.olive_config.import_pass_module(pass_config.type)
if not pass_cls.validate_config(pass_config.config, accelerator_spec):
pass_cls: Type[Pass] = self.olive_config.import_pass_module(run_pass_config.type)
if not pass_cls.validate_config(run_pass_config.config, accelerator_spec):
logger.warning("Invalid config, pruned.")
logger.debug(pass_config)
logger.debug(run_pass_config)
# no need to record in footprint since there was no run and thus no valid/failed model
# invalid configs are also not cached since the same config can be valid for other accelerator specs
# a pass can be accelerator agnostic but still have accelerator specific invalid configs
# this helps reusing cached models for different accelerator specs
return INVALID_CONFIG, None

p: Pass = pass_cls(accelerator_spec, pass_config.config, self.get_host_device())
pass_config = p.config.to_json()
pass_config = run_pass_config.config.to_json()
output_model_config = None

# load run from cache if it exists
run_accel = None if p.is_accelerator_agnostic(accelerator_spec) else accelerator_spec
run_accel = None if pass_cls.is_accelerator_agnostic(accelerator_spec) else accelerator_spec
output_model_id = self.cache.get_output_model_id(pass_type_name, pass_config, input_model_id, run_accel)
run_cache = self.cache.load_run_from_model_id(output_model_id)
if run_cache:
@@ -729,15 +729,18 @@ def _run_pass(
input_model_config = self.cache.prepare_resources_for_local(input_model_config)

try:
if p.run_on_target:
if pass_cls.run_on_target:
if self.target.system_type == SystemType.IsolatedORT:
logger.warning(
"Cannot run pass %s on IsolatedORT target, will use the host to run the pass.", pass_name
)
else:
host = self.target

output_model_config = host.run_pass(p, input_model_config, output_model_path)
full_pass_config = FullPassConfig.from_run_pass_config(
run_pass_config, accelerator_spec, self.get_host_device()
)
output_model_config = host.run_pass(full_pass_config, input_model_config, output_model_path)
except OlivePassError:
logger.exception("Pass run_pass failed")
output_model_config = FAILED_CONFIG
52 changes: 32 additions & 20 deletions olive/passes/olive_pass.py
Original file line number Diff line number Diff line change
@@ -6,14 +6,15 @@
import logging
import shutil
from abc import ABC, abstractmethod
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, get_args

from olive.common.config_utils import ParamCategory, validate_config
from olive.common.pydantic_v1 import BaseModel, ValidationError, create_model
from olive.common.pydantic_v1 import BaseModel, ValidationError, create_model, validator
from olive.common.user_module_loader import UserModuleLoader
from olive.data.config import DataConfig
from olive.hardware import DEFAULT_CPU_ACCELERATOR, AcceleratorSpec
from olive.hardware import DEFAULT_CPU_ACCELERATOR, AcceleratorSpec, Device
from olive.model import CompositeModelHandler, DistributedOnnxModelHandler, OliveModelHandler, ONNXModelHandler
from olive.passes.pass_config import (
AbstractPassConfig,
@@ -80,17 +81,10 @@
if hasattr(self.config, "user_script") and hasattr(self.config, "script_dir"):
self._user_module_loader = UserModuleLoader(self.config.user_script, self.config.script_dir)

# Params that are paths [(param_name, required)]
self.path_params = [
(param, param_config.required, param_config.category)
for param, param_config in self.default_config(accelerator_spec).items()
if param_config.category in (ParamCategory.PATH, ParamCategory.DATA)
]

self._initialized = False

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
"""Whether the pass is accelerator agnostic. If True, the pass will be reused for all accelerators.

The default value is True. The subclass could choose to override this method to return False by using the
@@ -482,17 +476,35 @@
reconstruct the pass from the JSON file.
"""

accelerator: Dict[str, str] = None
host_device: Optional[str] = None
accelerator: Optional[AcceleratorSpec] = None
host_device: Optional[Device] = None

@validator("accelerator", pre=True)
def validate_accelerator(cls, v):
if isinstance(v, AcceleratorSpec):
return v
elif isinstance(v, dict):
return AcceleratorSpec(**v)
raise ValueError("Invalid accelerator input.")

def create_pass(self):
if not isinstance(self.accelerator, dict):
raise ValueError(f"accelerator must be a dict, got {self.accelerator}")
def create_pass(self) -> Pass:
"""Create a Pass."""
return super().create_pass_with_args(self.accelerator, self.host_device)

pass_cls = Pass.registry[self.type.lower()]
accelerator_spec = AcceleratorSpec(**self.accelerator) # pylint: disable=not-a-mapping
self.config = pass_cls.generate_config(accelerator_spec, self.config)
return pass_cls(accelerator_spec, self.config, self.host_device)
@staticmethod
def from_run_pass_config(
run_pass_config: Dict[str, Any],
accelerator: "AcceleratorSpec",
host_device: Device = None,
) -> "FullPassConfig":
config = deepcopy(run_pass_config) if isinstance(run_pass_config, dict) else run_pass_config.dict()
config.update(
{
"accelerator": accelerator,
"host_device": host_device,
}
)
return validate_config(config, FullPassConfig)


# TODO(myguo): deprecate or remove this function by explicitly specify the accelerator_spec in the arguments
4 changes: 2 additions & 2 deletions olive/passes/onnx/inc_quantization.py
Original file line number Diff line number Diff line change
@@ -251,8 +251,8 @@
class IncQuantization(Pass):
"""Quantize ONNX model with Intel® Neural Compressor."""

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
"""Override this method to return False by using the accelerator spec information."""
return False

4 changes: 2 additions & 2 deletions olive/passes/onnx/model_builder.py
Original file line number Diff line number Diff line change
@@ -129,8 +129,8 @@ def validate_config(
return False
return True

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
return False

def _run_for_config(
4 changes: 2 additions & 2 deletions olive/passes/onnx/optimum_merging.py
Original file line number Diff line number Diff line change
@@ -19,8 +19,8 @@ class OptimumMerging(Pass):

_accepts_composite_model = True

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
"""Override this method to return False by using the accelerator spec information."""
return False

4 changes: 2 additions & 2 deletions olive/passes/onnx/session_params_tuning.py
Original file line number Diff line number Diff line change
@@ -79,8 +79,8 @@ def get_thread_affinity_nums(affinity_str):
class OrtSessionParamsTuning(Pass):
"""Optimize ONNX Runtime inference settings."""

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
"""Override this method to return False by using the accelerator spec information."""
return False

4 changes: 2 additions & 2 deletions olive/passes/onnx/transformer_optimization.py
Original file line number Diff line number Diff line change
@@ -36,8 +36,8 @@ class OrtTransformersOptimization(Pass):
# using a Linux machine which doesn't support onnxruntime-directml package.
# It is enough for the pass to fail if `opt_level` > 0 and the host doesn't have the required packages.

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
"""Override this method to return False by using the accelerator spec information."""
from onnxruntime import __version__ as OrtVersion
from packaging import version
4 changes: 2 additions & 2 deletions olive/passes/onnx/vitis_ai_quantization.py
Original file line number Diff line number Diff line change
@@ -209,8 +209,8 @@ def _initialize(self):
super()._initialize()
self.tmp_dir = tempfile.TemporaryDirectory(prefix="olive_vaiq_tmp")

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
"""Override this method to return False by using the accelerator spec information."""
return False

20 changes: 19 additions & 1 deletion olive/passes/pass_config.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from pathlib import Path
from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Type, Union
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional, Set, Type, Union

from olive.common.config_utils import (
ConfigBase,
@@ -21,6 +21,9 @@
from olive.resource_path import validate_resource_path
from olive.search.search_parameter import SearchParameter, SpecialParamValue, json_to_search_parameter

if TYPE_CHECKING:
from olive.hardware.accelerator import AcceleratorSpec


class PassParamDefault(StrEnumBase):
"""Default values for passes."""
@@ -129,6 +132,21 @@
def validate_type(cls, v):
return validate_lowercase(v)

@validator("config", pre=True, always=True)
def validate_config(cls, v):
return v or {}

def create_pass_with_args(self, accelerator: "AcceleratorSpec", host_device: Device):
"""Create a Pass."""
from olive.package_config import OlivePackageConfig # pylint: disable=cyclic-import

olive_config = OlivePackageConfig.load_default_config()
pass_cls = olive_config.import_pass_module(self.type)
self.config = pass_cls.generate_config(
accelerator, self.config if isinstance(self.config, dict) else self.config.dict()
)
return pass_cls(accelerator, self.config, host_device)


def create_config_class(
pass_type: str,
4 changes: 2 additions & 2 deletions olive/passes/pytorch/capture_split_info.py
Original file line number Diff line number Diff line change
@@ -77,8 +77,8 @@ def validate_config(

return True

@staticmethod
def is_accelerator_agnostic(accelerator_spec: AcceleratorSpec) -> bool:
@classmethod
def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool:
return False

def _run_for_config(
19 changes: 13 additions & 6 deletions olive/systems/azureml/aml_system.py
Original file line number Diff line number Diff line change
@@ -40,7 +40,7 @@

if TYPE_CHECKING:
from olive.hardware.accelerator import AcceleratorSpec
from olive.passes.olive_pass import Pass
from olive.passes.olive_pass import FullPassConfig


logger = logging.getLogger(__name__)
@@ -243,23 +243,30 @@ def _assert_not_none(self, obj):
if obj is None:
raise ValueError(f"{obj.__class__.__name__} is missing in the inputs!")

def run_pass(self, the_pass: "Pass", model_config: ModelConfig, output_model_path: str) -> ModelConfig:
def run_pass(
self,
full_pass_config: "FullPassConfig",
model_config: "ModelConfig",
output_model_path: str,
) -> ModelConfig:
"""Run the pass on the model."""
ml_client = self.azureml_client_config.create_client()

# serialize pass
pass_config = the_pass.to_json(check_object=True)
# serialize config
serialized_pass_config = full_pass_config.to_json(check_object=True)

with tempfile.TemporaryDirectory() as tempdir:
pipeline_job = self._create_pipeline_for_pass(tempdir, model_config.to_json(check_object=True), pass_config)
pipeline_job = self._create_pipeline_for_pass(
tempdir, model_config.to_json(check_object=True), serialized_pass_config
)

# submit job
named_outputs_dir = self._run_job(
ml_client,
pipeline_job,
"olive-pass",
tempdir,
tags={"Pass": pass_config["type"]},
tags={"Pass": serialized_pass_config["type"]},
output_name="pipeline_output",
)
pipeline_output_path = named_outputs_dir / "pipeline_output"
Loading
Oops, something went wrong.