Skip to content

Extend optimize_for_ort to cover passes #2274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions onnxscript/rewriter/ort_fusions/_core.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
from __future__ import annotations

import onnxscript.ir as ir
from onnxscript.ir.passes.common import shape_inference
import onnxscript.ir.passes.common as common_passes
from onnxscript.optimizer import optimize
from onnxscript.rewriter import rewrite
from onnxscript.rewriter.ort_fusions import (
@@ -48,7 +48,7 @@
# TODO: Do we need this dependence on ONNX's partial-data-propagation? There are some
# extra shape-propagation and partial-data-propagation rules in ONNX that are not yet
# incorporated in our optimizer.
shape_inference.infer_shapes(model)
common_passes.ShapeInferencePass()(model)
optimize(model)
return model

@@ -135,4 +135,18 @@
)
# Apply the ORT pattern rewrite rules.
rewrite(model, ORT_PATTERN_REWRITE_RULES)
return model, fusion_count

passes = [

Check warning on line 139 in onnxscript/rewriter/ort_fusions/_core.py

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/_core.py#L139

Added line #L139 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
passes = [
passes = ir.passes.Sequential(

# TODO(exporter team): Fold transpose into initializers
# Apply the ORT optimization passes.
# https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L172
common_passes.ClearMetadataAndDocStringPass(),
# https://github.com/microsoft/onnxruntime/blob/74dcf7e296639095dfa55d31336998b6f719ed76/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py#L139
common_passes.LiftConstantsToInitializersPass(lift_all_constants=False, size_limit=1),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have another pass called LiftSubgraphInitializersToMainGraphPass. Do we know if it's needed in genAI? @kunal-vaishnavi

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the pass logic is in DynamoOnnxHelper, then it is used for ONNX Runtime GenAI.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really produce graphs with subgraph initializers. I think we are ok either way

common_passes.RemoveInitializersFromInputsPass(),
common_passes.ShapeInferencePass(),
common_passes.CheckerPass(),
]
optimize_for_ort_passes = ir.passes.Sequential(*passes)
Comment on lines +149 to +150
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
]
optimize_for_ort_passes = ir.passes.Sequential(*passes)
)

result = optimize_for_ort_passes(model)
return result.model, fusion_count

Check warning on line 152 in onnxscript/rewriter/ort_fusions/_core.py

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/_core.py#L150-L152

Added lines #L150 - L152 were not covered by tests
Loading
Oops, something went wrong.