From 322851a457ea2acbd9682b3f18a4b6ebbeb8655c Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Mon, 5 May 2025 22:10:28 +0000 Subject: [PATCH 1/2] draft --- onnxscript/rewriter/ort_fusions/_core.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index 6e23700ee..51138a551 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -3,7 +3,7 @@ from __future__ import annotations import onnxscript.ir as ir -from onnxscript.ir.passes.common import shape_inference +import onnxscript.ir.passes.common as common_passes from onnxscript.optimizer import optimize from onnxscript.rewriter import rewrite from onnxscript.rewriter.ort_fusions import ( @@ -47,7 +47,7 @@ def _pre_optimize(model: ir.Model) -> ir.Model: # TODO: Do we need this dependence on ONNX's partial-data-propagation? There are some # extra shape-propagation and partial-data-propagation rules in ONNX that are not yet # incorporated in our optimizer. - shape_inference.infer_shapes(model) + common_passes.ShapeInferencePass()(model) optimize(model) return model @@ -130,4 +130,17 @@ def optimize_for_ort( ) # Apply the ORT pattern rewrite rules. rewrite(model, ORT_PATTERN_REWRITE_RULES) + + # TODO(exporter team): Fold transpose into initializers + # Apply the ORT optimization passes. + # https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L172 + common_passes.ClearMetadataAndDocStringPass()(model) + # https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L139 + common_passes.LiftConstantsToInitializersPass(lift_all_constants=False, size_limit=1)( + model + ) + common_passes.RemoveInitializersFromInputsPass()(model) + common_passes.ShapeInferencePass()(model) + common_passes.CheckerPass()(model) + return model, fusion_count From 0226ed72acfe6f5787a89a970066bdbd1aed5075 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Thu, 8 May 2025 22:15:45 +0000 Subject: [PATCH 2/2] use pass manager --- onnxscript/rewriter/ort_fusions/_core.py | 27 ++++++++++++------------ 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index 51138a551..116c0016e 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -131,16 +131,17 @@ def optimize_for_ort( # Apply the ORT pattern rewrite rules. rewrite(model, ORT_PATTERN_REWRITE_RULES) - # TODO(exporter team): Fold transpose into initializers - # Apply the ORT optimization passes. - # https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L172 - common_passes.ClearMetadataAndDocStringPass()(model) - # https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L139 - common_passes.LiftConstantsToInitializersPass(lift_all_constants=False, size_limit=1)( - model - ) - common_passes.RemoveInitializersFromInputsPass()(model) - common_passes.ShapeInferencePass()(model) - common_passes.CheckerPass()(model) - - return model, fusion_count + passes = [ + # TODO(exporter team): Fold transpose into initializers + # Apply the ORT optimization passes. + # https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L172 + common_passes.ClearMetadataAndDocStringPass(), + # https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L139 + common_passes.LiftConstantsToInitializersPass(lift_all_constants=False, size_limit=1), + common_passes.RemoveInitializersFromInputsPass(), + common_passes.ShapeInferencePass(), + common_passes.CheckerPass(), + ] + optimize_for_ort_passes = ir.passes.Sequential(*passes) + result = optimize_for_ort_passes(model) + return result.model, fusion_count