Skip to content

Commit

Permalink
documentation: (toy) add toy accelerator dialect (#1342)
Browse files Browse the repository at this point in the history
The main idea is to show how to divert the compilation flow to a toy
accelerator. This is much simpler than the plan with snitch etc, but
still shows the vague idea of matching on affine loops and raising them
to a different builtin operation.

As part of the Toy compilation, we know that the accelerator matches all
the possible affine loops that we create, meaning that we can compile it
end-to-end in python without relying on MLIR.
  • Loading branch information
superlopuh committed Jul 27, 2023
1 parent 7f5d954 commit b3a7e5c
Show file tree
Hide file tree
Showing 9 changed files with 419 additions and 19 deletions.
2 changes: 2 additions & 0 deletions docs/Toy/examples/interpret.toy
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# RUN: python -m toy %s --emit=toy-infer-shapes | filecheck %s
# RUN: python -m toy %s --emit=affine | filecheck %s

# RUN: python -m toy %s --emit=affine --accelerate | filecheck %s

# User defined generic function that operates on unknown shaped arguments
def multiply_transpose(a, b) {
return transpose(a) * transpose(b);
Expand Down
106 changes: 106 additions & 0 deletions docs/Toy/examples/tests/accelerate_toy.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// RUN: python -m toy %s --emit=affine --accelerate --ir | filecheck %s

builtin.module {
func.func @main() {
%0 = "memref.alloc"() {"operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<3x2xf64>
%1 = "memref.alloc"() {"operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<3x2xf64>
%2 = "memref.alloc"() {"operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<3x2xf64>
%3 = arith.constant 1.000000e+00 : f64
%4 = arith.constant 2.000000e+00 : f64
%5 = arith.constant 3.000000e+00 : f64
%6 = arith.constant 4.000000e+00 : f64
%7 = arith.constant 5.000000e+00 : f64
%8 = arith.constant 6.000000e+00 : f64
%9 = "memref.alloc"() {"operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<2x3xf64>
%10 = "memref.alloc"() {"operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<2x3xf64>
"affine.store"(%3, %10) {"map" = affine_map<() -> (0, 0)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%4, %10) {"map" = affine_map<() -> (0, 1)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%5, %10) {"map" = affine_map<() -> (0, 2)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%6, %10) {"map" = affine_map<() -> (1, 0)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%7, %10) {"map" = affine_map<() -> (1, 1)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%8, %10) {"map" = affine_map<() -> (1, 2)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%3, %9) {"map" = affine_map<() -> (0, 0)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%4, %9) {"map" = affine_map<() -> (0, 1)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%5, %9) {"map" = affine_map<() -> (0, 2)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%6, %9) {"map" = affine_map<() -> (1, 0)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%7, %9) {"map" = affine_map<() -> (1, 1)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%8, %9) {"map" = affine_map<() -> (1, 2)>} : (f64, memref<2x3xf64>) -> ()
"affine.for"() ({
^0(%arg0 : index):
"affine.for"() ({
^1(%arg1 : index):
%11 = "affine.load"(%9, %arg1, %arg0) {"map" = affine_map<(d0, d1) -> (d0, d1)>} : (memref<2x3xf64>, index, index) -> f64
"affine.store"(%11, %2, %arg0, %arg1) {"map" = affine_map<(d0, d1) -> (d0, d1)>} : (f64, memref<3x2xf64>, index, index) -> ()
"affine.yield"() : () -> ()
}) {"lower_bound" = affine_map<() -> (0)>, "step" = 1 : index, "upper_bound" = affine_map<() -> (2)>} : () -> ()
"affine.yield"() : () -> ()
}) {"lower_bound" = affine_map<() -> (0)>, "step" = 1 : index, "upper_bound" = affine_map<() -> (3)>} : () -> ()
"affine.for"() ({
^2(%arg0_1 : index):
"affine.for"() ({
^3(%arg1_1 : index):
%12 = "affine.load"(%10, %arg1_1, %arg0_1) {"map" = affine_map<(d0, d1) -> (d0, d1)>} : (memref<2x3xf64>, index, index) -> f64
"affine.store"(%12, %1, %arg0_1, %arg1_1) {"map" = affine_map<(d0, d1) -> (d0, d1)>} : (f64, memref<3x2xf64>, index, index) -> ()
"affine.yield"() : () -> ()
}) {"lower_bound" = affine_map<() -> (0)>, "step" = 1 : index, "upper_bound" = affine_map<() -> (2)>} : () -> ()
"affine.yield"() : () -> ()
}) {"lower_bound" = affine_map<() -> (0)>, "step" = 1 : index, "upper_bound" = affine_map<() -> (3)>} : () -> ()
"affine.for"() ({
^4(%arg0_2 : index):
"affine.for"() ({
^5(%arg1_2 : index):
%13 = "affine.load"(%2, %arg0_2, %arg1_2) {"map" = affine_map<(d0, d1) -> (d0, d1)>} : (memref<3x2xf64>, index, index) -> f64
%14 = "affine.load"(%1, %arg0_2, %arg1_2) {"map" = affine_map<(d0, d1) -> (d0, d1)>} : (memref<3x2xf64>, index, index) -> f64
%15 = arith.mulf %13, %14 : f64
"affine.store"(%15, %0, %arg0_2, %arg1_2) {"map" = affine_map<(d0, d1) -> (d0, d1)>} : (f64, memref<3x2xf64>, index, index) -> ()
"affine.yield"() : () -> ()
}) {"lower_bound" = affine_map<() -> (0)>, "step" = 1 : index, "upper_bound" = affine_map<() -> (2)>} : () -> ()
"affine.yield"() : () -> ()
}) {"lower_bound" = affine_map<() -> (0)>, "step" = 1 : index, "upper_bound" = affine_map<() -> (3)>} : () -> ()
printf.print_format "{}", %0 : memref<3x2xf64>
"memref.dealloc"(%10) : (memref<2x3xf64>) -> ()
"memref.dealloc"(%9) : (memref<2x3xf64>) -> ()
"memref.dealloc"(%2) : (memref<3x2xf64>) -> ()
"memref.dealloc"(%1) : (memref<3x2xf64>) -> ()
"memref.dealloc"(%0) : (memref<3x2xf64>) -> ()
func.return
}
}

// CHECK: builtin.module {
// CHECK-NEXT: func.func @main() {
// CHECK-NEXT: %0 = "memref.alloc"() {"operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<3x2xf64>
// CHECK-NEXT: %1 = "memref.alloc"() {"operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<3x2xf64>
// CHECK-NEXT: %2 = "memref.alloc"() {"operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<3x2xf64>
// CHECK-NEXT: %3 = arith.constant 1.000000e+00 : f64
// CHECK-NEXT: %4 = arith.constant 2.000000e+00 : f64
// CHECK-NEXT: %5 = arith.constant 3.000000e+00 : f64
// CHECK-NEXT: %6 = arith.constant 4.000000e+00 : f64
// CHECK-NEXT: %7 = arith.constant 5.000000e+00 : f64
// CHECK-NEXT: %8 = arith.constant 6.000000e+00 : f64
// CHECK-NEXT: %9 = "memref.alloc"() {"operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<2x3xf64>
// CHECK-NEXT: %10 = "memref.alloc"() {"operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<2x3xf64>
// CHECK-NEXT: "affine.store"(%3, %10) {"map" = affine_map<() -> (0, 0)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%4, %10) {"map" = affine_map<() -> (0, 1)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%5, %10) {"map" = affine_map<() -> (0, 2)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%6, %10) {"map" = affine_map<() -> (1, 0)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%7, %10) {"map" = affine_map<() -> (1, 1)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%8, %10) {"map" = affine_map<() -> (1, 2)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%3, %9) {"map" = affine_map<() -> (0, 0)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%4, %9) {"map" = affine_map<() -> (0, 1)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%5, %9) {"map" = affine_map<() -> (0, 2)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%6, %9) {"map" = affine_map<() -> (1, 0)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%7, %9) {"map" = affine_map<() -> (1, 1)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%8, %9) {"map" = affine_map<() -> (1, 2)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "toy_accelerator.transpose"(%2, %9) {"source_rows" = 2 : index, "source_cols" = 3 : index} : (memref<3x2xf64>, memref<2x3xf64>) -> ()
// CHECK-NEXT: "toy_accelerator.transpose"(%1, %10) {"source_rows" = 2 : index, "source_cols" = 3 : index} : (memref<3x2xf64>, memref<2x3xf64>) -> ()
// CHECK-NEXT: "toy_accelerator.mul"(%0, %2, %1) : (memref<3x2xf64>, memref<3x2xf64>, memref<3x2xf64>) -> ()
// CHECK-NEXT: printf.print_format "{}", %0 : memref<3x2xf64>
// CHECK-NEXT: "memref.dealloc"(%10) : (memref<2x3xf64>) -> ()
// CHECK-NEXT: "memref.dealloc"(%9) : (memref<2x3xf64>) -> ()
// CHECK-NEXT: "memref.dealloc"(%2) : (memref<3x2xf64>) -> ()
// CHECK-NEXT: "memref.dealloc"(%1) : (memref<3x2xf64>) -> ()
// CHECK-NEXT: "memref.dealloc"(%0) : (memref<3x2xf64>) -> ()
// CHECK-NEXT: func.return
// CHECK-NEXT: }
// CHECK-NEXT: }
2 changes: 2 additions & 0 deletions docs/Toy/examples/tests/with-mlir/interpret.toy
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# RUN: python -m toy %s --emit=scf | filecheck %s

# RUN: python -m toy %s --emit=scf --accelerate | filecheck %s

# User defined generic function that operates on unknown shaped arguments
def multiply_transpose(a, b) {
return transpose(a) * transpose(b);
Expand Down
12 changes: 8 additions & 4 deletions docs/Toy/toy/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from xdsl.printer import Printer

from .compiler import context, transform
from .emulator.toy_accelerator_functions import ToyAcceleratorFunctions
from .frontend.ir_gen import IRGen
from .frontend.parser import Parser as ToyParser
from .interpreter import Interpreter, ToyFunctions
Expand All @@ -35,9 +36,10 @@
)
parser.add_argument("--ir", dest="ir", action="store_true")
parser.add_argument("--print-op-generic", dest="print_generic", action="store_true")
parser.add_argument("--accelerate", dest="accelerate", action="store_true")


def main(path: Path, emit: str, ir: bool, print_generic: bool):
def main(path: Path, emit: str, ir: bool, accelerate: bool, print_generic: bool):
ctx = context()

path = args.source
Expand All @@ -60,7 +62,7 @@ def main(path: Path, emit: str, ir: bool, print_generic: bool):
print(f"Unknown file format {path}")
return

transform(ctx, module_op, target=emit)
transform(ctx, module_op, target=emit, accelerate=accelerate)

if ir:
printer = Printer(print_generic_format=print_generic)
Expand All @@ -72,7 +74,9 @@ def main(path: Path, emit: str, ir: bool, print_generic: bool):
interpreter.register_implementations(ToyFunctions())
if emit in ("affine"):
interpreter.register_implementations(AffineFunctions())
if emit in ("affine", "scf", "cf"):
if accelerate and emit in ("affine", "scf"):
interpreter.register_implementations(ToyAcceleratorFunctions())
if emit in ("affine", "scf"):
interpreter.register_implementations(ArithFunctions())
interpreter.register_implementations(MemrefFunctions())
interpreter.register_implementations(PrintfFunctions())
Expand All @@ -85,4 +89,4 @@ def main(path: Path, emit: str, ir: bool, print_generic: bool):

if __name__ == "__main__":
args = parser.parse_args()
main(args.source, args.emit, args.ir, args.print_generic)
main(args.source, args.emit, args.ir, args.accelerate, args.print_generic)
23 changes: 19 additions & 4 deletions docs/Toy/toy/compiler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path

from xdsl.dialects import printf, scf
from xdsl.dialects import affine, arith, func, memref, printf, scf
from xdsl.dialects.builtin import Builtin, ModuleOp
from xdsl.ir import MLContext
from xdsl.transforms.mlir_opt import MLIROptPass
Expand All @@ -9,17 +9,22 @@
from .frontend.ir_gen import IRGen
from .frontend.parser import Parser
from .rewrites.inline_toy import InlineToyPass
from .rewrites.lower_to_toy_accelerator import LowerToToyAccelerator
from .rewrites.lower_toy_affine import LowerToAffinePass
from .rewrites.optimise_toy import OptimiseToy
from .rewrites.shape_inference import ShapeInferencePass


def context() -> MLContext:
ctx = MLContext()
ctx.register_dialect(affine.Affine)
ctx.register_dialect(arith.Arith)
ctx.register_dialect(Builtin)
ctx.register_dialect(toy.Toy)
ctx.register_dialect(scf.Scf)
ctx.register_dialect(func.Func)
ctx.register_dialect(memref.MemRef)
ctx.register_dialect(printf.Printf)
ctx.register_dialect(scf.Scf)
ctx.register_dialect(toy.Toy)
return ctx


Expand All @@ -30,7 +35,13 @@ def parse_toy(program: str, ctx: MLContext | None = None) -> ModuleOp:
return module_op


def transform(ctx: MLContext, module_op: ModuleOp, *, target: str = "toy-infer-shapes"):
def transform(
ctx: MLContext,
module_op: ModuleOp,
*,
target: str = "riscv-assembly",
accelerate: bool,
):
if target == "toy":
return

Expand All @@ -52,6 +63,10 @@ def transform(ctx: MLContext, module_op: ModuleOp, *, target: str = "toy-infer-s
LowerToAffinePass().apply(ctx, module_op)
module_op.verify()

if accelerate:
LowerToToyAccelerator().apply(ctx, module_op)
module_op.verify()

if target == "affine":
return

Expand Down
104 changes: 104 additions & 0 deletions docs/Toy/toy/dialects/toy_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from __future__ import annotations

import abc
from typing import Annotated

from xdsl.dialects.builtin import AnyIntegerAttr, ArrayAttr
from xdsl.dialects.memref import MemRefType
from xdsl.ir import Attribute, Dialect, Operation, SSAValue
from xdsl.irdl import (
ConstraintVar,
IRDLOperation,
attr_def,
irdl_op_definition,
operand_def,
)
from xdsl.utils.exceptions import VerifyException


@irdl_op_definition
class Transpose(IRDLOperation):
name = "toy_accelerator.transpose"

destination = operand_def(MemRefType)
source = operand_def(MemRefType)

source_rows = attr_def(AnyIntegerAttr)
source_cols = attr_def(AnyIntegerAttr)

def __init__(
self,
destination: SSAValue | Operation,
input: SSAValue | Operation,
source_rows: AnyIntegerAttr,
source_cols: AnyIntegerAttr,
):
super().__init__(
operands=(destination, input),
attributes={
"source_rows": source_rows,
"source_cols": source_cols,
},
)

def verify_(self) -> None:
if not isinstance(self.source.type, MemRefType):
raise VerifyException(
f"Invalid transpose source type {self.source.type}, expected MemRefType"
)
if not isinstance(self.destination.type, MemRefType):
raise VerifyException(
f"Invalid transpose destination type {self.destination.type}, expected MemRefType"
)

expected_source_shape = ArrayAttr((self.source_rows, self.source_cols))
source_shape = self.source.type.shape

if source_shape != expected_source_shape:
raise VerifyException(f"Transpose source shape mismatch")

expected_destination_shape = ArrayAttr((self.source_cols, self.source_rows))
destination_shape = self.destination.type.shape

if destination_shape != expected_destination_shape:
raise VerifyException(f"Transpose source shape mismatch")


class BinOp(IRDLOperation, abc.ABC):
"""
An in-place mutating binary operation.
"""

T = Annotated[MemRefType[Attribute], ConstraintVar("T")]

dest = operand_def(T)
lhs = operand_def(T)
rhs = operand_def(T)

def __init__(
self,
dest: SSAValue | Operation,
lhs: SSAValue | Operation,
rhs: SSAValue | Operation,
):
super().__init__(operands=(dest, lhs, rhs))


@irdl_op_definition
class Add(BinOp):
name = "toy_accelerator.add"


@irdl_op_definition
class Mul(BinOp):
name = "toy_accelerator.mul"


ToyAccelerator = Dialect(
[
Transpose,
Add,
Mul,
],
[],
)
51 changes: 51 additions & 0 deletions docs/Toy/toy/emulator/toy_accelerator_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

from typing import Any

from xdsl.interpreter import Interpreter, InterpreterFunctions, impl, register_impls

from ..dialects import toy_accelerator


@register_impls
class ToyAcceleratorFunctions(InterpreterFunctions):
@impl(toy_accelerator.Transpose)
def run_transpose(
self,
interpreter: Interpreter,
op: toy_accelerator.Transpose,
args: tuple[Any, ...],
) -> tuple[Any, ...]:
dest, source = args

source_rows = op.source_rows.value.data
source_cols = op.source_cols.value.data

for row in range(source_rows):
for col in range(source_cols):
value = source.load((row, col))
dest.store((col, row), value)

return ()

@impl(toy_accelerator.Add)
def run_add(
self, interpreter: Interpreter, op: toy_accelerator.Add, args: tuple[Any, ...]
) -> tuple[Any, ...]:
dest, lhs, rhs = args

for i, (l, r) in enumerate(zip(lhs.data, rhs.data)):
dest.data[i] = l + r

return ()

@impl(toy_accelerator.Mul)
def run_mul(
self, interpreter: Interpreter, op: toy_accelerator.Mul, args: tuple[Any, ...]
) -> tuple[Any, ...]:
dest, lhs, rhs = args

for i, (l, r) in enumerate(zip(lhs.data, rhs.data)):
dest.data[i] = l * r

return ()

0 comments on commit b3a7e5c

Please sign in to comment.