Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion _doc/design/optimizer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ The method returns None if no match is found or an instance of class :class:`Mat
self,
node: Optional[NodeProto] = None,
lineno: Optional[int] = None,
msg: str = "",
msg: Optional[Union[Callable[[], str], str]] = None,
):

It may be useful which reason made a pattern matching fail.
Expand All @@ -84,6 +84,8 @@ expression:

By setting the verbosity (see next Section), the user may then know
which lines in the code returned None and which condition failed.
The last parameter is used to print a more comprehensive message about the
reason why the match failed.

PatternOptimization.apply
+++++++++++++++++++++++++
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_torch_interpreter/test_onnx_export_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def forward(self, x):
self.assertEqual(shape_x, ["batch", ""])
for obs in onx.graph.value_info:
shape = tuple((d.dim_param or d.dim_value) for d in obs.type.tensor_type.shape.dim)
self.assertIn(shape, (("2048*batch//1024", 1024), ("batch", 2, 1024)))
self.assertIn(shape, (("2*batch", 1024), ("batch", 2, 1024)))
sess = ExtendedReferenceEvaluator(model_path, verbose=0)
feeds = dict(zip(sess.input_names, [x.numpy() for x in xs]))
got = sess.run(None, feeds)[0]
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_xbuilder/test_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,7 @@ def test__apply_reshape_to_shape(self):
(
("s23", 1, "seq_length", "s31+seq_length"),
(-1,),
("s23*seq_length*(s31+seq_length)",),
("s23*(s31+seq_length)*seq_length",),
),
(("s44", 16, 1), (0, 1, -1), ("s44", 1, 16)),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7564,7 +7564,7 @@ def test_local_attention_gqa_1(self):
"ShapeBasedExpandSwap",
"FunctionAttention",
],
verbose=10,
verbose=0,
),
)
ort = self._check_with_ort(onx)
Expand Down
3 changes: 0 additions & 3 deletions _unittests/ut_xrun_doc/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,6 @@ def test_rename_dynamic_expression(self):
self.assertEqual("batch", rename_dynamic_expression("1*batch", {}))
self.assertEqual("batch", rename_dynamic_expression("batch*1", {}))

def test_rename_dynamic_expression_reorder(self):
self.assertEqual("a+b", rename_dynamic_expression("b+a", {}))

def test_dot_plot(self):
TFLOAT = onnx.TensorProto.FLOAT
model = oh.make_model(
Expand Down
4 changes: 2 additions & 2 deletions _unittests/ut_xshape/test_shape_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,14 +321,14 @@ def test_evaluate_shape(self):
builder.run_model(model)
self.assertEqual(
builder._known_shapes,
{"Y": ("batch", "seq1"), "X": ("batch", "seq2"), "Z": ("batch", "seq2+seq1")},
{"Y": ("batch", "seq1"), "X": ("batch", "seq2"), "Z": ("batch", "seq1+seq2")},
)
feeds = dict(
X=np.random.rand(3, 5).astype(np.float32), Y=np.random.rand(3, 6).astype(np.float32)
)
got = ExtendedReferenceEvaluator(model).run(None, feeds)
res = builder.compare_with_true_inputs(feeds, got)
self.assertEqual(res, {"Z": (("batch", 3, 3), ("seq2+seq1", 11, 11))})
self.assertEqual(res, {"Z": (("batch", 3, 3), ("seq1+seq2", 11, 11))})

def test_concat_split(self):
model = oh.make_model(
Expand Down
14 changes: 14 additions & 0 deletions _unittests/ut_xshape/test_simplify_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,20 @@ def test_simplify_add_sub(self):
def test_simplify_function(self):
self.assertEqual("CeilToInt(b+c,2)", simplify_expression("CeilToInt(b+c,2)"))

def test_simplify_function_order(self):
self.assertEqual("a+b", simplify_expression("b+a"))

def test_simplify_function_order3(self):
self.assertEqual("a+b+c", simplify_expression("c+b+a"))
self.assertEqual("a+b+c", simplify_expression("b+c+a"))
self.assertEqual("a+b+c", simplify_expression("a+c+b"))

def test_simplify_function_floordiv_int(self):
self.assertEqual("512*a", simplify_expression("1024*a//2"))
self.assertEqual("a", simplify_expression("1024*a//1024"))
self.assertEqual("a+b", simplify_expression("1024*(a+b)//1024"))
self.assertEqual("2*a+2*b", simplify_expression("1024*(a+b)//1024*2"))


if __name__ == "__main__":
unittest.main(verbosity=2)
12 changes: 10 additions & 2 deletions experimental_experiment/xbuilder/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2085,7 +2085,11 @@ def set_shape(
if hasattr(self, "replacements_dimensions_"):
self.replacements_dimensions_[name] = tuple(
(
rename_dynamic_expression(_, self.replacements_for_replacements_dimensions_)
simplify_expression(
rename_dynamic_expression(
_, self.replacements_for_replacements_dimensions_
)
)
if isinstance(_, str)
else _
)
Expand Down Expand Up @@ -6175,7 +6179,11 @@ def _update(dd_flat, names):
if v is None:
continue
self.replacements_dimensions_[k] = tuple(
(rename_dynamic_expression(_, replacements) if isinstance(_, str) else _)
(
simplify_expression(rename_dynamic_expression(_, replacements))
if isinstance(_, str)
else _
)
for _ in v
)
return replacements
Expand Down
6 changes: 5 additions & 1 deletion experimental_experiment/xoptim/patterns/onnx_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,11 @@ def match(
shape1 = g.get_shape_renamed(gqa_unsqueeze.input[0])
shape2 = g.get_shape_renamed(gqa_reshape.output[0])
if shape1[0] != shape2[0] or shape1[2] != shape2[2] or shape1[3] != shape2[3]:
return self.none(node, inspect.currentframe().f_lineno)
return self.none(
node,
inspect.currentframe().f_lineno,
msg=lambda: f"Shape mismatch {shape1=}, {shape2=}",
)
else:
# No Attention, no MultiHeadAttention, no GroupQueryAttention
return self.none(node, inspect.currentframe().f_lineno)
Expand Down
6 changes: 3 additions & 3 deletions experimental_experiment/xoptim/patterns_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def none(
self,
node: Optional[NodeProto] = None,
lineno: Optional[int] = None,
msg: Optional[Union[Callable, str]] = None,
msg: Optional[Union[Callable[[], str], str]] = None,
):
"""
It may be useful which reason made a pattern matching fail.
Expand All @@ -250,11 +250,11 @@ def none(
elif callable(msg):
msg = msg()
if msg:
msg = f"\n{msg}"
msg = f"\n reason: {msg}"
if self.verbose >= 10 and hasattr(self, "_debug"):
msg2 = self._debug_print()
if msg2:
msg2 = f"\n{textwrap.indent(msg2, ' ')}"
msg2 = f"\n reason: {textwrap.indent(msg2, ' ')}"
print(
f"[{self.__class__.__name__}.match] NONE - line: {lineno}:"
f"{os.path.split(self.__class__.__module__)[-1]}, "
Expand Down
64 changes: 12 additions & 52 deletions experimental_experiment/xshape/rename_expressions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import ast
from typing import Dict, List, Optional, Set
from typing import Dict, Optional, Set
from .simplify_expressions import SimpleSimpliflyTransformer, CommonTransformer


Expand All @@ -20,58 +20,20 @@ def parse_expression_tokens(expr: str) -> Set[str]:


class RenameTransformer(CommonTransformer):
def __init__(self, mapping, expr: Optional[str] = None):
super().__init__(expr)
self.mapping = mapping

def visit_Name(self, node):
if node.id in self.mapping:
return ast.copy_location(ast.Name(id=self.mapping[node.id], ctx=node.ctx), node)
return node


class ReorderCommutativeOpsTransformer(ast.NodeTransformer):
def __init__(self):
super().__init__()

def visit_BinOp(self, node: ast.BinOp):
# First recurse into children
self.generic_visit(node)

# Only process + and *
if isinstance(node.op, (ast.Add, ast.Mult)):
operands = self._flatten(node, type(node.op))
operands.sort(key=self._expr_key)
return self._rebuild(operands, node.op)

return node

def _flatten(self, node: ast.AST, op_type) -> List[ast.AST]:
"""Flattens a chain of same-type binary operations."""
if isinstance(node, ast.BinOp) and isinstance(node.op, op_type):
return self._flatten(node.left, op_type) + self._flatten(node.right, op_type)
return [node]

def _rebuild(self, operands: List[ast.AST], op: ast.operator) -> ast.AST:
"""Rebuilds a binary tree from sorted operands."""
expr = operands[0]
for operand in operands[1:]:
expr = ast.BinOp(left=expr, op=op, right=operand)
return expr

def _expr_key(self, node: ast.AST) -> str:
"""Generates a sortable key for expressions."""
return ast.unparse(node)
"""
Renames variable names into other based on a mapping.

:param magging: mapping
:param expr: only use for error messages
"""

class RenameVariable(CommonTransformer):
def __init__(self, mapping, expr: Optional[str] = None):
super().__init__()
def __init__(self, mapping: Dict[str, str], expr: Optional[str] = None):
super().__init__(expr)
self.mapping = mapping

def visit_Name(self, node):
if node.id in self.mapping:
node.id = self.mapping[node.id]
return ast.copy_location(ast.Name(id=self.mapping[node.id], ctx=node.ctx), node)
return node


Expand All @@ -85,8 +47,7 @@ def rename_expression(expr: str, mapping: Dict[str, str]) -> str:
"""
tree = ast.parse(expr, mode="eval")
transformer = RenameTransformer(mapping)
reorder = ReorderCommutativeOpsTransformer()
new_tree = reorder.visit(transformer.visit(tree))
new_tree = transformer.visit(tree)
ast.fix_missing_locations(new_tree)
return ast.unparse(new_tree).replace(" ", "")

Expand All @@ -104,10 +65,9 @@ def rename_dynamic_expression(expression: str, replacements: Dict[str, str]):
tree = ast.parse(expression)
except SyntaxError:
return expression
transformer = RenameVariable(replacements)
transformer = RenameTransformer(replacements)
simplify = SimpleSimpliflyTransformer()
reorder = ReorderCommutativeOpsTransformer()
new_tree = reorder.visit(simplify.visit(transformer.visit(tree)))
new_tree = simplify.visit(transformer.visit(tree))
res = ast.unparse(new_tree).replace(" ", "")
return res

Expand Down
Loading
Loading