Skip to content

Commit

Permalink
dialects: Add tests for remaining ops in Stencil dialect (#1084)
Browse files Browse the repository at this point in the history
  • Loading branch information
meshtag committed Jun 6, 2023
1 parent 31de8ad commit d9d2067
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions tests/dialects/test_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
IndexType,
)
from xdsl.dialects.experimental.stencil import (
BufferOp,
ExternalLoadOp,
ExternalStoreOp,
ReturnOp,
ResultType,
ApplyOp,
Expand All @@ -26,9 +29,11 @@
LoadOp,
FieldType,
IndexAttr,
StoreResultOp,
IndexOp,
AccessOp,
)
from xdsl.dialects.memref import MemRefType
from xdsl.dialects.stencil import CastOp
from xdsl.ir import Attribute, Block
from xdsl.utils.exceptions import VerifyException
Expand Down Expand Up @@ -560,3 +565,50 @@ def test_stencil_access():
assert isinstance(access, AccessOp)
assert access.offset == offset_index_attr
assert access.temp.typ == temp_type


def test_store_result():
elem = IndexAttr.get(1)
elem_ssa_val = TestSSAValue(elem)
result_type = ResultType(f32)

store_result = StoreResultOp.build(
operands=[elem_ssa_val], result_types=[result_type]
)

assert isinstance(store_result, StoreResultOp)
assert store_result.args[0] == elem_ssa_val
assert store_result.res.typ == result_type


def test_external_load():
memref = TestSSAValue(MemRefType.from_element_type_and_shape(f32, ([5])))
field_type = FieldType((5), f32)

external_load = ExternalLoadOp.get(memref, field_type)

assert isinstance(external_load, ExternalLoadOp)
assert external_load.field == memref
assert external_load.result.typ == field_type


def test_external_store():
field = TestSSAValue(FieldType((5), f32))
memref = TestSSAValue(MemRefType.from_element_type_and_shape(f32, ([5])))

external_store = ExternalStoreOp.build(operands=[field, memref])

assert isinstance(external_store, ExternalStoreOp)
assert external_store.field == memref
assert external_store.temp == field


def test_buffer():
temp = TestSSAValue(TempType((5), f32))
res_typ = TempType((5), f32)

buffer = BufferOp.build(operands=[temp], result_types=[res_typ])

assert isinstance(buffer, BufferOp)
assert buffer.temp == temp
assert buffer.res.typ == res_typ

0 comments on commit d9d2067

Please sign in to comment.