diff --git a/tests/filecheck/mlir-conversion/builtin_attrs.mlir b/tests/filecheck/mlir-conversion/builtin_attrs.mlir index a91d343580..0e3b03c103 100644 --- a/tests/filecheck/mlir-conversion/builtin_attrs.mlir +++ b/tests/filecheck/mlir-conversion/builtin_attrs.mlir @@ -106,4 +106,16 @@ // CHECK: "value1" = opaque<"test", "contents">, "value2" = opaque<"test", "contents"> : tensor<2xf64> + "func.func"() ({}) {function_type = () -> (), + symbol = @some_symbol, + sym_name = "symbol_attr"} : () -> () + + // CHECK: "symbol" = @some_symbol + + "func.func"() ({}) {function_type = () -> (), + value1 = tensor, + sym_name = "non_static_shaped_tensor"} : () -> () + + // CHECK: tensor + }) : () -> () diff --git a/tests/filecheck/mlir-conversion/with-bindings/symbol_tests.xdsl b/tests/filecheck/mlir-conversion/with-bindings/symbol_tests.xdsl new file mode 100644 index 0000000000..2730ffc652 --- /dev/null +++ b/tests/filecheck/mlir-conversion/with-bindings/symbol_tests.xdsl @@ -0,0 +1,10 @@ +// RUN: xdsl-opt -t mlir %s | mlir-opt --mlir-print-op-generic > %t-1 && xdsl-opt -t mlir %s | mlir-opt --mlir-print-op-generic > %t-2 && diff %t-1 %t-2 + +// Tests if the non generic form can be printed. + +// CHECK: module { +builtin.module() { + func.func() ["function_type" = !fun<[], []>, "symbol" = @some_symbol, "sym_name" = "symbol_attr", "sym_visibility" = "private"] {} + + func.func() ["function_type" = !fun<[], []>, "value1" = !tensor<[-1 : !index], !i32>, "sym_name" = "unranked_tensor_type", "sym_visibility" = "private"] {} +} diff --git a/xdsl/parser.py b/xdsl/parser.py index cb2fa61d37..0d5193d43f 100644 --- a/xdsl/parser.py +++ b/xdsl/parser.py @@ -943,6 +943,11 @@ def parse_optional_mlir_attribute(self, self.parse_char("]") return ArrayAttr.from_list(contents) + # FlatSymbolRefAttr + if self.parse_optional_char("@"): + symbol_name = self.parse_alpha_num(skip_white_space=False) + return FlatSymbolRefAttr.from_str(symbol_name) + # tensor type if (tensor := self.parse_optional_mlir_tensor()) is not None: return tensor diff --git a/xdsl/printer.py b/xdsl/printer.py index ca027695b0..f2d267e1fd 100644 --- a/xdsl/printer.py +++ b/xdsl/printer.py @@ -408,8 +408,9 @@ def print_one_elem(val: Attribute): attribute = cast(AnyVectorType, attribute) self.print( "vector<" if isinstance(attribute, VectorType) else "tensor<") - self.print_list(attribute.shape.data, - lambda x: self.print(x.value.data), "x") + self.print_list( + attribute.shape.data, lambda x: self.print(x.value.data) + if x.value.data != -1 else self.print("?"), "x") if len(attribute.shape.data) != 0: self.print("x") self.print(attribute.element_type)