diff --git a/tests/interpreters/test_memref_interpreter.py b/tests/interpreters/test_memref_interpreter.py new file mode 100644 index 0000000000..41ac972797 --- /dev/null +++ b/tests/interpreters/test_memref_interpreter.py @@ -0,0 +1,50 @@ +import pytest + +from xdsl.dialects import arith, memref +from xdsl.dialects.builtin import IndexType, ModuleOp, i32 +from xdsl.interpreter import Interpreter +from xdsl.interpreters.arith import ArithFunctions +from xdsl.interpreters.memref import MemrefFunctions, MemrefValue +from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.utils.exceptions import InterpretationError + +interpreter = Interpreter(ModuleOp([])) +interpreter.register_implementations(ArithFunctions()) +interpreter.register_implementations(MemrefFunctions()) + +index = IndexType() + + +def test_functions(): + alloc_op = memref.Alloc.get(i32, None, (2, 3)) + zero_op = arith.Constant.from_int_and_width(0, index) + one_op = arith.Constant.from_int_and_width(1, index) + two_op = arith.Constant.from_int_and_width(2, index) + forty_two_op = arith.Constant.from_int_and_width(42, 32) + store_op = memref.Store.get(forty_two_op, alloc_op, (zero_op, one_op)) + load_42_op = memref.Load.get(alloc_op, (zero_op, one_op)) + load_undef_op = memref.Load.get(alloc_op, (zero_op, two_op)) + dealloc_op = memref.Dealloc.get(alloc_op) + + (shaped_array,) = interpreter.run_op(alloc_op, ()) + v = MemrefValue.Allocated + assert shaped_array == ShapedArray([v, v, v, v, v, v], [2, 3]) + (zero,) = interpreter.run_op(zero_op, ()) + (one,) = interpreter.run_op(one_op, ()) + (two,) = interpreter.run_op(two_op, ()) + (forty_two_0,) = interpreter.run_op(forty_two_op, ()) + store_res = interpreter.run_op(store_op, (forty_two_0, shaped_array, zero, one)) + assert store_res == () + (forty_two_1,) = interpreter.run_op(load_42_op, (shaped_array, zero, one)) + assert forty_two_1 == 42 + + with pytest.raises(InterpretationError) as e: + interpreter.run_op(load_undef_op, (shaped_array, zero, two)) + e.match("uninitialized") + + dealloc_res = interpreter.run_op(dealloc_op, (shaped_array,)) + assert dealloc_res == () + + with pytest.raises(InterpretationError) as e: + interpreter.run_op(load_undef_op, (shaped_array, zero, one)) + e.match("deallocated") diff --git a/xdsl/interpreters/memref.py b/xdsl/interpreters/memref.py new file mode 100644 index 0000000000..ffcb5fc105 --- /dev/null +++ b/xdsl/interpreters/memref.py @@ -0,0 +1,69 @@ +from enum import Enum +from math import prod +from typing import Any, cast + +from xdsl.dialects import memref +from xdsl.interpreter import Interpreter, InterpreterFunctions, impl, register_impls +from xdsl.interpreters.shaped_array import ShapedArray +from xdsl.ir.core import Attribute +from xdsl.utils.exceptions import InterpretationError + + +class MemrefValue(Enum): + Allocated = 1 + Deallocated = 2 + + +@register_impls +class MemrefFunctions(InterpreterFunctions): + @impl(memref.Alloc) + def run_alloc(self, interpreter: Interpreter, op: memref.Alloc, args: tuple[()]): + memref_typ = cast(memref.MemRefType[Attribute], op.memref.type) + + shape = memref_typ.get_shape() + size = prod(shape) + data = [MemrefValue.Allocated] * size + + shaped_array = ShapedArray(data, list(shape)) + return (shaped_array,) + + @impl(memref.Dealloc) + def run_dealloc( + self, interpreter: Interpreter, op: memref.Dealloc, args: tuple[Any, ...] + ): + (shaped_array,) = args + for i in range(len(shaped_array.data)): + shaped_array.data[i] = MemrefValue.Deallocated + return () + + @impl(memref.Store) + def run_store( + self, interpreter: Interpreter, op: memref.Store, args: tuple[Any, ...] + ): + value, memref, *indices = args + + memref = cast(ShapedArray[Any], memref) + + indices = tuple(indices) + memref.store(indices, value) + + return () + + @impl(memref.Load) + def run_load( + self, interpreter: Interpreter, op: memref.Load, args: tuple[Any, ...] + ): + shaped_array, *indices = args + + shaped_array = cast(ShapedArray[Any], shaped_array) + + indices = tuple(indices) + value = shaped_array.load(indices) + + if isinstance(value, MemrefValue): + state = "uninitialized" if value == MemrefValue.Allocated else "deallocated" + raise InterpretationError( + f"Cannot load {state} value from memref {shaped_array}" + ) + + return (value,)