diff --git a/tests/dialects/test_stencil.py b/tests/dialects/test_stencil.py index 4fb70b97b4..53a1aa9f78 100644 --- a/tests/dialects/test_stencil.py +++ b/tests/dialects/test_stencil.py @@ -14,15 +14,19 @@ i64, IntegerType, ArrayAttr, + IndexType, ) from xdsl.dialects.experimental.stencil import ( ReturnOp, ResultType, ApplyOp, + StoreOp, TempType, LoadOp, FieldType, IndexAttr, + IndexOp, + AccessOp, ) from xdsl.dialects.stencil import CastOp from xdsl.ir import Block @@ -498,3 +502,79 @@ def test_stencil_resulttype(float_type: AnyFloat): assert isinstance(stencil_resulttype, ResultType) assert stencil_resulttype.elem == float_type + + +def test_stencil_store(): + temp_type = TempType([5, 5], f32) + temp_type_ssa_val = TestSSAValue(temp_type) + + field_type = FieldType([2, 2], f32) + field_type_ssa_val = TestSSAValue(field_type) + + lb = IndexAttr.get(1, 1) + ub = IndexAttr.get(64, 64) + + store = StoreOp.get(temp_type_ssa_val, field_type_ssa_val, lb, ub) + + assert isinstance(store, StoreOp) + assert isinstance(store.field.typ, FieldType) + assert store.field.typ == field_type + assert isinstance(store.temp.typ, TempType) + assert store.temp.typ == temp_type + assert len(store.field.typ.shape) == 2 + assert len(store.temp.typ.shape) == 2 + assert store.lb is lb + assert store.ub is ub + + +def test_stencil_store_load_overlap(): + temp_type = TempType([5, 5], f32) + temp_type_ssa_val = TestSSAValue(temp_type) + + field_type = FieldType([2, 2], f32) + field_type_ssa_val = TestSSAValue(field_type) + + lb = IndexAttr.get(1, 1) + ub = IndexAttr.get(64, 64) + + load = LoadOp.get(field_type_ssa_val, lb, ub) + store = StoreOp.get(temp_type_ssa_val, field_type_ssa_val, lb, ub) + + with pytest.raises(VerifyException) as exc_info: + load.verify() + assert exc_info.value.args[0] == "Cannot Load and Store the same field!" + + with pytest.raises(VerifyException) as exc_info: + store.verify() + assert exc_info.value.args[0] == "Cannot Load and Store the same field!" + + +def test_stencil_index(): + dim = IntAttr(10) + offset = IndexAttr.get(1) + + index = IndexOp.build( + attributes={ + "dim": dim, + "offset": offset, + }, + result_types=[IndexType()], + ) + + assert isinstance(index, IndexOp) + assert index.dim is dim + assert index.offset is offset + + +def test_stencil_access(): + temp_type = TempType([5, 5], f32) + temp_type_ssa_val = TestSSAValue(temp_type) + + offset = [1, 1] + offset_index_attr = IndexAttr.get(*offset) + + access = AccessOp.get(temp_type_ssa_val, offset) + + assert isinstance(access, AccessOp) + assert access.offset == offset_index_attr + assert access.temp.typ == temp_type