Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

interpreter: (memref) add implementations for some memref operations #1254

Merged
merged 4 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 50 additions & 0 deletions tests/interpreters/test_memref_interpreter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest

from xdsl.dialects import arith, memref
from xdsl.dialects.builtin import IndexType, ModuleOp, i32
from xdsl.interpreter import Interpreter
from xdsl.interpreters.arith import ArithFunctions
from xdsl.interpreters.memref import MemrefFunctions, MemrefValue
from xdsl.interpreters.shaped_array import ShapedArray
from xdsl.utils.exceptions import InterpretationError

interpreter = Interpreter(ModuleOp([]))
interpreter.register_implementations(ArithFunctions())
interpreter.register_implementations(MemrefFunctions())

index = IndexType()


def test_functions():
alloc_op = memref.Alloc.get(i32, None, (2, 3))
zero_op = arith.Constant.from_int_and_width(0, index)
one_op = arith.Constant.from_int_and_width(1, index)
two_op = arith.Constant.from_int_and_width(2, index)
forty_two_op = arith.Constant.from_int_and_width(42, 32)
store_op = memref.Store.get(forty_two_op, alloc_op, (zero_op, one_op))
load_42_op = memref.Load.get(alloc_op, (zero_op, one_op))
load_undef_op = memref.Load.get(alloc_op, (zero_op, two_op))
dealloc_op = memref.Dealloc.get(alloc_op)

(shaped_array,) = interpreter.run_op(alloc_op, ())
v = MemrefValue.Allocated
assert shaped_array == ShapedArray([v, v, v, v, v, v], [2, 3])
(zero,) = interpreter.run_op(zero_op, ())
(one,) = interpreter.run_op(one_op, ())
(two,) = interpreter.run_op(two_op, ())
(forty_two_0,) = interpreter.run_op(forty_two_op, ())
store_res = interpreter.run_op(store_op, (forty_two_0, shaped_array, zero, one))
assert store_res == ()
(forty_two_1,) = interpreter.run_op(load_42_op, (shaped_array, zero, one))
assert forty_two_1 == 42

with pytest.raises(InterpretationError) as e:
interpreter.run_op(load_undef_op, (shaped_array, zero, two))
e.match("uninitialized")

dealloc_res = interpreter.run_op(dealloc_op, (shaped_array,))
assert dealloc_res == ()

with pytest.raises(InterpretationError) as e:
interpreter.run_op(load_undef_op, (shaped_array, zero, one))
e.match("deallocated")
69 changes: 69 additions & 0 deletions xdsl/interpreters/memref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from enum import Enum
from math import prod
from typing import Any, cast

from xdsl.dialects import memref
from xdsl.interpreter import Interpreter, InterpreterFunctions, impl, register_impls
from xdsl.interpreters.shaped_array import ShapedArray
from xdsl.ir.core import Attribute
from xdsl.utils.exceptions import InterpretationError


class MemrefValue(Enum):
Allocated = 1
Deallocated = 2


@register_impls
class MemrefFunctions(InterpreterFunctions):
@impl(memref.Alloc)
def run_alloc(self, interpreter: Interpreter, op: memref.Alloc, args: tuple[()]):
memref_typ = cast(memref.MemRefType[Attribute], op.memref.type)

shape = memref_typ.get_shape()
size = prod(shape)
data = [MemrefValue.Allocated] * size

shaped_array = ShapedArray(data, list(shape))
return (shaped_array,)

@impl(memref.Dealloc)
def run_dealloc(
self, interpreter: Interpreter, op: memref.Dealloc, args: tuple[Any, ...]
):
(shaped_array,) = args
for i in range(len(shaped_array.data)):
shaped_array.data[i] = MemrefValue.Deallocated
return ()

@impl(memref.Store)
def run_store(
self, interpreter: Interpreter, op: memref.Store, args: tuple[Any, ...]
):
value, memref, *indices = args

memref = cast(ShapedArray[Any], memref)

indices = tuple(indices)
memref.store(indices, value)

return ()

@impl(memref.Load)
def run_load(
self, interpreter: Interpreter, op: memref.Load, args: tuple[Any, ...]
):
shaped_array, *indices = args

shaped_array = cast(ShapedArray[Any], shaped_array)

indices = tuple(indices)
value = shaped_array.load(indices)

if isinstance(value, MemrefValue):
state = "uninitialized" if value == MemrefValue.Allocated else "deallocated"
raise InterpretationError(
f"Cannot load {state} value from memref {shaped_array}"
)

return (value,)