diff --git a/tests/dialects/test_stencil.py b/tests/dialects/test_stencil.py index 98d402f2f6..2b07385c20 100644 --- a/tests/dialects/test_stencil.py +++ b/tests/dialects/test_stencil.py @@ -17,6 +17,9 @@ IndexType, ) from xdsl.dialects.experimental.stencil import ( + BufferOp, + ExternalLoadOp, + ExternalStoreOp, ReturnOp, ResultType, ApplyOp, @@ -26,9 +29,11 @@ LoadOp, FieldType, IndexAttr, + StoreResultOp, IndexOp, AccessOp, ) +from xdsl.dialects.memref import MemRefType from xdsl.dialects.stencil import CastOp from xdsl.ir import Attribute, Block from xdsl.utils.exceptions import VerifyException @@ -560,3 +565,50 @@ def test_stencil_access(): assert isinstance(access, AccessOp) assert access.offset == offset_index_attr assert access.temp.typ == temp_type + + +def test_store_result(): + elem = IndexAttr.get(1) + elem_ssa_val = TestSSAValue(elem) + result_type = ResultType(f32) + + store_result = StoreResultOp.build( + operands=[elem_ssa_val], result_types=[result_type] + ) + + assert isinstance(store_result, StoreResultOp) + assert store_result.args[0] == elem_ssa_val + assert store_result.res.typ == result_type + + +def test_external_load(): + memref = TestSSAValue(MemRefType.from_element_type_and_shape(f32, ([5]))) + field_type = FieldType((5), f32) + + external_load = ExternalLoadOp.get(memref, field_type) + + assert isinstance(external_load, ExternalLoadOp) + assert external_load.field == memref + assert external_load.result.typ == field_type + + +def test_external_store(): + field = TestSSAValue(FieldType((5), f32)) + memref = TestSSAValue(MemRefType.from_element_type_and_shape(f32, ([5]))) + + external_store = ExternalStoreOp.build(operands=[field, memref]) + + assert isinstance(external_store, ExternalStoreOp) + assert external_store.field == memref + assert external_store.temp == field + + +def test_buffer(): + temp = TestSSAValue(TempType((5), f32)) + res_typ = TempType((5), f32) + + buffer = BufferOp.build(operands=[temp], result_types=[res_typ]) + + assert isinstance(buffer, BufferOp) + assert buffer.temp == temp + assert buffer.res.typ == res_typ