Skip to content

Commit

Permalink
transforms: (print) Add a lowering to printf (#1142)
Browse files Browse the repository at this point in the history
This adds the `print-to-printf` pass to lower println ops to printf calls
  • Loading branch information
AntonLydike authored Jun 16, 2023
1 parent 85b16da commit 056a98f
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 3 deletions.
58 changes: 58 additions & 0 deletions tests/dialects/test_print.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from xdsl.dialects import print as print_dialect, test, builtin
from xdsl.transforms import print_to_println
import pytest


@pytest.mark.parametrize(
"given,expected",
(
("test", "test"),
("@hello world", "hello_world"),
("something.with.dots", "something_with_dots"),
("this is 💩", "this_is"),
("123 is a number!", "123_is_a_number"),
),
)
def test_symbol_sanitizer(given: str, expected: str):
assert print_to_println.legalize_str_for_symbol_name(given) == expected


def test_format_str_from_op():
a1, a2 = test.TestOp.create(result_types=[builtin.i32, builtin.f32]).results
op = print_dialect.PrintLnOp("test {} value {}", a1, a2)

parts = print_to_println._format_string_spec_from_print_op( # pyright: ignore[reportPrivateUsage]
op
)

assert list(parts) == [
"test ",
a1,
" value ",
a2,
]

op2 = print_dialect.PrintLnOp("{}", a1)

parts2 = print_to_println._format_string_spec_from_print_op( # pyright: ignore[reportPrivateUsage]
op2
)

assert list(parts2) == [a1]


def test_global_symbol_name_generation():
"""
Check that two strings that are invalid symbol names still result in two distinct
global symbol names.
Similarly, test that the same string results in the same symbol name.
"""
s1 = print_to_println._key_from_str("(") # pyright: ignore[reportPrivateUsage]
s2 = print_to_println._key_from_str(")") # pyright: ignore[reportPrivateUsage]

assert s1 != s2

s3 = print_to_println._key_from_str(")") # pyright: ignore[reportPrivateUsage]

assert s2 == s3
22 changes: 22 additions & 0 deletions tests/filecheck/dialects/print/print_to_printf.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: xdsl-opt %s -p print-to-printf | filecheck %s

builtin.module {
"func.func"() ({
%pi = "arith.constant"() {value = 3.14159:f32} : () -> f32
%12 = "arith.constant"() {value = 12 : i32} : () -> i32

print.println "Hello: {} {}", %pi : f32, %12 : i32

"func.return"() : () -> ()
}) {sym_name = "main", function_type=() -> ()} : () -> ()
}

// CHECK: %{{\d+}} = "llvm.mlir.addressof"() {"global_name" = @Hello_f_842f9d94ff2eba9703926bef3c2bc5f427db9871} : () -> !llvm.ptr

// CHECK: "llvm.call"(%{{\d+}}, %{{\d+}}, %{{\d+}}) {"callee" = @printf{{.*}}} : (!llvm.ptr, f64, i32) -> ()

// CHECK: "llvm.func"() ({
// CHECK-NEXT: }) {"sym_name" = "printf", "function_type" = !llvm.func<void (!llvm.ptr, ...)>, "CConv" = #llvm.cconv<ccc>, "linkage" = #llvm.linkage<"external">, "visibility_" = 0 : i64} : () -> ()

// CHECK: "llvm.mlir.global"() ({
// CHECK-NEXT: }) {"global_type" = !llvm.array<14 x i8>, "sym_name" = "Hello_f_842f9d94ff2eba9703926bef3c2bc5f427db9871", "linkage" = #llvm.linkage<"internal">, "addr_space" = 0 : i32, "constant", "value" = dense<[72, 101, 108, 108, 111, 58, 32, 37, 102, 32, 37, 105, 10, 0]> : tensor<14xi8>} : () -> ()
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// RUN: xdsl-opt %s -p print-to-printf | mlir-opt --test-lower-to-llvm | filecheck %s
// this tests straight to llvmir to verify intended target compatibility

builtin.module {
"func.func"() ({
%pi = "arith.constant"() {value = 3.14159:f32} : () -> f32
%12 = "arith.constant"() {value = 12 : i32} : () -> i32

print.println "Hello: {} {}", %pi : f32, %12 : i32

"func.return"() : () -> ()
}) {sym_name = "main", function_type=() -> ()} : () -> ()
}


// CHECK: llvm.call @printf(%{{\d+}}, %{{\d+}}, %{{\d+}}) : (!llvm.ptr, f64, i32) -> ()

// CHECK: llvm.func @printf(!llvm.ptr, ...)

// CHECK: llvm.mlir.global internal constant @Hello_f_{{\w+}}(dense<[72, 101, 108, 108, 111, 58, 32, 37, 102, 32, 37, 105, 10, 0]> : tensor<14xi8>) {addr_space = 0 : i32} : !llvm.arr
4 changes: 1 addition & 3 deletions xdsl/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,9 +783,7 @@ class AddressOfOp(IRDLOperation):
def get(
global_name: str | StringAttr | SymbolRefAttr, result_type: LLVMPointerType
):
if isinstance(global_name, str):
global_name = StringAttr(global_name)
if isinstance(global_name, StringAttr):
if isinstance(global_name, (StringAttr, str)):
global_name = SymbolRefAttr(global_name)

return AddressOfOp.build(
Expand Down
1 change: 1 addition & 0 deletions xdsl/dialects/print.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class PrintLnOp(IRDLOperation):
"""

name = "print.println"

format_str: builtin.StringAttr = attr_def(builtin.StringAttr)
format_vals: VarOperand = var_operand_def()

Expand Down
165 changes: 165 additions & 0 deletions xdsl/transforms/print_to_println.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from typing import Iterable
import hashlib
import re

from xdsl.ir import SSAValue, Attribute, MLContext, Operation
from xdsl.dialects import print, builtin, arith, llvm
from xdsl.pattern_rewriter import (
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
PatternRewriteWalker,
)

from xdsl.passes import ModulePass


i8 = builtin.IntegerType(8)


def legalize_str_for_symbol_name(val: str):
"""
Takes any string and legalizes it to be a global llvm symbol.
(for the strictest possible interpretation of this)
- Replaces all whitespaces and dots with _
- Deletes all non ascii alphanumerical characters
- Strips all underscores from the start and end of the string
The resulting string consists only of ascii letters, underscores and digits.
This is a surjective mapping, meaning that multiple inputs will produce the same
output. This function alone cannot be used to get a uniquely identifying global
symbol name for a string!
"""
val = re.sub(r"(\s+|\.)", "_", val)
val = re.sub(r"[^A-Za-z0-9_]+", "", val).strip("_")
return val


def _key_from_str(val: str) -> str:
"""
Generate a symbol name from any given string.
Takes the first ten letters of the string plus it's sha1 hash to create a
(pretty much) globally unique symbol name.
"""
h = hashlib.new("sha1")
h.update(val.encode())
return f"{legalize_str_for_symbol_name(val[:10])}_{h.hexdigest()}"


def _format_string_spec_from_print_op(op: print.PrintLnOp) -> Iterable[str | SSAValue]:
"""
Translates the op:
print.println "val = {}, val2 = {}", %1 : i32, %2 : f32
into this sequence:
["val = ", %1, ", val2 = ", %2]
Empty string parts are omitted.
"""
format_str = op.format_str.data.split("{}")
args = iter(op.format_vals)

for part in format_str[:-1]:
if part != "":
yield part
yield next(args)
if format_str[-1] != "":
yield format_str[-1]


def _format_str_for_typ(t: Attribute):
match t:
case builtin.f64:
return "%f"
case builtin.i32:
return "%i"
case builtin.i64:
return "%li"
case _:
raise ValueError(f"Cannot find printf code for {t}")


class PrintlnOpToPrintfCall(RewritePattern):
collected_global_symbs: dict[str, llvm.GlobalOp]

def __init__(self):
self.collected_global_symbs = dict()

def _construct_global(self, val: str):
"""
Constructs an llvm.global operation containing the string. Assigns a unique
symbol name to the value that is derived from the string value.
"""
data = val.encode() + b"\x00"

t_type = builtin.TensorType.from_type_and_list(i8, [len(data)])

return llvm.GlobalOp.get(
llvm.LLVMArrayType.from_size_and_type(len(data), i8),
_key_from_str(val),
constant=True,
linkage="internal",
value=builtin.DenseIntOrFPElementsAttr.from_list(t_type, data),
)

@op_type_rewrite_pattern
def match_and_rewrite(self, op: print.PrintLnOp, rewriter: PatternRewriter, /):
format_str = ""
args: list[SSAValue] = []
casts: list[Operation] = []
# make sure all arguments are in the format libc expects them to be
# e.g. floats must be promoted to double before calling
for part in _format_string_spec_from_print_op(op):
if isinstance(part, str):
format_str += part
elif isinstance(part.typ, builtin.IndexType):
# index must be cast to fixed bitwidth before printing
casts.append(new_val := arith.IndexCastOp.get(part, builtin.i64))
args.append(new_val.result)
format_str += "%li"
elif part.typ == builtin.f32:
# f32 must be promoted to f64 before printing
casts.append(new_val := arith.ExtFOp.get(part, builtin.f64))
args.append(new_val.result)
format_str += "%f"
else:
args.append(part)
format_str += _format_str_for_typ(part.typ)

globl = self._construct_global(format_str + "\n")
self.collected_global_symbs[globl.sym_name.data] = globl

rewriter.replace_matched_op(
casts
+ [
ptr := llvm.AddressOfOp.get(
globl.sym_name, llvm.LLVMPointerType.opaque()
),
llvm.CallOp("printf", ptr.result, *args),
]
)


class PrintToPrintf(ModulePass):
name = "print-to-printf"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
add_printf_call = PrintlnOpToPrintfCall()

PatternRewriteWalker(add_printf_call).rewrite_module(op)

op.body.block.add_ops(
[
llvm.FuncOp(
"printf",
llvm.LLVMFunctionType(
[llvm.LLVMPointerType.opaque()], is_variadic=True
),
linkage=llvm.LinkageAttr("external"),
),
*add_printf_call.collected_global_symbs.values(),
]
)
2 changes: 2 additions & 0 deletions xdsl/xdsl_opt_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from xdsl.transforms.experimental.dmp.scatter_gather import (
DmpScatterGatherTrivialLowering,
)
from xdsl.transforms.print_to_println import PrintToPrintf

from xdsl.utils.exceptions import DiagnosticException
from xdsl.utils.parse_pipeline import parse_pipeline
Expand Down Expand Up @@ -109,6 +110,7 @@ def get_all_passes() -> list[type[ModulePass]]:
LowerRISCVFunc,
LowerSnitchPass,
LowerSnitchRuntimePass,
PrintToPrintf,
RISCVRegisterAllocation,
StencilShapeInferencePass,
StencilStorageMaterializationPass,
Expand Down

0 comments on commit 056a98f

Please sign in to comment.