Skip to content

Commit

Permalink
core: (trait) Polish HasParent trait (#1438)
Browse files Browse the repository at this point in the history
Main change is the use of `*args` for `parameters`. Another change is
using generators instead of list.
  • Loading branch information
kingiler committed Aug 10, 2023
1 parent cd00bd8 commit ec6c3ed
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 8 deletions.
2 changes: 1 addition & 1 deletion tests/test_builtin_traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class HasMultipleParentOp(IRDLOperation):

name = "test.has_multiple_parent"

traits = frozenset([HasParent((ParentOp, Parent2Op))])
traits = frozenset([HasParent(ParentOp, Parent2Op)])


def test_has_parent_no_parent():
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/irdl/irdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class ParametersOp(IRDLOperation):

args: VarOperand = var_operand_def(AttributeType)

traits = frozenset([HasParent((TypeOp, AttributeOp))])
traits = frozenset([HasParent(TypeOp, AttributeOp)])

def __init__(self, args: Sequence[SSAValue]):
super().__init__(operands=[args])
Expand Down
10 changes: 4 additions & 6 deletions xdsl/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,20 @@ class HasParent(OpTrait):

parameters: tuple[type[Operation], ...]

def __init__(self, parameters: type[Operation] | tuple[type[Operation], ...]):
if not isinstance(parameters, tuple):
parameters = (parameters,)
if len(parameters) == 0:
def __init__(self, *parameters: type[Operation]):
if not parameters:
raise ValueError("parameters must not be empty")
super().__init__(parameters)

def verify(self, op: Operation) -> None:
parent = op.parent_op()
if isinstance(parent, tuple(self.parameters)):
if isinstance(parent, self.parameters):
return
if len(self.parameters) == 1:
raise VerifyException(
f"'{op.name}' expects parent op '{self.parameters[0].name}'"
)
names = ", ".join([f"'{p.name}'" for p in self.parameters])
names = ", ".join(f"'{p.name}'" for p in self.parameters)
raise VerifyException(f"'{op.name}' expects parent op to be one of {names}")


Expand Down

0 comments on commit ec6c3ed

Please sign in to comment.