Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (affine) use maps in affine.for bounds, and add some helpers #1209

Merged
merged 2 commits into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 10 additions & 9 deletions tests/dialects/test_affine.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import pytest

from xdsl.dialects.affine import For, Yield
from xdsl.dialects.builtin import IndexType, IntegerAttr, IntegerType
from xdsl.dialects.builtin import AffineMapAttr, IndexType, IntegerAttr, IntegerType
from xdsl.ir import Attribute, Region, Block
from xdsl.ir.affine.affine_expr import AffineExpr


def test_simple_for():
f = For.from_region([], [], 0, 5, Region())
assert f.lower_bound.value.data == 0
assert f.upper_bound.value.data == 5
assert f.lower_bound.data.results == [AffineExpr.constant(0)]
assert f.upper_bound.data.results == [AffineExpr.constant(5)]


def test_for_mismatch_operands_results_counts():
attributes: dict[str, Attribute] = {
"lower_bound": IntegerAttr.from_index_int_value(0),
"upper_bound": IntegerAttr.from_index_int_value(5),
"lower_bound": AffineMapAttr.constant_map(0),
"upper_bound": AffineMapAttr.constant_map(5),
"step": IntegerAttr.from_index_int_value(1),
}
f = For.create(
Expand All @@ -30,8 +31,8 @@ def test_for_mismatch_operands_results_counts():

def test_for_mismatch_operands_results_types():
attributes: dict[str, Attribute] = {
"lower_bound": IntegerAttr.from_index_int_value(0),
"upper_bound": IntegerAttr.from_index_int_value(5),
"lower_bound": AffineMapAttr.constant_map(0),
"upper_bound": AffineMapAttr.constant_map(5),
"step": IntegerAttr.from_index_int_value(1),
}
b = Block(arg_types=(IntegerType(32),))
Expand All @@ -52,8 +53,8 @@ def test_for_mismatch_operands_results_types():

def test_for_mismatch_blockargs():
attributes: dict[str, Attribute] = {
"lower_bound": IntegerAttr.from_index_int_value(0),
"upper_bound": IntegerAttr.from_index_int_value(5),
"lower_bound": AffineMapAttr.constant_map(0),
"upper_bound": AffineMapAttr.constant_map(5),
"step": IntegerAttr.from_index_int_value(1),
}
b = Block(arg_types=(IndexType(),))
Expand Down
9 changes: 5 additions & 4 deletions tests/filecheck/dialects/affine/affine_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
"affine.for"() ({
^0(%i : index):
"affine.yield"() : () -> ()
}) {"lower_bound" = 0 : index, "upper_bound" = 256 : index, "step" = 1 : index} : () -> ()
}) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (256)>, "step" = 1 : index} : () -> ()

// CHECK: "affine.for"() ({
// CHECK-NEXT: ^0(%{{.*}} : index):
// CHECK-NEXT: "affine.yield"() : () -> ()
// CHECK-NEXT: }) {"lower_bound" = 0 : index, "upper_bound" = 256 : index, "step" = 1 : index} : () -> ()
// CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (256)>, "step" = 1 : index} : () -> ()


// For with values being passed during iterations
Expand All @@ -22,13 +22,14 @@
^1(%i : index, %step_value : !test.type<"int">):
%next_value = "test.op"() : () -> !test.type<"int">
"affine.yield"(%next_value) : (!test.type<"int">) -> ()
}) {"lower_bound" = -10 : index, "upper_bound" = 10 : index, "step" = 1 : index} : (!test.type<"int">) -> (!test.type<"int">)
}) {"lower_bound" = affine_map<() -> (-10)>, "upper_bound" = affine_map<() -> (10)>, "step" = 1 : index} : (!test.type<"int">) -> (!test.type<"int">)

// CHECK: %res = "affine.for"(%{{.*}}) ({
// CHECK-NEXT: ^1(%{{.*}} : index, %{{.*}} : !test.type<"int">):
// CHECK-NEXT: %{{.*}} = "test.op"() : () -> !test.type<"int">
// CHECK-NEXT: "affine.yield"(%{{.*}}) : (!test.type<"int">) -> ()
// CHECK-NEXT: }) {"lower_bound" = -10 : index, "upper_bound" = 10 : index, "step" = 1 : index} : (!test.type<"int">) -> !test.type<"int">
// CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (-10)>, "upper_bound" = affine_map<() -> (10)>, "step" = 1 : index} : (!test.type<"int">) -> !test.type<"int">


%memref = "test.op"() : () -> memref<2x3xf64>
%value = "test.op"() : () -> f64
Expand Down
16 changes: 8 additions & 8 deletions tests/filecheck/dialects/affine/examples.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
%val = "memref.load"(%ref, %i) : (memref<128xi32>, index) -> i32
%res = "arith.addi"(%sum, %val) : (i32, i32) -> i32
"affine.yield"(%res) : (i32) -> ()
}) {"lower_bound" = 0 : index, "upper_bound" = 256 : index, "step" = 1 : index} : (i32) -> i32
}) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (256)>, "step" = 1 : index} : (i32) -> i32
func.return %r : i32
}) {"sym_name" = "sum_vec", "function_type" = (memref<128xi32>) -> i32, "sym_visibility" = "private"} : () -> ()

Expand All @@ -24,7 +24,7 @@
// CHECK-NEXT: %{{.*}} = "memref.load"(%{{.*}}, %{{.*}}) : (memref<128xi32>, index) -> i32
// CHECK-NEXT: %{{.*}} = "arith.addi"(%{{.*}}, %{{.*}}) : (i32, i32) -> i32
// CHECK-NEXT: "affine.yield"(%{{.*}}) : (i32) -> ()
// CHECK-NEXT: }) {"lower_bound" = 0 : index, "upper_bound" = 256 : index, "step" = 1 : index} : (i32) -> i32
// CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (256)>, "step" = 1 : index} : (i32) -> i32
// CHECK-NEXT: "func.return"(%{{.*}}) : (i32) -> ()
// CHECK-NEXT: }) {"sym_name" = "sum_vec", "function_type" = (memref<128xi32>) -> i32, "sym_visibility" = "private"} : () -> ()

Expand All @@ -46,11 +46,11 @@
%10 = "arith.addf"(%8, %9) : (f32, f32) -> f32
"memref.store"(%10, %2, %3, %4) : (f32, memref<256x256xf32>, index, index) -> ()
"affine.yield"() : () -> ()
}) {"lower_bound" = 0 : index, "upper_bound" = 256 : index, "step" = 1 : index} : () -> ()
}) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (256)>, "step" = 1 : index} : () -> ()
"affine.yield"() : () -> ()
}) {"lower_bound" = 0 : index, "upper_bound" = 256 : index, "step" = 1 : index} : () -> ()
}) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (256)>, "step" = 1 : index} : () -> ()
"affine.yield"() : () -> ()
}) {"lower_bound" = 0 : index, "upper_bound" = 256 : index, "step" = 1 : index} : () -> ()
}) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (256)>, "step" = 1 : index} : () -> ()
"func.return"(%2) : (memref<256x256xf32>) -> ()
}) {"sym_name" = "affine_mm", "function_type" = (memref<256x256xf32>, memref<256x256xf32>, memref<256x256xf32>) -> memref<256x256xf32>, "sym_visibility" = "private"} : () -> ()

Expand All @@ -69,11 +69,11 @@
//CHECK-NEXT: %{{.*}} = "arith.addf"(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
//CHECK-NEXT: "memref.store"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (f32, memref<256x256xf32>, index, index) -> ()
//CHECK-NEXT: "affine.yield"() : () -> ()
//CHECK-NEXT: }) {"lower_bound" = 0 : index, "upper_bound" = 256 : index, "step" = 1 : index} : () -> ()
//CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (256)>, "step" = 1 : index} : () -> ()
//CHECK-NEXT: "affine.yield"() : () -> ()
//CHECK-NEXT: }) {"lower_bound" = 0 : index, "upper_bound" = 256 : index, "step" = 1 : index} : () -> ()
//CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (256)>, "step" = 1 : index} : () -> ()
//CHECK-NEXT: "affine.yield"() : () -> ()
//CHECK-NEXT: }) {"lower_bound" = 0 : index, "upper_bound" = 256 : index, "step" = 1 : index} : () -> ()
//CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (256)>, "step" = 1 : index} : () -> ()
//CHECK-NEXT: "func.return"(%{{.*}}) : (memref<256x256xf32>) -> ()
//CHECK-NEXT: }) {"sym_name" = "affine_mm", "function_type" = (memref<256x256xf32>, memref<256x256xf32>, memref<256x256xf32>) -> memref<256x256xf32>, "sym_visibility" = "private"} : () -> ()

Expand Down
12 changes: 6 additions & 6 deletions tests/filecheck/frontend/dialects/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# CHECK-NEXT: "affine.for"() ({
# CHECK-NEXT: ^0(%0 : index):
# CHECK-NEXT: "affine.yield"() : () -> ()
# CHECK-NEXT: }) {"lower_bound" = 0 : index, "upper_bound" = 100 : index, "step" = 1 : index} : () -> ()
# CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (100)>, "step" = 1 : index} : () -> ()
# CHECK-NEXT: func.return
# CHECK-NEXT: }

Expand All @@ -24,7 +24,7 @@ def test_affine_for_I():
# CHECK-NEXT: "affine.for"() ({
# CHECK-NEXT: ^1(%1 : index):
# CHECK-NEXT: "affine.yield"() : () -> ()
# CHECK-NEXT: }) {"lower_bound" = 10 : index, "upper_bound" = 30 : index, "step" = 1 : index} : () -> ()
# CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (10)>, "upper_bound" = affine_map<() -> (30)>, "step" = 1 : index} : () -> ()
# CHECK-NEXT: func.return
# CHECK-NEXT: }
def test_affine_for_II():
Expand All @@ -36,7 +36,7 @@ def test_affine_for_II():
# CHECK-NEXT: "affine.for"() ({
# CHECK-NEXT: ^2(%2 : index):
# CHECK-NEXT: "affine.yield"() : () -> ()
# CHECK-NEXT: }) {"lower_bound" = 1 : index, "upper_bound" = 20 : index, "step" = 5 : index} : () -> ()
# CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (1)>, "upper_bound" = affine_map<() -> (20)>, "step" = 5 : index} : () -> ()
# CHECK-NEXT: func.return
# CHECK-NEXT: }
def test_affine_for_III():
Expand All @@ -52,11 +52,11 @@ def test_affine_for_III():
# CHECK-NEXT: "affine.for"() ({
# CHECK-NEXT: ^5(%5 : index):
# CHECK-NEXT: "affine.yield"() : () -> ()
# CHECK-NEXT: }) {"lower_bound" = 0 : index, "upper_bound" = 30 : index, "step" = 1 : index} : () -> ()
# CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (30)>, "step" = 1 : index} : () -> ()
# CHECK-NEXT: "affine.yield"() : () -> ()
# CHECK-NEXT: }) {"lower_bound" = 0 : index, "upper_bound" = 20 : index, "step" = 1 : index} : () -> ()
# CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (20)>, "step" = 1 : index} : () -> ()
# CHECK-NEXT: "affine.yield"() : () -> ()
# CHECK-NEXT: }) {"lower_bound" = 0 : index, "upper_bound" = 10 : index, "step" = 1 : index} : () -> ()
# CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (10)>, "step" = 1 : index} : () -> ()
# CHECK-NEXT: func.return
# CHECK-NEXT: }
def test_affine_for_IV():
Expand Down
78 changes: 63 additions & 15 deletions xdsl/dialects/affine.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from __future__ import annotations

from typing import Annotated, Sequence

from xdsl.dialects.builtin import AffineMapAttr, AnyIntegerAttr, IndexType, IntegerAttr
from typing import Annotated, Any, Sequence, cast

