diff --git a/tests/dialects/test_print.py b/tests/dialects/test_print.py new file mode 100644 index 0000000000..4a04059f16 --- /dev/null +++ b/tests/dialects/test_print.py @@ -0,0 +1,58 @@ +from xdsl.dialects import print as print_dialect, test, builtin +from xdsl.transforms import print_to_println +import pytest + + +@pytest.mark.parametrize( + "given,expected", + ( + ("test", "test"), + ("@hello world", "hello_world"), + ("something.with.dots", "something_with_dots"), + ("this is 💩", "this_is"), + ("123 is a number!", "123_is_a_number"), + ), +) +def test_symbol_sanitizer(given: str, expected: str): + assert print_to_println.legalize_str_for_symbol_name(given) == expected + + +def test_format_str_from_op(): + a1, a2 = test.TestOp.create(result_types=[builtin.i32, builtin.f32]).results + op = print_dialect.PrintLnOp("test {} value {}", a1, a2) + + parts = print_to_println._format_string_spec_from_print_op( # pyright: ignore[reportPrivateUsage] + op + ) + + assert list(parts) == [ + "test ", + a1, + " value ", + a2, + ] + + op2 = print_dialect.PrintLnOp("{}", a1) + + parts2 = print_to_println._format_string_spec_from_print_op( # pyright: ignore[reportPrivateUsage] + op2 + ) + + assert list(parts2) == [a1] + + +def test_global_symbol_name_generation(): + """ + Check that two strings that are invalid symbol names still result in two distinct + global symbol names. + + Similarly, test that the same string results in the same symbol name. + """ + s1 = print_to_println._key_from_str("(") # pyright: ignore[reportPrivateUsage] + s2 = print_to_println._key_from_str(")") # pyright: ignore[reportPrivateUsage] + + assert s1 != s2 + + s3 = print_to_println._key_from_str(")") # pyright: ignore[reportPrivateUsage] + + assert s2 == s3 diff --git a/tests/filecheck/dialects/print/print_to_printf.mlir b/tests/filecheck/dialects/print/print_to_printf.mlir new file mode 100644 index 0000000000..88281f1f6b --- /dev/null +++ b/tests/filecheck/dialects/print/print_to_printf.mlir @@ -0,0 +1,22 @@ +// RUN: xdsl-opt %s -p print-to-printf | filecheck %s + +builtin.module { + "func.func"() ({ + %pi = "arith.constant"() {value = 3.14159:f32} : () -> f32 + %12 = "arith.constant"() {value = 12 : i32} : () -> i32 + + print.println "Hello: {} {}", %pi : f32, %12 : i32 + + "func.return"() : () -> () + }) {sym_name = "main", function_type=() -> ()} : () -> () +} + +// CHECK: %{{\d+}} = "llvm.mlir.addressof"() {"global_name" = @Hello_f_842f9d94ff2eba9703926bef3c2bc5f427db9871} : () -> !llvm.ptr + +// CHECK: "llvm.call"(%{{\d+}}, %{{\d+}}, %{{\d+}}) {"callee" = @printf{{.*}}} : (!llvm.ptr, f64, i32) -> () + +// CHECK: "llvm.func"() ({ +// CHECK-NEXT: }) {"sym_name" = "printf", "function_type" = !llvm.func, "CConv" = #llvm.cconv, "linkage" = #llvm.linkage<"external">, "visibility_" = 0 : i64} : () -> () + +// CHECK: "llvm.mlir.global"() ({ +// CHECK-NEXT: }) {"global_type" = !llvm.array<14 x i8>, "sym_name" = "Hello_f_842f9d94ff2eba9703926bef3c2bc5f427db9871", "linkage" = #llvm.linkage<"internal">, "addr_space" = 0 : i32, "constant", "value" = dense<[72, 101, 108, 108, 111, 58, 32, 37, 102, 32, 37, 105, 10, 0]> : tensor<14xi8>} : () -> () diff --git a/tests/filecheck/mlir-conversion/with-mlir/dialects/print/print_to_printf.mlir b/tests/filecheck/mlir-conversion/with-mlir/dialects/print/print_to_printf.mlir new file mode 100644 index 0000000000..f3c13ee7bb --- /dev/null +++ b/tests/filecheck/mlir-conversion/with-mlir/dialects/print/print_to_printf.mlir @@ -0,0 +1,20 @@ +// RUN: xdsl-opt %s -p print-to-printf | mlir-opt --test-lower-to-llvm | filecheck %s +// this tests straight to llvmir to verify intended target compatibility + +builtin.module { + "func.func"() ({ + %pi = "arith.constant"() {value = 3.14159:f32} : () -> f32 + %12 = "arith.constant"() {value = 12 : i32} : () -> i32 + + print.println "Hello: {} {}", %pi : f32, %12 : i32 + + "func.return"() : () -> () + }) {sym_name = "main", function_type=() -> ()} : () -> () +} + + +// CHECK: llvm.call @printf(%{{\d+}}, %{{\d+}}, %{{\d+}}) : (!llvm.ptr, f64, i32) -> () + +// CHECK: llvm.func @printf(!llvm.ptr, ...) + +// CHECK: llvm.mlir.global internal constant @Hello_f_{{\w+}}(dense<[72, 101, 108, 108, 111, 58, 32, 37, 102, 32, 37, 105, 10, 0]> : tensor<14xi8>) {addr_space = 0 : i32} : !llvm.arr diff --git a/xdsl/dialects/llvm.py b/xdsl/dialects/llvm.py index 7c032e1e55..5a953a79da 100644 --- a/xdsl/dialects/llvm.py +++ b/xdsl/dialects/llvm.py @@ -783,9 +783,7 @@ class AddressOfOp(IRDLOperation): def get( global_name: str | StringAttr | SymbolRefAttr, result_type: LLVMPointerType ): - if isinstance(global_name, str): - global_name = StringAttr(global_name) - if isinstance(global_name, StringAttr): + if isinstance(global_name, (StringAttr, str)): global_name = SymbolRefAttr(global_name) return AddressOfOp.build( diff --git a/xdsl/dialects/print.py b/xdsl/dialects/print.py index 77c96174b6..88192facc8 100644 --- a/xdsl/dialects/print.py +++ b/xdsl/dialects/print.py @@ -26,6 +26,7 @@ class PrintLnOp(IRDLOperation): """ name = "print.println" + format_str: builtin.StringAttr = attr_def(builtin.StringAttr) format_vals: VarOperand = var_operand_def() diff --git a/xdsl/transforms/print_to_println.py b/xdsl/transforms/print_to_println.py new file mode 100644 index 0000000000..6410d729fc --- /dev/null +++ b/xdsl/transforms/print_to_println.py @@ -0,0 +1,165 @@ +from typing import Iterable +import hashlib +import re + +from xdsl.ir import SSAValue, Attribute, MLContext, Operation +from xdsl.dialects import print, builtin, arith, llvm +from xdsl.pattern_rewriter import ( + PatternRewriter, + RewritePattern, + op_type_rewrite_pattern, + PatternRewriteWalker, +) + +from xdsl.passes import ModulePass + + +i8 = builtin.IntegerType(8) + + +def legalize_str_for_symbol_name(val: str): + """ + Takes any string and legalizes it to be a global llvm symbol. + (for the strictest possible interpretation of this) + + - Replaces all whitespaces and dots with _ + - Deletes all non ascii alphanumerical characters + - Strips all underscores from the start and end of the string + + The resulting string consists only of ascii letters, underscores and digits. + + This is a surjective mapping, meaning that multiple inputs will produce the same + output. This function alone cannot be used to get a uniquely identifying global + symbol name for a string! + """ + val = re.sub(r"(\s+|\.)", "_", val) + val = re.sub(r"[^A-Za-z0-9_]+", "", val).strip("_") + return val + + +def _key_from_str(val: str) -> str: + """ + Generate a symbol name from any given string. + + Takes the first ten letters of the string plus it's sha1 hash to create a + (pretty much) globally unique symbol name. + """ + h = hashlib.new("sha1") + h.update(val.encode()) + return f"{legalize_str_for_symbol_name(val[:10])}_{h.hexdigest()}" + + +def _format_string_spec_from_print_op(op: print.PrintLnOp) -> Iterable[str | SSAValue]: + """ + Translates the op: + print.println "val = {}, val2 = {}", %1 : i32, %2 : f32 + + into this sequence: + ["val = ", %1, ", val2 = ", %2] + + Empty string parts are omitted. + """ + format_str = op.format_str.data.split("{}") + args = iter(op.format_vals) + + for part in format_str[:-1]: + if part != "": + yield part + yield next(args) + if format_str[-1] != "": + yield format_str[-1] + + +def _format_str_for_typ(t: Attribute): + match t: + case builtin.f64: + return "%f" + case builtin.i32: + return "%i" + case builtin.i64: + return "%li" + case _: + raise ValueError(f"Cannot find printf code for {t}") + + +class PrintlnOpToPrintfCall(RewritePattern): + collected_global_symbs: dict[str, llvm.GlobalOp] + + def __init__(self): + self.collected_global_symbs = dict() + + def _construct_global(self, val: str): + """ + Constructs an llvm.global operation containing the string. Assigns a unique + symbol name to the value that is derived from the string value. + """ + data = val.encode() + b"\x00" + + t_type = builtin.TensorType.from_type_and_list(i8, [len(data)]) + + return llvm.GlobalOp.get( + llvm.LLVMArrayType.from_size_and_type(len(data), i8), + _key_from_str(val), + constant=True, + linkage="internal", + value=builtin.DenseIntOrFPElementsAttr.from_list(t_type, data), + ) + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: print.PrintLnOp, rewriter: PatternRewriter, /): + format_str = "" + args: list[SSAValue] = [] + casts: list[Operation] = [] + # make sure all arguments are in the format libc expects them to be + # e.g. floats must be promoted to double before calling + for part in _format_string_spec_from_print_op(op): + if isinstance(part, str): + format_str += part + elif isinstance(part.typ, builtin.IndexType): + # index must be cast to fixed bitwidth before printing + casts.append(new_val := arith.IndexCastOp.get(part, builtin.i64)) + args.append(new_val.result) + format_str += "%li" + elif part.typ == builtin.f32: + # f32 must be promoted to f64 before printing + casts.append(new_val := arith.ExtFOp.get(part, builtin.f64)) + args.append(new_val.result) + format_str += "%f" + else: + args.append(part) + format_str += _format_str_for_typ(part.typ) + + globl = self._construct_global(format_str + "\n") + self.collected_global_symbs[globl.sym_name.data] = globl + + rewriter.replace_matched_op( + casts + + [ + ptr := llvm.AddressOfOp.get( + globl.sym_name, llvm.LLVMPointerType.opaque() + ), + llvm.CallOp("printf", ptr.result, *args), + ] + ) + + +class PrintToPrintf(ModulePass): + name = "print-to-printf" + + def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: + add_printf_call = PrintlnOpToPrintfCall() + + PatternRewriteWalker(add_printf_call).rewrite_module(op) + + op.body.block.add_ops( + [ + llvm.FuncOp( + "printf", + llvm.LLVMFunctionType( + [llvm.LLVMPointerType.opaque()], is_variadic=True + ), + linkage=llvm.LinkageAttr("external"), + ), + *add_printf_call.collected_global_symbs.values(), + ] + ) diff --git a/xdsl/xdsl_opt_main.py b/xdsl/xdsl_opt_main.py index 2398d622e1..5564a8f588 100644 --- a/xdsl/xdsl_opt_main.py +++ b/xdsl/xdsl_opt_main.py @@ -58,6 +58,7 @@ from xdsl.transforms.experimental.dmp.scatter_gather import ( DmpScatterGatherTrivialLowering, ) +from xdsl.transforms.print_to_println import PrintToPrintf from xdsl.utils.exceptions import DiagnosticException from xdsl.utils.parse_pipeline import parse_pipeline @@ -109,6 +110,7 @@ def get_all_passes() -> list[type[ModulePass]]: LowerRISCVFunc, LowerSnitchPass, LowerSnitchRuntimePass, + PrintToPrintf, RISCVRegisterAllocation, StencilShapeInferencePass, StencilStorageMaterializationPass,