Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

documentation: (toy) add toy accelerator dialect #1342

Merged
merged 3 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/Toy/examples/interpret.toy
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# RUN: python -m toy %s --emit=toy-infer-shapes | filecheck %s
# RUN: python -m toy %s --emit=affine | filecheck %s

# RUN: python -m toy %s --emit=affine --accelerate | filecheck %s
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you check that this run? I'm not sure if things run if they are not in the first lines.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an assert False, seems to work ok


# User defined generic function that operates on unknown shaped arguments
def multiply_transpose(a, b) {
return transpose(a) * transpose(b);
Expand Down
106 changes: 106 additions & 0 deletions docs/Toy/examples/tests/accelerate_toy.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// RUN: python -m toy %s --emit=affine --accelerate --ir | filecheck %s

builtin.module {
func.func @main() {
%0 = "memref.alloc"() {"operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<3x2xf64>
%1 = "memref.alloc"() {"operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<3x2xf64>
%2 = "memref.alloc"() {"operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<3x2xf64>
%3 = arith.constant 1.000000e+00 : f64
%4 = arith.constant 2.000000e+00 : f64
%5 = arith.constant 3.000000e+00 : f64
%6 = arith.constant 4.000000e+00 : f64
%7 = arith.constant 5.000000e+00 : f64
%8 = arith.constant 6.000000e+00 : f64
%9 = "memref.alloc"() {"operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<2x3xf64>
%10 = "memref.alloc"() {"operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<2x3xf64>
"affine.store"(%3, %10) {"map" = affine_map<() -> (0, 0)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%4, %10) {"map" = affine_map<() -> (0, 1)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%5, %10) {"map" = affine_map<() -> (0, 2)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%6, %10) {"map" = affine_map<() -> (1, 0)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%7, %10) {"map" = affine_map<() -> (1, 1)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%8, %10) {"map" = affine_map<() -> (1, 2)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%3, %9) {"map" = affine_map<() -> (0, 0)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%4, %9) {"map" = affine_map<() -> (0, 1)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%5, %9) {"map" = affine_map<() -> (0, 2)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%6, %9) {"map" = affine_map<() -> (1, 0)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%7, %9) {"map" = affine_map<() -> (1, 1)>} : (f64, memref<2x3xf64>) -> ()
"affine.store"(%8, %9) {"map" = affine_map<() -> (1, 2)>} : (f64, memref<2x3xf64>) -> ()
"affine.for"() ({
^0(%arg0 : index):
"affine.for"() ({
^1(%arg1 : index):
%11 = "affine.load"(%9, %arg1, %arg0) {"map" = affine_map<(d0, d1) -> (d0, d1)>} : (memref<2x3xf64>, index, index) -> f64
"affine.store"(%11, %2, %arg0, %arg1) {"map" = affine_map<(d0, d1) -> (d0, d1)>} : (f64, memref<3x2xf64>, index, index) -> ()
"affine.yield"() : () -> ()
}) {"lower_bound" = affine_map<() -> (0)>, "step" = 1 : index, "upper_bound" = affine_map<() -> (2)>} : () -> ()
"affine.yield"() : () -> ()
}) {"lower_bound" = affine_map<() -> (0)>, "step" = 1 : index, "upper_bound" = affine_map<() -> (3)>} : () -> ()
"affine.for"() ({
^2(%arg0_1 : index):
"affine.for"() ({
^3(%arg1_1 : index):
%12 = "affine.load"(%10, %arg1_1, %arg0_1) {"map" = affine_map<(d0, d1) -> (d0, d1)>} : (memref<2x3xf64>, index, index) -> f64
"affine.store"(%12, %1, %arg0_1, %arg1_1) {"map" = affine_map<(d0, d1) -> (d0, d1)>} : (f64, memref<3x2xf64>, index, index) -> ()
"affine.yield"() : () -> ()
}) {"lower_bound" = affine_map<() -> (0)>, "step" = 1 : index, "upper_bound" = affine_map<() -> (2)>} : () -> ()
"affine.yield"() : () -> ()
}) {"lower_bound" = affine_map<() -> (0)>, "step" = 1 : index, "upper_bound" = affine_map<() -> (3)>} : () -> ()
"affine.for"() ({
^4(%arg0_2 : index):
"affine.for"() ({
^5(%arg1_2 : index):
%13 = "affine.load"(%2, %arg0_2, %arg1_2) {"map" = affine_map<(d0, d1) -> (d0, d1)>} : (memref<3x2xf64>, index, index) -> f64
%14 = "affine.load"(%1, %arg0_2, %arg1_2) {"map" = affine_map<(d0, d1) -> (d0, d1)>} : (memref<3x2xf64>, index, index) -> f64
%15 = arith.mulf %13, %14 : f64
"affine.store"(%15, %0, %arg0_2, %arg1_2) {"map" = affine_map<(d0, d1) -> (d0, d1)>} : (f64, memref<3x2xf64>, index, index) -> ()
"affine.yield"() : () -> ()
}) {"lower_bound" = affine_map<() -> (0)>, "step" = 1 : index, "upper_bound" = affine_map<() -> (2)>} : () -> ()
"affine.yield"() : () -> ()
}) {"lower_bound" = affine_map<() -> (0)>, "step" = 1 : index, "upper_bound" = affine_map<() -> (3)>} : () -> ()
printf.print_format "{}", %0 : memref<3x2xf64>
"memref.dealloc"(%10) : (memref<2x3xf64>) -> ()
"memref.dealloc"(%9) : (memref<2x3xf64>) -> ()
"memref.dealloc"(%2) : (memref<3x2xf64>) -> ()
"memref.dealloc"(%1) : (memref<3x2xf64>) -> ()
"memref.dealloc"(%0) : (memref<3x2xf64>) -> ()
func.return
}
}

// CHECK: builtin.module {
// CHECK-NEXT: func.func @main() {
// CHECK-NEXT: %0 = "memref.alloc"() {"operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<3x2xf64>
// CHECK-NEXT: %1 = "memref.alloc"() {"operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<3x2xf64>
// CHECK-NEXT: %2 = "memref.alloc"() {"operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<3x2xf64>
// CHECK-NEXT: %3 = arith.constant 1.000000e+00 : f64
// CHECK-NEXT: %4 = arith.constant 2.000000e+00 : f64
// CHECK-NEXT: %5 = arith.constant 3.000000e+00 : f64
// CHECK-NEXT: %6 = arith.constant 4.000000e+00 : f64
// CHECK-NEXT: %7 = arith.constant 5.000000e+00 : f64
// CHECK-NEXT: %8 = arith.constant 6.000000e+00 : f64
// CHECK-NEXT: %9 = "memref.alloc"() {"operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<2x3xf64>
// CHECK-NEXT: %10 = "memref.alloc"() {"operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<2x3xf64>
// CHECK-NEXT: "affine.store"(%3, %10) {"map" = affine_map<() -> (0, 0)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%4, %10) {"map" = affine_map<() -> (0, 1)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%5, %10) {"map" = affine_map<() -> (0, 2)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%6, %10) {"map" = affine_map<() -> (1, 0)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%7, %10) {"map" = affine_map<() -> (1, 1)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%8, %10) {"map" = affine_map<() -> (1, 2)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%3, %9) {"map" = affine_map<() -> (0, 0)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%4, %9) {"map" = affine_map<() -> (0, 1)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%5, %9) {"map" = affine_map<() -> (0, 2)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%6, %9) {"map" = affine_map<() -> (1, 0)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%7, %9) {"map" = affine_map<() -> (1, 1)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "affine.store"(%8, %9) {"map" = affine_map<() -> (1, 2)>} : (f64, memref<2x3xf64>) -> ()
// CHECK-NEXT: "toy_accelerator.transpose"(%2, %9) {"source_rows" = 2 : index, "source_cols" = 3 : index} : (memref<3x2xf64>, memref<2x3xf64>) -> ()
// CHECK-NEXT: "toy_accelerator.transpose"(%1, %10) {"source_rows" = 2 : index, "source_cols" = 3 : index} : (memref<3x2xf64>, memref<2x3xf64>) -> ()
// CHECK-NEXT: "toy_accelerator.mul"(%0, %2, %1) : (memref<3x2xf64>, memref<3x2xf64>, memref<3x2xf64>) -> ()
// CHECK-NEXT: printf.print_format "{}", %0 : memref<3x2xf64>
// CHECK-NEXT: "memref.dealloc"(%10) : (memref<2x3xf64>) -> ()
// CHECK-NEXT: "memref.dealloc"(%9) : (memref<2x3xf64>) -> ()
// CHECK-NEXT: "memref.dealloc"(%2) : (memref<3x2xf64>) -> ()
// CHECK-NEXT: "memref.dealloc"(%1) : (memref<3x2xf64>) -> ()
// CHECK-NEXT: "memref.dealloc"(%0) : (memref<3x2xf64>) -> ()
// CHECK-NEXT: func.return
// CHECK-NEXT: }
// CHECK-NEXT: }
2 changes: 2 additions & 0 deletions docs/Toy/examples/tests/with-mlir/interpret.toy
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# RUN: python -m toy %s --emit=scf | filecheck %s

# RUN: python -m toy %s --emit=scf --accelerate | filecheck %s
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here


# User defined generic function that operates on unknown shaped arguments
def multiply_transpose(a, b) {
return transpose(a) * transpose(b);
Expand Down
12 changes: 8 additions & 4 deletions docs/Toy/toy/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from xdsl.printer import Printer

from .compiler import context, transform
from .emulator.toy_accelerator_functions import ToyAcceleratorFunctions
from .frontend.ir_gen import IRGen
from .frontend.parser import Parser as ToyParser
from .interpreter import Interpreter, ToyFunctions
Expand All @@ -35,9 +36,10 @@
)
parser.add_argument("--ir", dest="ir", action="store_true")
parser.add_argument("--print-op-generic", dest="print_generic", action="store_true")
parser.add_argument("--accelerate", dest="accelerate", action="store_true")


def main(path: Path, emit: str, ir: bool, print_generic: bool):
def main(path: Path, emit: str, ir: bool, accelerate: bool, print_generic: bool):
ctx = context()

path = args.source
Expand All @@ -60,7 +62,7 @@ def main(path: Path, emit: str, ir: bool, print_generic: bool):
print(f"Unknown file format {path}")
return

transform(ctx, module_op, target=emit)
transform(ctx, module_op, target=emit, accelerate=accelerate)

if ir:
printer = Printer(print_generic_format=print_generic)
Expand All @@ -72,7 +74,9 @@ def main(path: Path, emit: str, ir: bool, print_generic: bool):
interpreter.register_implementations(ToyFunctions())
if emit in ("affine"):
interpreter.register_implementations(AffineFunctions())
if emit in ("affine", "scf", "cf"):
if accelerate and emit in ("affine", "scf"):
interpreter.register_implementations(ToyAcceleratorFunctions())
if emit in ("affine", "scf"):
interpreter.register_implementations(ArithFunctions())
interpreter.register_implementations(MemrefFunctions())
interpreter.register_implementations(PrintfFunctions())
Expand All @@ -85,4 +89,4 @@ def main(path: Path, emit: str, ir: bool, print_generic: bool):

if __name__ == "__main__":
args = parser.parse_args()
main(args.source, args.emit, args.ir, args.print_generic)
main(args.source, args.emit, args.ir, args.accelerate, args.print_generic)
23 changes: 19 additions & 4 deletions docs/Toy/toy/compiler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path

from xdsl.dialects import printf, scf
from xdsl.dialects import affine, arith, func, memref, printf, scf
from xdsl.dialects.builtin import Builtin, ModuleOp
from xdsl.ir import MLContext
from xdsl.transforms.mlir_opt import MLIROptPass
Expand All @@ -9,17 +9,22 @@
from .frontend.ir_gen import IRGen
from .frontend.parser import Parser
from .rewrites.inline_toy import InlineToyPass
from .rewrites.lower_to_toy_accelerator import LowerToToyAccelerator
from .rewrites.lower_toy_affine import LowerToAffinePass
from .rewrites.optimise_toy import OptimiseToy
from .rewrites.shape_inference import ShapeInferencePass


def context() -> MLContext:
ctx = MLContext()
ctx.register_dialect(affine.Affine)
ctx.register_dialect(arith.Arith)
ctx.register_dialect(Builtin)
ctx.register_dialect(toy.Toy)
ctx.register_dialect(scf.Scf)
ctx.register_dialect(func.Func)
ctx.register_dialect(memref.MemRef)
ctx.register_dialect(printf.Printf)
ctx.register_dialect(scf.Scf)
ctx.register_dialect(toy.Toy)
return ctx


Expand All @@ -30,7 +35,13 @@ def parse_toy(program: str, ctx: MLContext | None = None) -> ModuleOp:
return module_op


def transform(ctx: MLContext, module_op: ModuleOp, *, target: str = "toy-infer-shapes"):
def transform(
ctx: MLContext,
module_op: ModuleOp,
*,
target: str = "riscv-assembly",
accelerate: bool,
):
if target == "toy":
return

Expand All @@ -52,6 +63,10 @@ def transform(ctx: MLContext, module_op: ModuleOp, *, target: str = "toy-infer-s
LowerToAffinePass().apply(ctx, module_op)
module_op.verify()

if accelerate:
LowerToToyAccelerator().apply(ctx, module_op)
module_op.verify()

if target == "affine":
return

Expand Down
104 changes: 104 additions & 0 deletions docs/Toy/toy/dialects/toy_accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from __future__ import annotations

import abc
from typing import Annotated

from xdsl.dialects.builtin import AnyIntegerAttr, ArrayAttr
from xdsl.dialects.memref import MemRefType
from xdsl.ir import Attribute, Dialect, Operation, SSAValue
from xdsl.irdl import (
ConstraintVar,
IRDLOperation,
attr_def,
irdl_op_definition,
operand_def,
)
from xdsl.utils.exceptions import VerifyException


@irdl_op_definition
class Transpose(IRDLOperation):
name = "toy_accelerator.transpose"

destination = operand_def(MemRefType)
source = operand_def(MemRefType)

source_rows = attr_def(AnyIntegerAttr)
source_cols = attr_def(AnyIntegerAttr)

def __init__(
self,
destination: SSAValue | Operation,
input: SSAValue | Operation,
source_rows: AnyIntegerAttr,
source_cols: AnyIntegerAttr,
):
super().__init__(
operands=(destination, input),
attributes={
"source_rows": source_rows,
"source_cols": source_cols,
},
)

def verify_(self) -> None:
if not isinstance(self.source.type, MemRefType):
raise VerifyException(
f"Invalid transpose source type {self.source.type}, expected MemRefType"
)
if not isinstance(self.destination.type, MemRefType):
raise VerifyException(
f"Invalid transpose destination type {self.destination.type}, expected MemRefType"
)

expected_source_shape = ArrayAttr((self.source_rows, self.source_cols))
source_shape = self.source.type.shape

if source_shape != expected_source_shape:
raise VerifyException(f"Transpose source shape mismatch")

expected_destination_shape = ArrayAttr((self.source_cols, self.source_rows))
destination_shape = self.destination.type.shape

if destination_shape != expected_destination_shape:
raise VerifyException(f"Transpose source shape mismatch")


class BinOp(IRDLOperation, abc.ABC):
"""
An in-place mutating binary operation.
"""

T = Annotated[MemRefType[Attribute], ConstraintVar("T")]

dest = operand_def(T)
lhs = operand_def(T)
rhs = operand_def(T)

def __init__(
self,
dest: SSAValue | Operation,
lhs: SSAValue | Operation,
rhs: SSAValue | Operation,
):
super().__init__(operands=(dest, lhs, rhs))


@irdl_op_definition
class Add(BinOp):
name = "toy_accelerator.add"


@irdl_op_definition
class Mul(BinOp):
name = "toy_accelerator.mul"


ToyAccelerator = Dialect(
[
Transpose,
Add,
Mul,
],
[],
)
51 changes: 51 additions & 0 deletions docs/Toy/toy/emulator/toy_accelerator_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

from typing import Any

from xdsl.interpreter import Interpreter, InterpreterFunctions, impl, register_impls

from ..dialects import toy_accelerator


@register_impls
class ToyAcceleratorFunctions(InterpreterFunctions):
@impl(toy_accelerator.Transpose)
def run_transpose(
self,
interpreter: Interpreter,
op: toy_accelerator.Transpose,
args: tuple[Any, ...],
) -> tuple[Any, ...]:
dest, source = args

source_rows = op.source_rows.value.data
source_cols = op.source_cols.value.data

for row in range(source_rows):
for col in range(source_cols):
value = source.load((row, col))
dest.store((col, row), value)

return ()

@impl(toy_accelerator.Add)
def run_add(
self, interpreter: Interpreter, op: toy_accelerator.Add, args: tuple[Any, ...]
) -> tuple[Any, ...]:
dest, lhs, rhs = args

for i, (l, r) in enumerate(zip(lhs.data, rhs.data)):
dest.data[i] = l + r

return ()

@impl(toy_accelerator.Mul)
def run_mul(
self, interpreter: Interpreter, op: toy_accelerator.Mul, args: tuple[Any, ...]
) -> tuple[Any, ...]:
dest, lhs, rhs = args

for i, (l, r) in enumerate(zip(lhs.data, rhs.data)):
dest.data[i] = l * r

return ()
Loading