Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement is_optional_symbol for MLIR compliance. #1330

Merged
merged 9 commits into from
Jul 24, 2023
31 changes: 30 additions & 1 deletion tests/test_traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
attr_def,
irdl_op_definition,
operand_def,
opt_attr_def,
result_def,
)
from xdsl.traits import SymbolOpInterface
from xdsl.traits import OptionalSymbolOpInterface, SymbolOpInterface
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.test_value import TestSSAValue

Expand Down Expand Up @@ -295,3 +296,31 @@ class SymNameOp(IRDLOperation):

op2 = SymNameOp(attributes={"sym_name": StringAttr("symbol_name")})
op2.verify()


def test_optional_symbol_op_interface():
"""
Test that operations that conform to OptionalSymbolOpInterface have necessary attributes.
PapyChacal marked this conversation as resolved.
Show resolved Hide resolved
"""

@irdl_op_definition
class OptionalSymNameOp(IRDLOperation):
name = "no_sym_name"

sym_name = opt_attr_def(StringAttr)

traits = frozenset((OptionalSymbolOpInterface(),))

non_symbol = OptionalSymNameOp()
PapyChacal marked this conversation as resolved.
Show resolved Hide resolved
interface = non_symbol.get_trait(SymbolOpInterface)
assert interface is not None
assert interface.is_optional_symbol(non_symbol)
non_symbol.verify()
assert SymbolOpInterface.get_sym_attr_name(non_symbol) is None

symbol = OptionalSymNameOp(attributes={"sym_name": StringAttr("main")})
interface = symbol.get_trait(SymbolOpInterface)
assert interface is not None
assert interface.is_optional_symbol(symbol)
symbol.verify()
assert SymbolOpInterface.get_sym_attr_name(symbol) == StringAttr("main")
6 changes: 4 additions & 2 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
var_region_def,
var_result_def,
)
from xdsl.traits import IsolatedFromAbove, NoTerminator
from xdsl.traits import IsolatedFromAbove, NoTerminator, OptionalSymbolOpInterface
from xdsl.utils.exceptions import VerifyException

if TYPE_CHECKING:
Expand Down Expand Up @@ -1217,7 +1217,9 @@ class ModuleOp(IRDLOperation):

body: Region = region_def("single_block")

traits = frozenset([IsolatedFromAbove(), NoTerminator()])
traits = frozenset(
[IsolatedFromAbove(), NoTerminator(), OptionalSymbolOpInterface()]
)

def __init__(
self,
Expand Down
3 changes: 2 additions & 1 deletion xdsl/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,8 @@ def symbol_table(self) -> dict[str, Operation]:
for op in self.module.walk():
if op.has_trait(SymbolOpInterface):
symbol = SymbolOpInterface.get_sym_attr_name(op)
self._symbol_table[symbol.data] = op
if symbol:
self._symbol_table[symbol.data] = op
return self._symbol_table

def get_values(self, values: Iterable[SSAValue]) -> tuple[Any, ...]:
Expand Down
29 changes: 28 additions & 1 deletion xdsl/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,21 +212,38 @@ class SymbolOpInterface(OpTrait):
"""

@staticmethod
def get_sym_attr_name(op: Operation) -> StringAttr:
def get_sym_attr_name(op: Operation) -> StringAttr | None:
PapyChacal marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns the symbol of the operation
compor marked this conversation as resolved.
Show resolved Hide resolved
"""
# import builtin here to avoid circular import
from xdsl.dialects.builtin import StringAttr

concrete = op.get_trait(SymbolOpInterface)
assert concrete is not None
if concrete.is_optional_symbol(op) and "sym_name" not in op.attributes:
return None
attr = op.attributes["sym_name"]
PapyChacal marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(attr, StringAttr)
return attr

@staticmethod
def is_optional_symbol(op: Operation) -> bool:
PapyChacal marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns true if this operation optionally defines a symbol based on the
presence of the symbol name.
"""
return False

def verify(self, op: Operation) -> None:
# import builtin here to avoid circular import
from xdsl.dialects.builtin import StringAttr

# If this is an optional symbol, bail out early if possible.
concrete = op.get_trait(SymbolOpInterface)
assert concrete is not None
if concrete.is_optional_symbol(op) and "sym_name" not in op.attributes:
return
PapyChacal marked this conversation as resolved.
Show resolved Hide resolved
if "sym_name" not in op.attributes or not isinstance(
op.attributes["sym_name"], StringAttr
):
Expand All @@ -236,6 +253,16 @@ def verify(self, op: Operation) -> None:
)


class OptionalSymbolOpInterface(SymbolOpInterface):
"""
Helper interface specialization for an optional Symbol.
"""

@staticmethod
def is_optional_symbol(op: Operation) -> bool:
return True


Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Digging in MLIR's sources, that's the only specialization I saw, and a few times, hence exposing it directly as is in the xDSL practical/firendly mindset!

class CallableOpInterface(OpTrait, abc.ABC):
"""
Interface for function-like Operations that can be called in a generic way.
Expand Down