-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
transforms: (print) Add a lowering to printf (#1142)
This adds the `print-to-printf` pass to lower println ops to printf calls
- Loading branch information
1 parent
85b16da
commit 056a98f
Showing
7 changed files
with
269 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from xdsl.dialects import print as print_dialect, test, builtin | ||
from xdsl.transforms import print_to_println | ||
import pytest | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"given,expected", | ||
( | ||
("test", "test"), | ||
("@hello world", "hello_world"), | ||
("something.with.dots", "something_with_dots"), | ||
("this is 💩", "this_is"), | ||
("123 is a number!", "123_is_a_number"), | ||
), | ||
) | ||
def test_symbol_sanitizer(given: str, expected: str): | ||
assert print_to_println.legalize_str_for_symbol_name(given) == expected | ||
|
||
|
||
def test_format_str_from_op(): | ||
a1, a2 = test.TestOp.create(result_types=[builtin.i32, builtin.f32]).results | ||
op = print_dialect.PrintLnOp("test {} value {}", a1, a2) | ||
|
||
parts = print_to_println._format_string_spec_from_print_op( # pyright: ignore[reportPrivateUsage] | ||
op | ||
) | ||
|
||
assert list(parts) == [ | ||
"test ", | ||
a1, | ||
" value ", | ||
a2, | ||
] | ||
|
||
op2 = print_dialect.PrintLnOp("{}", a1) | ||
|
||
parts2 = print_to_println._format_string_spec_from_print_op( # pyright: ignore[reportPrivateUsage] | ||
op2 | ||
) | ||
|
||
assert list(parts2) == [a1] | ||
|
||
|
||
def test_global_symbol_name_generation(): | ||
""" | ||
Check that two strings that are invalid symbol names still result in two distinct | ||
global symbol names. | ||
Similarly, test that the same string results in the same symbol name. | ||
""" | ||
s1 = print_to_println._key_from_str("(") # pyright: ignore[reportPrivateUsage] | ||
s2 = print_to_println._key_from_str(")") # pyright: ignore[reportPrivateUsage] | ||
|
||
assert s1 != s2 | ||
|
||
s3 = print_to_println._key_from_str(")") # pyright: ignore[reportPrivateUsage] | ||
|
||
assert s2 == s3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
// RUN: xdsl-opt %s -p print-to-printf | filecheck %s | ||
|
||
builtin.module { | ||
"func.func"() ({ | ||
%pi = "arith.constant"() {value = 3.14159:f32} : () -> f32 | ||
%12 = "arith.constant"() {value = 12 : i32} : () -> i32 | ||
|
||
print.println "Hello: {} {}", %pi : f32, %12 : i32 | ||
|
||
"func.return"() : () -> () | ||
}) {sym_name = "main", function_type=() -> ()} : () -> () | ||
} | ||
|
||
// CHECK: %{{\d+}} = "llvm.mlir.addressof"() {"global_name" = @Hello_f_842f9d94ff2eba9703926bef3c2bc5f427db9871} : () -> !llvm.ptr | ||
|
||
// CHECK: "llvm.call"(%{{\d+}}, %{{\d+}}, %{{\d+}}) {"callee" = @printf{{.*}}} : (!llvm.ptr, f64, i32) -> () | ||
|
||
// CHECK: "llvm.func"() ({ | ||
// CHECK-NEXT: }) {"sym_name" = "printf", "function_type" = !llvm.func<void (!llvm.ptr, ...)>, "CConv" = #llvm.cconv<ccc>, "linkage" = #llvm.linkage<"external">, "visibility_" = 0 : i64} : () -> () | ||
|
||
// CHECK: "llvm.mlir.global"() ({ | ||
// CHECK-NEXT: }) {"global_type" = !llvm.array<14 x i8>, "sym_name" = "Hello_f_842f9d94ff2eba9703926bef3c2bc5f427db9871", "linkage" = #llvm.linkage<"internal">, "addr_space" = 0 : i32, "constant", "value" = dense<[72, 101, 108, 108, 111, 58, 32, 37, 102, 32, 37, 105, 10, 0]> : tensor<14xi8>} : () -> () |
20 changes: 20 additions & 0 deletions
20
tests/filecheck/mlir-conversion/with-mlir/dialects/print/print_to_printf.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
// RUN: xdsl-opt %s -p print-to-printf | mlir-opt --test-lower-to-llvm | filecheck %s | ||
// this tests straight to llvmir to verify intended target compatibility | ||
|
||
builtin.module { | ||
"func.func"() ({ | ||
%pi = "arith.constant"() {value = 3.14159:f32} : () -> f32 | ||
%12 = "arith.constant"() {value = 12 : i32} : () -> i32 | ||
|
||
print.println "Hello: {} {}", %pi : f32, %12 : i32 | ||
|
||
"func.return"() : () -> () | ||
}) {sym_name = "main", function_type=() -> ()} : () -> () | ||
} | ||
|
||
|
||
// CHECK: llvm.call @printf(%{{\d+}}, %{{\d+}}, %{{\d+}}) : (!llvm.ptr, f64, i32) -> () | ||
|
||
// CHECK: llvm.func @printf(!llvm.ptr, ...) | ||
|
||
// CHECK: llvm.mlir.global internal constant @Hello_f_{{\w+}}(dense<[72, 101, 108, 108, 111, 58, 32, 37, 102, 32, 37, 105, 10, 0]> : tensor<14xi8>) {addr_space = 0 : i32} : !llvm.arr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
from typing import Iterable | ||
import hashlib | ||
import re | ||
|
||
from xdsl.ir import SSAValue, Attribute, MLContext, Operation | ||
from xdsl.dialects import print, builtin, arith, llvm | ||
from xdsl.pattern_rewriter import ( | ||
PatternRewriter, | ||
RewritePattern, | ||
op_type_rewrite_pattern, | ||
PatternRewriteWalker, | ||
) | ||
|
||
from xdsl.passes import ModulePass | ||
|
||
|
||
i8 = builtin.IntegerType(8) | ||
|
||
|
||
def legalize_str_for_symbol_name(val: str): | ||
""" | ||
Takes any string and legalizes it to be a global llvm symbol. | ||
(for the strictest possible interpretation of this) | ||
- Replaces all whitespaces and dots with _ | ||
- Deletes all non ascii alphanumerical characters | ||
- Strips all underscores from the start and end of the string | ||
The resulting string consists only of ascii letters, underscores and digits. | ||
This is a surjective mapping, meaning that multiple inputs will produce the same | ||
output. This function alone cannot be used to get a uniquely identifying global | ||
symbol name for a string! | ||
""" | ||
val = re.sub(r"(\s+|\.)", "_", val) | ||
val = re.sub(r"[^A-Za-z0-9_]+", "", val).strip("_") | ||
return val | ||
|
||
|
||
def _key_from_str(val: str) -> str: | ||
""" | ||
Generate a symbol name from any given string. | ||
Takes the first ten letters of the string plus it's sha1 hash to create a | ||
(pretty much) globally unique symbol name. | ||
""" | ||
h = hashlib.new("sha1") | ||
h.update(val.encode()) | ||
return f"{legalize_str_for_symbol_name(val[:10])}_{h.hexdigest()}" | ||
|
||
|
||
def _format_string_spec_from_print_op(op: print.PrintLnOp) -> Iterable[str | SSAValue]: | ||
""" | ||
Translates the op: | ||
print.println "val = {}, val2 = {}", %1 : i32, %2 : f32 | ||
into this sequence: | ||
["val = ", %1, ", val2 = ", %2] | ||
Empty string parts are omitted. | ||
""" | ||
format_str = op.format_str.data.split("{}") | ||
args = iter(op.format_vals) | ||
|
||
for part in format_str[:-1]: | ||
if part != "": | ||
yield part | ||
yield next(args) | ||
if format_str[-1] != "": | ||
yield format_str[-1] | ||
|
||
|
||
def _format_str_for_typ(t: Attribute): | ||
match t: | ||
case builtin.f64: | ||
return "%f" | ||
case builtin.i32: | ||
return "%i" | ||
case builtin.i64: | ||
return "%li" | ||
case _: | ||
raise ValueError(f"Cannot find printf code for {t}") | ||
|
||
|
||
class PrintlnOpToPrintfCall(RewritePattern): | ||
collected_global_symbs: dict[str, llvm.GlobalOp] | ||
|
||
def __init__(self): | ||
self.collected_global_symbs = dict() | ||
|
||
def _construct_global(self, val: str): | ||
""" | ||
Constructs an llvm.global operation containing the string. Assigns a unique | ||
symbol name to the value that is derived from the string value. | ||
""" | ||
data = val.encode() + b"\x00" | ||
|
||
t_type = builtin.TensorType.from_type_and_list(i8, [len(data)]) | ||
|
||
return llvm.GlobalOp.get( | ||
llvm.LLVMArrayType.from_size_and_type(len(data), i8), | ||
_key_from_str(val), | ||
constant=True, | ||
linkage="internal", | ||
value=builtin.DenseIntOrFPElementsAttr.from_list(t_type, data), | ||
) | ||
|
||
@op_type_rewrite_pattern | ||
def match_and_rewrite(self, op: print.PrintLnOp, rewriter: PatternRewriter, /): | ||
format_str = "" | ||
args: list[SSAValue] = [] | ||
casts: list[Operation] = [] | ||
# make sure all arguments are in the format libc expects them to be | ||
# e.g. floats must be promoted to double before calling | ||
for part in _format_string_spec_from_print_op(op): | ||
if isinstance(part, str): | ||
format_str += part | ||
elif isinstance(part.typ, builtin.IndexType): | ||
# index must be cast to fixed bitwidth before printing | ||
casts.append(new_val := arith.IndexCastOp.get(part, builtin.i64)) | ||
args.append(new_val.result) | ||
format_str += "%li" | ||
elif part.typ == builtin.f32: | ||
# f32 must be promoted to f64 before printing | ||
casts.append(new_val := arith.ExtFOp.get(part, builtin.f64)) | ||
args.append(new_val.result) | ||
format_str += "%f" | ||
else: | ||
args.append(part) | ||
format_str += _format_str_for_typ(part.typ) | ||
|
||
globl = self._construct_global(format_str + "\n") | ||
self.collected_global_symbs[globl.sym_name.data] = globl | ||
|
||
rewriter.replace_matched_op( | ||
casts | ||
+ [ | ||
ptr := llvm.AddressOfOp.get( | ||
globl.sym_name, llvm.LLVMPointerType.opaque() | ||
), | ||
llvm.CallOp("printf", ptr.result, *args), | ||
] | ||
) | ||
|
||
|
||
class PrintToPrintf(ModulePass): | ||
name = "print-to-printf" | ||
|
||
def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: | ||
add_printf_call = PrintlnOpToPrintfCall() | ||
|
||
PatternRewriteWalker(add_printf_call).rewrite_module(op) | ||
|
||
op.body.block.add_ops( | ||
[ | ||
llvm.FuncOp( | ||
"printf", | ||
llvm.LLVMFunctionType( | ||
[llvm.LLVMPointerType.opaque()], is_variadic=True | ||
), | ||
linkage=llvm.LinkageAttr("external"), | ||
), | ||
*add_printf_call.collected_global_symbs.values(), | ||
] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters