diff --git a/tests/dialects/test_affine.py b/tests/dialects/test_affine.py index d5bd401379..3f3ce386df 100644 --- a/tests/dialects/test_affine.py +++ b/tests/dialects/test_affine.py @@ -1,20 +1,21 @@ import pytest from xdsl.dialects.affine import For, Yield -from xdsl.dialects.builtin import IndexType, IntegerAttr, IntegerType +from xdsl.dialects.builtin import AffineMapAttr, IndexType, IntegerAttr, IntegerType from xdsl.ir import Attribute, Region, Block +from xdsl.ir.affine.affine_expr import AffineExpr def test_simple_for(): f = For.from_region([], [], 0, 5, Region()) - assert f.lower_bound.value.data == 0 - assert f.upper_bound.value.data == 5 + assert f.lower_bound.data.results == [AffineExpr.constant(0)] + assert f.upper_bound.data.results == [AffineExpr.constant(5)] def test_for_mismatch_operands_results_counts(): attributes: dict[str, Attribute] = { - "lower_bound": IntegerAttr.from_index_int_value(0), - "upper_bound": IntegerAttr.from_index_int_value(5), + "lower_bound": AffineMapAttr.constant_map(0), + "upper_bound": AffineMapAttr.constant_map(5), "step": IntegerAttr.from_index_int_value(1), } f = For.create( @@ -30,8 +31,8 @@ def test_for_mismatch_operands_results_counts(): def test_for_mismatch_operands_results_types(): attributes: dict[str, Attribute] = { - "lower_bound": IntegerAttr.from_index_int_value(0), - "upper_bound": IntegerAttr.from_index_int_value(5), + "lower_bound": AffineMapAttr.constant_map(0), + "upper_bound": AffineMapAttr.constant_map(5), "step": IntegerAttr.from_index_int_value(1), } b = Block(arg_types=(IntegerType(32),)) @@ -52,8 +53,8 @@ def test_for_mismatch_operands_results_types(): def test_for_mismatch_blockargs(): attributes: dict[str, Attribute] = { - "lower_bound": IntegerAttr.from_index_int_value(0), - "upper_bound": IntegerAttr.from_index_int_value(5), + "lower_bound": AffineMapAttr.constant_map(0), + "upper_bound": AffineMapAttr.constant_map(5), "step": IntegerAttr.from_index_int_value(1), } b = Block(arg_types=(IndexType(),)) diff --git a/tests/filecheck/dialects/affine/affine_ops.mlir b/tests/filecheck/dialects/affine/affine_ops.mlir index 92571b75d0..17efbb79fb 100644 --- a/tests/filecheck/dialects/affine/affine_ops.mlir +++ b/tests/filecheck/dialects/affine/affine_ops.mlir @@ -7,12 +7,12 @@ "affine.for"() ({ ^0(%i : index): "affine.yield"() : () -> () - }) {"lower_bound" = 0 : index, "upper_bound" = 256 : index, "step" = 1 : index} : () -> () + }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (256)>, "step" = 1 : index} : () -> () // CHECK: "affine.for"() ({ // CHECK-NEXT: ^0(%{{.*}} : index): // CHECK-NEXT: "affine.yield"() : () -> () - // CHECK-NEXT: }) {"lower_bound" = 0 : index, "upper_bound" = 256 : index, "step" = 1 : index} : () -> () + // CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (256)>, "step" = 1 : index} : () -> () // For with values being passed during iterations @@ -22,13 +22,14 @@ ^1(%i : index, %step_value : !test.type<"int">): %next_value = "test.op"() : () -> !test.type<"int"> "affine.yield"(%next_value) : (!test.type<"int">) -> () - }) {"lower_bound" = -10 : index, "upper_bound" = 10 : index, "step" = 1 : index} : (!test.type<"int">) -> (!test.type<"int">) + }) {"lower_bound" = affine_map<() -> (-10)>, "upper_bound" = affine_map<() -> (10)>, "step" = 1 : index} : (!test.type<"int">) -> (!test.type<"int">) // CHECK: %res = "affine.for"(%{{.*}}) ({ // CHECK-NEXT: ^1(%{{.*}} : index, %{{.*}} : !test.type<"int">): // CHECK-NEXT: %{{.*}} = "test.op"() : () -> !test.type<"int"> // CHECK-NEXT: "affine.yield"(%{{.*}}) : (!test.type<"int">) -> () - // CHECK-NEXT: }) {"lower_bound" = -10 : index, "upper_bound" = 10 : index, "step" = 1 : index} : (!test.type<"int">) -> !test.type<"int"> + // CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (-10)>, "upper_bound" = affine_map<() -> (10)>, "step" = 1 : index} : (!test.type<"int">) -> !test.type<"int"> + %memref = "test.op"() : () -> memref<2x3xf64> %value = "test.op"() : () -> f64 diff --git a/tests/filecheck/dialects/affine/examples.mlir b/tests/filecheck/dialects/affine/examples.mlir index 4cd67295ba..9560d85d82 100644 --- a/tests/filecheck/dialects/affine/examples.mlir +++ b/tests/filecheck/dialects/affine/examples.mlir @@ -12,7 +12,7 @@ %val = "memref.load"(%ref, %i) : (memref<128xi32>, index) -> i32 %res = "arith.addi"(%sum, %val) : (i32, i32) -> i32 "affine.yield"(%res) : (i32) -> () - }) {"lower_bound" = 0 : index, "upper_bound" = 256 : index, "step" = 1 : index} : (i32) -> i32 + }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (256)>, "step" = 1 : index} : (i32) -> i32 func.return %r : i32 }) {"sym_name" = "sum_vec", "function_type" = (memref<128xi32>) -> i32, "sym_visibility" = "private"} : () -> () @@ -24,7 +24,7 @@ // CHECK-NEXT: %{{.*}} = "memref.load"(%{{.*}}, %{{.*}}) : (memref<128xi32>, index) -> i32 // CHECK-NEXT: %{{.*}} = "arith.addi"(%{{.*}}, %{{.*}}) : (i32, i32) -> i32 // CHECK-NEXT: "affine.yield"(%{{.*}}) : (i32) -> () - // CHECK-NEXT: }) {"lower_bound" = 0 : index, "upper_bound" = 256 : index, "step" = 1 : index} : (i32) -> i32 + // CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (256)>, "step" = 1 : index} : (i32) -> i32 // CHECK-NEXT: "func.return"(%{{.*}}) : (i32) -> () // CHECK-NEXT: }) {"sym_name" = "sum_vec", "function_type" = (memref<128xi32>) -> i32, "sym_visibility" = "private"} : () -> () @@ -46,11 +46,11 @@ %10 = "arith.addf"(%8, %9) : (f32, f32) -> f32 "memref.store"(%10, %2, %3, %4) : (f32, memref<256x256xf32>, index, index) -> () "affine.yield"() : () -> () - }) {"lower_bound" = 0 : index, "upper_bound" = 256 : index, "step" = 1 : index} : () -> () + }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (256)>, "step" = 1 : index} : () -> () "affine.yield"() : () -> () - }) {"lower_bound" = 0 : index, "upper_bound" = 256 : index, "step" = 1 : index} : () -> () + }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (256)>, "step" = 1 : index} : () -> () "affine.yield"() : () -> () - }) {"lower_bound" = 0 : index, "upper_bound" = 256 : index, "step" = 1 : index} : () -> () + }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (256)>, "step" = 1 : index} : () -> () "func.return"(%2) : (memref<256x256xf32>) -> () }) {"sym_name" = "affine_mm", "function_type" = (memref<256x256xf32>, memref<256x256xf32>, memref<256x256xf32>) -> memref<256x256xf32>, "sym_visibility" = "private"} : () -> () @@ -69,11 +69,11 @@ //CHECK-NEXT: %{{.*}} = "arith.addf"(%{{.*}}, %{{.*}}) : (f32, f32) -> f32 //CHECK-NEXT: "memref.store"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (f32, memref<256x256xf32>, index, index) -> () //CHECK-NEXT: "affine.yield"() : () -> () - //CHECK-NEXT: }) {"lower_bound" = 0 : index, "upper_bound" = 256 : index, "step" = 1 : index} : () -> () + //CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (256)>, "step" = 1 : index} : () -> () //CHECK-NEXT: "affine.yield"() : () -> () - //CHECK-NEXT: }) {"lower_bound" = 0 : index, "upper_bound" = 256 : index, "step" = 1 : index} : () -> () + //CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (256)>, "step" = 1 : index} : () -> () //CHECK-NEXT: "affine.yield"() : () -> () - //CHECK-NEXT: }) {"lower_bound" = 0 : index, "upper_bound" = 256 : index, "step" = 1 : index} : () -> () + //CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (256)>, "step" = 1 : index} : () -> () //CHECK-NEXT: "func.return"(%{{.*}}) : (memref<256x256xf32>) -> () //CHECK-NEXT: }) {"sym_name" = "affine_mm", "function_type" = (memref<256x256xf32>, memref<256x256xf32>, memref<256x256xf32>) -> memref<256x256xf32>, "sym_visibility" = "private"} : () -> () diff --git a/tests/filecheck/frontend/dialects/affine.py b/tests/filecheck/frontend/dialects/affine.py index 7e47331bc3..f1458038fa 100644 --- a/tests/filecheck/frontend/dialects/affine.py +++ b/tests/filecheck/frontend/dialects/affine.py @@ -11,7 +11,7 @@ # CHECK-NEXT: "affine.for"() ({ # CHECK-NEXT: ^0(%0 : index): # CHECK-NEXT: "affine.yield"() : () -> () - # CHECK-NEXT: }) {"lower_bound" = 0 : index, "upper_bound" = 100 : index, "step" = 1 : index} : () -> () + # CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (100)>, "step" = 1 : index} : () -> () # CHECK-NEXT: func.return # CHECK-NEXT: } @@ -24,7 +24,7 @@ def test_affine_for_I(): # CHECK-NEXT: "affine.for"() ({ # CHECK-NEXT: ^1(%1 : index): # CHECK-NEXT: "affine.yield"() : () -> () - # CHECK-NEXT: }) {"lower_bound" = 10 : index, "upper_bound" = 30 : index, "step" = 1 : index} : () -> () + # CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (10)>, "upper_bound" = affine_map<() -> (30)>, "step" = 1 : index} : () -> () # CHECK-NEXT: func.return # CHECK-NEXT: } def test_affine_for_II(): @@ -36,7 +36,7 @@ def test_affine_for_II(): # CHECK-NEXT: "affine.for"() ({ # CHECK-NEXT: ^2(%2 : index): # CHECK-NEXT: "affine.yield"() : () -> () - # CHECK-NEXT: }) {"lower_bound" = 1 : index, "upper_bound" = 20 : index, "step" = 5 : index} : () -> () + # CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (1)>, "upper_bound" = affine_map<() -> (20)>, "step" = 5 : index} : () -> () # CHECK-NEXT: func.return # CHECK-NEXT: } def test_affine_for_III(): @@ -52,11 +52,11 @@ def test_affine_for_III(): # CHECK-NEXT: "affine.for"() ({ # CHECK-NEXT: ^5(%5 : index): # CHECK-NEXT: "affine.yield"() : () -> () - # CHECK-NEXT: }) {"lower_bound" = 0 : index, "upper_bound" = 30 : index, "step" = 1 : index} : () -> () + # CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (30)>, "step" = 1 : index} : () -> () # CHECK-NEXT: "affine.yield"() : () -> () - # CHECK-NEXT: }) {"lower_bound" = 0 : index, "upper_bound" = 20 : index, "step" = 1 : index} : () -> () + # CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (20)>, "step" = 1 : index} : () -> () # CHECK-NEXT: "affine.yield"() : () -> () - # CHECK-NEXT: }) {"lower_bound" = 0 : index, "upper_bound" = 10 : index, "step" = 1 : index} : () -> () + # CHECK-NEXT: }) {"lower_bound" = affine_map<() -> (0)>, "upper_bound" = affine_map<() -> (10)>, "step" = 1 : index} : () -> () # CHECK-NEXT: func.return # CHECK-NEXT: } def test_affine_for_IV(): diff --git a/xdsl/dialects/affine.py b/xdsl/dialects/affine.py index b8b7c9cf14..d191a65ecf 100644 --- a/xdsl/dialects/affine.py +++ b/xdsl/dialects/affine.py @@ -1,10 +1,19 @@ from __future__ import annotations -from typing import Annotated, Sequence - -from xdsl.dialects.builtin import AffineMapAttr, AnyIntegerAttr, IndexType, IntegerAttr +from typing import Annotated, Any, Sequence, cast + +from xdsl.dialects.builtin import ( + AffineMapAttr, + AnyIntegerAttr, + ContainerType, + IndexType, + IntegerAttr, + ShapedType, +) from xdsl.dialects.memref import MemRefType from xdsl.ir import Attribute, Operation, SSAValue, Block, Region, Dialect +from xdsl.ir.affine.affine_expr import AffineExpr +from xdsl.ir.affine.affine_map import AffineMap from xdsl.traits import IsTerminator from xdsl.irdl import ( ConstraintVar, @@ -30,10 +39,8 @@ class For(IRDLOperation): arguments: VarOperand = var_operand_def(AnyAttr()) res: VarOpResult = var_result_def(AnyAttr()) - # TODO the bounds are in fact affine_maps - # TODO support dynamic bounds as soon as maps are here - lower_bound: AnyIntegerAttr = attr_def(AnyIntegerAttr) - upper_bound: AnyIntegerAttr = attr_def(AnyIntegerAttr) + lower_bound = attr_def(AffineMapAttr) + upper_bound = attr_def(AffineMapAttr) step: AnyIntegerAttr = attr_def(AnyIntegerAttr) body: Region = region_def() @@ -64,15 +71,19 @@ def verify_(self) -> None: def from_region( operands: Sequence[Operation | SSAValue], result_types: Sequence[Attribute], - lower_bound: int | AnyIntegerAttr, - upper_bound: int | AnyIntegerAttr, + lower_bound: int | AffineMapAttr, + upper_bound: int | AffineMapAttr, region: Region, step: int | AnyIntegerAttr = 1, ) -> For: if isinstance(lower_bound, int): - lower_bound = IntegerAttr.from_index_int_value(lower_bound) + lower_bound = AffineMapAttr( + AffineMap(0, 0, [AffineExpr.constant(lower_bound)]) + ) if isinstance(upper_bound, int): - upper_bound = IntegerAttr.from_index_int_value(upper_bound) + upper_bound = AffineMapAttr( + AffineMap(0, 0, [AffineExpr.constant(upper_bound)]) + ) if isinstance(step, int): step = IntegerAttr.from_index_int_value(step) attributes: dict[str, Attribute] = { @@ -96,11 +107,28 @@ class Store(IRDLOperation): value = operand_def(T) memref = operand_def(MemRefType[T]) + indices = var_operand_def(IndexType) map = opt_attr_def(AffineMapAttr) - def __init__(self, value: SSAValue, memref: SSAValue, map: AffineMapAttr): + def __init__( + self, + value: SSAValue, + memref: SSAValue, + indices: Sequence[SSAValue], + map: AffineMapAttr | None = None, + ): + if map is None: + # Create identity map for memrefs with at least one dimension or () -> () + # for zero-dimensional memrefs. + if not isinstance(memref.typ, MemRefType): + raise ValueError( + "affine.store memref operand must be of type MemrefType" + ) + memref_type = cast(MemRefType[Attribute], memref.typ) + rank = memref_type.get_num_dims() + map = AffineMapAttr(AffineMap.identity(rank)) super().__init__( - operands=(value, memref), + operands=(value, memref, indices), attributes={"map": map}, ) @@ -122,9 +150,29 @@ def __init__( self, memref: SSAValue, indices: Sequence[SSAValue], - map: AffineMapAttr, - result_type: T, + map: AffineMapAttr | None = None, + result_type: T | None = None, ): + if map is None: + # Create identity map for memrefs with at least one dimension or () -> () + # for zero-dimensional memrefs. + if not isinstance(memref.typ, ShapedType): + raise ValueError( + "affine.store memref operand must be of type ShapedType" + ) + memref_type = cast(MemRefType[Attribute], memref.typ) + rank = memref_type.get_num_dims() + map = AffineMapAttr(AffineMap.identity(rank)) + if result_type is None: + # Create identity map for memrefs with at least one dimension or () -> () + # for zero-dimensional memrefs. + if not isinstance(memref.typ, ContainerType): + raise ValueError( + "affine.store memref operand must be of type ContainerType" + ) + memref_type = cast(ContainerType[Any], memref.typ) + result_type = memref_type.get_element_type() + super().__init__( operands=(memref, indices), attributes={"map": map}, diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 72615e0d94..cf835b77d9 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -1068,6 +1068,10 @@ def parse_parameter(parser: Parser) -> AffineMap: def print_parameter(self, printer: Printer) -> None: printer.print_string(f"{self.data}") + @staticmethod + def constant_map(value: int) -> AffineMapAttr: + return AffineMapAttr(AffineMap.constant_map(value)) + @irdl_op_definition class UnrealizedConversionCastOp(IRDLOperation): diff --git a/xdsl/ir/affine/affine_expr.py b/xdsl/ir/affine/affine_expr.py index f0787ef971..7e02a360f8 100644 --- a/xdsl/ir/affine/affine_expr.py +++ b/xdsl/ir/affine/affine_expr.py @@ -214,6 +214,8 @@ def __radd__(self, other: AffineExpr | int) -> AffineExpr: return self.__add__(other) def __neg__(self) -> AffineExpr: + if isinstance(self._impl, _AffineConstantExprStorage): + return AffineExpr.constant(-self._impl.value) return self * -1 def __sub__(self, other: AffineExpr | int) -> AffineExpr: @@ -251,7 +253,7 @@ def __mul__(self, other: AffineExpr | int) -> AffineExpr: # TODO (#1087): MLIR also supports multiplication by symbols also, making # maps semi-affine. Currently, we do not implement semi-affine maps. raise NotImplementedError( - "Multiplication with non-constant (semi-affine) is not supported yet" + f"Multiplication with non-constant (semi-affine) is not supported yet self: {self} other: {other}" ) # TODO (#1086): Simplify multiplication here before returning. return AffineExpr(_AffineBinaryOpExprStorage(_AffineExprKind.Mul, self, other)) diff --git a/xdsl/ir/affine/affine_map.py b/xdsl/ir/affine/affine_map.py index 463c98bf9f..3c1488252b 100644 --- a/xdsl/ir/affine/affine_map.py +++ b/xdsl/ir/affine/affine_map.py @@ -15,6 +15,18 @@ class AffineMap: num_symbols: int results: list[AffineExpr] + @staticmethod + def constant_map(value: int) -> AffineMap: + return AffineMap(0, 0, [AffineExpr.constant(value)]) + + @staticmethod + def identity(rank: int) -> AffineMap: + return AffineMap(rank, 0, [AffineExpr.dimension(dim) for dim in range(rank)]) + + @staticmethod + def empty() -> AffineMap: + return AffineMap(0, 0, []) + def eval(self, dims: list[int], symbols: list[int]) -> list[int]: """Evaluate the AffineMap given the values of dimensions and symbols.""" assert len(dims) == self.num_dims