Skip to content

Commit

Permalink
dialects: (affine) add store and load operations (#1183)
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh committed Jun 26, 2023
1 parent aefa483 commit f2c69dc
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 3 deletions.
15 changes: 15 additions & 0 deletions tests/filecheck/dialects/affine/affine_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,19 @@
// CHECK-NEXT: "affine.yield"(%{{.*}}) : (!test.type<"int">) -> ()
// CHECK-NEXT: }) {"lower_bound" = -10 : index, "upper_bound" = 10 : index, "step" = 1 : index} : (!test.type<"int">) -> !test.type<"int">

%memref = "test.op"() : () -> memref<2x3xf64>
%value = "test.op"() : () -> f64
"affine.store"(%value, %memref) {"map" = affine_map<() -> (0, 0)>} : (f64, memref<2x3xf64>) -> ()

// CHECK: %memref = "test.op"() : () -> memref<2x3xf64>
// CHECK-NEXT: %value = "test.op"() : () -> f64
// CHECK-NEXT: "affine.store"(%value, %memref) {"map" = affine_map<() -> (0, 0)>} : (f64, memref<2x3xf64>) -> ()

%zero = "test.op"() : () -> index
%same_value = "affine.load"(%memref, %zero, %zero) {"map" = affine_map<(d0, d1) -> (d0, d1)>} : (memref<2x3xf64>, index, index) -> f64

// CHECK: %zero = "test.op"() : () -> index
// CHECK-NEXT: %same_value = "affine.load"(%memref, %zero, %zero) {"map" = affine_map<(d0, d1) -> (d0, d1)>} : (memref<2x3xf64>, index, index) -> f64


}) : () -> ()
10 changes: 10 additions & 0 deletions tests/filecheck/mlir-conversion/with-mlir/affine_map.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,14 @@
"f"() {map = affine_map<(d0, d1, d2) -> ()>} : () -> ()
"f"() {map = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>} : () -> ()
"f"() {map = affine_map<(d0, d1, d2) -> (d0 floordiv 2)>} : () -> ()
"func.func"() ({
%memref = "test.op"() : () -> memref<2x3xf64>
%value = "test.op"() : () -> f64
"affine.store"(%value, %memref) {"map" = affine_map<() -> (0, 0)>} : (f64, memref<2x3xf64>) -> ()

%zero = "test.op"() : () -> index
%same_value = "affine.load"(%memref, %zero, %zero) {"map" = affine_map<(d0, d1) -> (d0, d1)>} : (memref<2x3xf64>, index, index) -> f64

"func.return"() : () -> ()
}) {function_type = () -> (), sym_name = "store_load"} : () -> ()
}) : () -> ()
63 changes: 60 additions & 3 deletions xdsl/dialects/affine.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
from __future__ import annotations

from typing import Sequence
from typing import Annotated, Sequence

from xdsl.dialects.builtin import AnyIntegerAttr, IndexType, IntegerAttr
from xdsl.dialects.builtin import AffineMapAttr, AnyIntegerAttr, IndexType, IntegerAttr
from xdsl.dialects.memref import MemRefType
from xdsl.ir import Attribute, Operation, SSAValue, Block, Region, Dialect
from xdsl.traits import IsTerminator
from xdsl.irdl import (
ConstraintVar,
VarOpResult,
attr_def,
irdl_op_definition,
VarOperand,
AnyAttr,
IRDLOperation,
operand_def,
opt_attr_def,
region_def,
result_def,
var_operand_def,
var_result_def,
)
Expand Down Expand Up @@ -83,6 +88,50 @@ def from_region(
)


@irdl_op_definition
class Store(IRDLOperation):
name = "affine.store"

T = Annotated[Attribute, ConstraintVar("T")]

value = operand_def(T)
memref = operand_def(MemRefType[T])
map = opt_attr_def(AffineMapAttr)

def __init__(self, value: SSAValue, memref: SSAValue, map: AffineMapAttr):
super().__init__(
operands=(value, memref),
attributes={"map": map},
)


@irdl_op_definition
class Load(IRDLOperation):
name = "affine.load"

T = Annotated[Attribute, ConstraintVar("T")]

memref = operand_def(MemRefType[T])
indices = var_operand_def(IndexType)

result = result_def(T)

map = opt_attr_def(AffineMapAttr)

def __init__(
self,
memref: SSAValue,
indices: Sequence[SSAValue],
map: AffineMapAttr,
result_type: T,
):
super().__init__(
operands=(memref, indices),
attributes={"map": map},
result_types=(result_type,),
)


@irdl_op_definition
class Yield(IRDLOperation):
name = "affine.yield"
Expand All @@ -95,4 +144,12 @@ def get(*operands: SSAValue | Operation) -> Yield:
return Yield.create(operands=[SSAValue.get(operand) for operand in operands])


Affine = Dialect([For, Yield], [])
Affine = Dialect(
[
For,
Store,
Load,
Yield,
],
[],
)

0 comments on commit f2c69dc

Please sign in to comment.