-
Notifications
You must be signed in to change notification settings - Fork 57
/
mlir_printer_test.py
224 lines (169 loc) · 5.89 KB
/
mlir_printer_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
from io import StringIO
from xdsl.ir import Attribute, Data, MLContext, MLIRType, Operation, ParametrizedAttribute
from xdsl.irdl import (AnyAttr, ParameterDef, RegionDef, VarOperandDef,
VarResultDef, irdl_attr_definition, irdl_op_definition)
from xdsl.parser import Parser
from xdsl.printer import Printer
import re
@irdl_op_definition
class ModuleOp(Operation):
"""Module operation. Redefined to not depend on the builtin dialect."""
name = "module"
region = RegionDef()
@irdl_op_definition
class AnyOp(Operation):
"""Operation only used for testing."""
name = "any"
op = VarOperandDef(AnyAttr())
res = VarResultDef(AnyAttr())
@irdl_attr_definition
class DataAttr(Data[int]):
"""Attribute only used for testing."""
name = "data_attr"
@staticmethod
def parse_parameter(parser: Parser) -> int:
return parser.parse_int_literal()
@staticmethod
def print_parameter(data: int, printer: Printer) -> None:
printer.print(data)
@irdl_attr_definition
class DataType(Data[int], MLIRType):
"""Attribute only used for testing."""
name = "data_type"
@staticmethod
def parse_parameter(parser: Parser) -> int:
return parser.parse_int_literal()
@staticmethod
def print_parameter(data: int, printer: Printer) -> None:
printer.print(data)
@irdl_attr_definition
class ParamAttr(ParametrizedAttribute):
name = "param_attr"
@irdl_attr_definition
class ParamAttrWithParam(ParametrizedAttribute):
name = "param_attr_with_param"
data: ParameterDef[Attribute]
@irdl_attr_definition
class ParamType(ParametrizedAttribute, MLIRType):
name = "param_type"
@irdl_attr_definition
class ParamAttrWithCustomFormat(ParametrizedAttribute):
name = "param_custom_format"
param1: ParameterDef[ParamAttr]
def print_parameters(self, printer: Printer) -> None:
printer.print(f"~~")
def print_as_mlir_and_compare(test_prog: str, expected: str):
ctx = MLContext()
ctx.register_op(ModuleOp)
ctx.register_op(AnyOp)
ctx.register_attr(DataAttr)
ctx.register_attr(DataType)
ctx.register_attr(ParamAttr)
ctx.register_attr(ParamType)
ctx.register_attr(ParamAttrWithParam)
ctx.register_attr(ParamAttrWithCustomFormat)
parser = Parser(ctx, test_prog)
module = parser.parse_op()
res = StringIO()
printer = Printer(target=Printer.Target.MLIR, stream=res)
printer.print_op(module)
# Remove all whitespace from the expected string.
regex = re.compile(r'[^\S]+')
assert (regex.sub("", res.getvalue()).strip() == \
regex.sub("", expected).strip())
def test_empty_op():
"""Test printing an empty operation."""
print_as_mlir_and_compare(
"""any()""",
""""any"() : () -> ()""",
)
def test_data_attr():
"""Test printing an operation with a data attribute."""
print_as_mlir_and_compare(
"""any() [ "attr" = !data_attr<42> ]""",
""""any"() {"attr" = #data_attr<42>} : () -> ()""",
)
def test_data_type():
"""Test printing an operation with a data type."""
print_as_mlir_and_compare(
"""%0 : !data_type<42> = any()""",
"""%0 = "any"() : () -> !data_type<42>""",
)
def test_param_attr():
"""Test printing an operation with a parametrized attribute."""
print_as_mlir_and_compare(
"""any() [ "attr" = !param_attr ]""",
""""any"() {"attr" = #param_attr } : () -> ()""",
)
def test_param_type():
"""Test printing an operation with a parametrized type."""
print_as_mlir_and_compare(
"""%0 : !param_type = any()""",
"""%0 = "any"() : () -> !param_type""",
)
def test_param_attr_with_param():
"""
Test printing an operation with a parametrized attribute with parameters.
"""
print_as_mlir_and_compare(
"""any() [ "attr" = !param_attr_with_param<!param_attr> ]""",
""""any"() {"attr" = #param_attr_with_param<#param_attr> }
: () -> ()""",
)
print_as_mlir_and_compare(
"""any() [ "attr" = !param_attr_with_param<!param_type> ]""",
""""any"() {"attr" = #param_attr_with_param<!param_type> }
: () -> ()""",
)
def test_op_with_region():
"""Test printing an operation with a region."""
print_as_mlir_and_compare(
"""module() {}""",
""""module"() ({}) : () -> ()""",
)
def test_op_with_results():
"""Test printing an operation with results."""
print_as_mlir_and_compare(
"""%0 : !param_attr = any()""",
"""%0 = "any"() : () -> #param_attr""",
)
print_as_mlir_and_compare(
"""%0 : !param_attr, %1 : !param_type = any()""",
"""%0, %1 = "any"() : () -> (#param_attr, !param_type)""",
)
def test_op_with_operands():
"""Test printing an operation with operands."""
print_as_mlir_and_compare(
"""module() {
%0 : !param_attr = any()
any(%0 : !param_attr)
}""",
""""module"() ({
%0 = "any"() : () -> #param_attr
"any"(%0) : (#param_attr) -> ()
}) : () -> ()
""",
)
print_as_mlir_and_compare(
"""module() {
%0 : !param_attr = any()
any(%0 : !param_attr, %0 : !param_attr)
}""",
""""module"() ({
%0 = "any"() : () -> #param_attr
"any"(%0, %0) : (#param_attr, #param_attr) -> ()
}) : () -> ()
""",
)
def test_op_with_attributes():
"""Test printing an operation with attributes."""
print_as_mlir_and_compare(
"""any() [ "attr" = !data_attr<42> ]""",
""""any"() {"attr" = #data_attr<42>} : () -> ()""",
)
def test_param_custom_format():
"""Test printing an operation with a param attribute with custom format."""
print_as_mlir_and_compare(
"""any() [ "attr" = !param_custom_format<!param_attr> ]""",
""""any"() {"attr" = #param_custom_format~~} : () -> ()""",
)