Skip to content

Commit a879613

Browse files
titaiwangmsbmehta001
authored andcommitted
[rewriter] Enable llama rule sets (microsoft#2124)
Enable llama_rule_sets. We might need to come up with a better name.
1 parent 433b751 commit a879613

File tree

4 files changed

+10
-33
lines changed

4 files changed

+10
-33
lines changed

onnxscript/optimizer/_optimizer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
cast_constant_of_shape,
1313
collapse_slices,
1414
gemm_to_matmul_add,
15+
llama_rule_sets,
1516
no_op,
1617
)
1718

@@ -23,6 +24,7 @@
2324
gemm_to_matmul_add.rule,
2425
*cast_constant_of_shape.rules.rules,
2526
*collapse_slices.rules.rules,
27+
*llama_rule_sets.llama_p0_rule_set().rules,
2628
]
2729

2830

onnxscript/rewriter/_ir_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77

88
import numpy as np
99

10-
import onnxscript.ir as ir
11-
from onnxscript.optimizer import basic_constant_propagation
10+
from onnxscript import ir, optimizer
1211

1312

1413
def display_nodes(nodes: Sequence[ir.Node]) -> None:
@@ -54,7 +53,7 @@ def visit(node: ir.Node, depth):
5453
def get_const_value(value: ir.Value) -> ir.TensorProtocol | None:
5554
node = value.producer()
5655
if node is not None:
57-
basic_constant_propagation([node])
56+
optimizer.basic_constant_propagation([node])
5857
return value.const_value
5958

6059

onnxscript/rewriter/llama_rule_sets.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66

77
import onnx.numpy_helper
88

9-
import onnxscript.ir as ir
10-
import onnxscript.rewriter._ir_utils as ir_utils
11-
import onnxscript.rewriter.no_op as no_op
12-
import onnxscript.rewriter.pattern as orp
9+
from onnxscript import ir
10+
from onnxscript.rewriter import _ir_utils as ir_utils
11+
from onnxscript.rewriter import pattern as orp
1312

1413

1514
class CastIdentity(orp.RewriteRuleAsClass):
@@ -271,15 +270,11 @@ def llama_p0_rule_set() -> orp.RewriteRuleSet:
271270
"""
272271
return orp.RewriteRuleSet(
273272
[
274-
no_op.mul_by_1_rule,
275-
no_op.add_0_rule,
276-
no_op.add_0_rule,
277-
no_op.div_by_1_rule,
278-
cast_cast_rule,
273+
# cast_cast_rule, # Might have precision issues.
279274
cast_identity_rule,
280275
expand_identity_rule,
281276
reshape_reshape_rule,
282-
slice_split_rule,
277+
slice_split_rule, # Affect collapse slices rules?
283278
transpose_identity_rule,
284279
transpose_transpose_rule,
285280
unsqueeze_unsqueeze_rule,

onnxscript/rewriter/llama_rule_sets_test.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -80,25 +80,6 @@ def _check_model(
8080
opset_imports=[onnx.helper.make_opsetid("", 18)],
8181
),
8282
),
83-
(
84-
"mul_by_one",
85-
_make_model(
86-
onnx.helper.make_graph(
87-
[
88-
onnx.helper.make_node("Mul", ["X", "one"], ["Y"]),
89-
],
90-
"name",
91-
[onnx.helper.make_tensor_value_info("X", FLOAT, [None])],
92-
[onnx.helper.make_tensor_value_info("Y", FLOAT, [None])],
93-
[
94-
onnx.numpy_helper.from_array(
95-
np.array([1], dtype=np.float32), name="one"
96-
)
97-
],
98-
),
99-
opset_imports=[onnx.helper.make_opsetid("", 18)],
100-
),
101-
),
10283
(
10384
"canceled_out_transposes",
10485
_make_model(
@@ -180,7 +161,7 @@ def test_llama_p0_rule_set_transpose_transpose(self, _: str, model: ir.Model):
180161
]
181162
)
182163
def test_llama_p0_rule_set_cast_cast(self, _: str, model: ir.Model):
183-
rule_set = llama_rule_sets.llama_p0_rule_set()
164+
rule_set = llama_rule_sets.cast_cast_rule
184165
model_proto = ir.serde.serialize_model(model)
185166
rule_set.apply_to_model(model)
186167
rewritten_model = ir.serde.serialize_model(model)

0 commit comments

Comments
 (0)