from xdsl.dialects.builtin import (
AffineMapAttr,
AnyIntegerAttr,
ContainerType,
IndexType,
IntegerAttr,
ShapedType,
)
from xdsl.dialects.memref import MemRefType
from xdsl.ir import Attribute, Operation, SSAValue, Block, Region, Dialect
from xdsl.ir.affine.affine_expr import AffineExpr
from xdsl.ir.affine.affine_map import AffineMap
from xdsl.traits import IsTerminator
from xdsl.irdl import (
ConstraintVar,
Expand All @@ -30,10 +39,8 @@ class For(IRDLOperation):
arguments: VarOperand = var_operand_def(AnyAttr())
res: VarOpResult = var_result_def(AnyAttr())

# TODO the bounds are in fact affine_maps
# TODO support dynamic bounds as soon as maps are here
lower_bound: AnyIntegerAttr = attr_def(AnyIntegerAttr)
upper_bound: AnyIntegerAttr = attr_def(AnyIntegerAttr)
lower_bound = attr_def(AffineMapAttr)
upper_bound = attr_def(AffineMapAttr)
step: AnyIntegerAttr = attr_def(AnyIntegerAttr)

body: Region = region_def()
Expand Down Expand Up @@ -64,15 +71,19 @@ def verify_(self) -> None:
def from_region(
operands: Sequence[Operation | SSAValue],
result_types: Sequence[Attribute],
lower_bound: int | AnyIntegerAttr,
upper_bound: int | AnyIntegerAttr,
lower_bound: int | AffineMapAttr,
upper_bound: int | AffineMapAttr,
region: Region,
step: int | AnyIntegerAttr = 1,
) -> For:
if isinstance(lower_bound, int):
lower_bound = IntegerAttr.from_index_int_value(lower_bound)
lower_bound = AffineMapAttr(
AffineMap(0, 0, [AffineExpr.constant(lower_bound)])
)
if isinstance(upper_bound, int):
upper_bound = IntegerAttr.from_index_int_value(upper_bound)
upper_bound = AffineMapAttr(
AffineMap(0, 0, [AffineExpr.constant(upper_bound)])
)
if isinstance(step, int):
step = IntegerAttr.from_index_int_value(step)
attributes: dict[str, Attribute] = {
Expand All @@ -96,11 +107,28 @@ class Store(IRDLOperation):

value = operand_def(T)
memref = operand_def(MemRefType[T])
indices = var_operand_def(IndexType)
map = opt_attr_def(AffineMapAttr)

def __init__(self, value: SSAValue, memref: SSAValue, map: AffineMapAttr):
def __init__(
self,
value: SSAValue,
memref: SSAValue,
indices: Sequence[SSAValue],
map: AffineMapAttr | None = None,
):
if map is None:
# Create identity map for memrefs with at least one dimension or () -> ()
# for zero-dimensional memrefs.
if not isinstance(memref.typ, MemRefType):
raise ValueError(
"affine.store memref operand must be of type MemrefType"
)
memref_type = cast(MemRefType[Attribute], memref.typ)
rank = memref_type.get_num_dims()
map = AffineMapAttr(AffineMap.identity(rank))
super().__init__(
operands=(value, memref),
operands=(value, memref, indices),
attributes={"map": map},
)

Expand All @@ -122,9 +150,29 @@ def __init__(
self,
memref: SSAValue,
indices: Sequence[SSAValue],
map: AffineMapAttr,
result_type: T,
map: AffineMapAttr | None = None,
result_type: T | None = None,
):
if map is None:
# Create identity map for memrefs with at least one dimension or () -> ()
# for zero-dimensional memrefs.
if not isinstance(memref.typ, ShapedType):
raise ValueError(
"affine.store memref operand must be of type ShapedType"
)
memref_type = cast(MemRefType[Attribute], memref.typ)
rank = memref_type.get_num_dims()
map = AffineMapAttr(AffineMap.identity(rank))
if result_type is None:
# Create identity map for memrefs with at least one dimension or () -> ()
# for zero-dimensional memrefs.
if not isinstance(memref.typ, ContainerType):
raise ValueError(
"affine.store memref operand must be of type ContainerType"
)
memref_type = cast(ContainerType[Any], memref.typ)
result_type = memref_type.get_element_type()

super().__init__(
operands=(memref, indices),
attributes={"map": map},
Expand Down
4 changes: 4 additions & 0 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,10 @@ def parse_parameter(parser: Parser) -> AffineMap:
def print_parameter(self, printer: Printer) -> None:
printer.print_string(f"{self.data}")

@staticmethod
def constant_map(value: int) -> AffineMapAttr:
return AffineMapAttr(AffineMap.constant_map(value))


@irdl_op_definition
class UnrealizedConversionCastOp(IRDLOperation):
Expand Down
4 changes: 3 additions & 1 deletion xdsl/ir/affine/affine_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ def __radd__(self, other: AffineExpr | int) -> AffineExpr:
return self.__add__(other)

def __neg__(self) -> AffineExpr:
if isinstance(self._impl, _AffineConstantExprStorage):
return AffineExpr.constant(-self._impl.value)
return self * -1

def __sub__(self, other: AffineExpr | int) -> AffineExpr:
Expand Down Expand Up @@ -251,7 +253,7 @@ def __mul__(self, other: AffineExpr | int) -> AffineExpr:
# TODO (#1087): MLIR also supports multiplication by symbols also, making
# maps semi-affine. Currently, we do not implement semi-affine maps.
raise NotImplementedError(
"Multiplication with non-constant (semi-affine) is not supported yet"
f"Multiplication with non-constant (semi-affine) is not supported yet self: {self} other: {other}"
)
# TODO (#1086): Simplify multiplication here before returning.
return AffineExpr(_AffineBinaryOpExprStorage(_AffineExprKind.Mul, self, other))
Expand Down
12 changes: 12 additions & 0 deletions xdsl/ir/affine/affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@ class AffineMap:
num_symbols: int
results: list[AffineExpr]

@staticmethod
def constant_map(value: int) -> AffineMap:
return AffineMap(0, 0, [AffineExpr.constant(value)])

@staticmethod
def identity(rank: int) -> AffineMap:
return AffineMap(rank, 0, [AffineExpr.dimension(dim) for dim in range(rank)])

@staticmethod
def empty() -> AffineMap:
return AffineMap(0, 0, [])

Comment on lines +26 to +29
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make the empty method with dims = 0, syms = 0?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean explicitly specifying the arguments?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I meant the api. But it's fine, I can extend this later.

def eval(self, dims: list[int], symbols: list[int]) -> list[int]:
"""Evaluate the AffineMap given the values of dimensions and symbols."""
assert len(dims) == self.num_dims
Expand Down