Skip to content

Commit

Permalink
dialect: (riscv) Custom print attributes (#1363)
Browse files Browse the repository at this point in the history
  • Loading branch information
kingiler committed Jul 31, 2023
1 parent a861cfd commit cc83842
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 15 deletions.
12 changes: 6 additions & 6 deletions tests/filecheck/dialects/riscv/riscv_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@
riscv.j {"immediate" = #riscv.label<"label">} : () -> ()
// CHECK-NEXT: riscv.j {"immediate" = #riscv.label<"label">} : () -> ()

riscv.jalr %0 {"immediate" = 1 : i32}: (!riscv.reg<>) -> ()
// CHECK-NEXT: riscv.jalr %0 {"immediate" = 1 : i32} : (!riscv.reg<>) -> ()
riscv.jalr %0 {"immediate" = 1 : i32, "rd" = !riscv.reg<>} : (!riscv.reg<>) -> ()
// CHECK-NEXT: riscv.jalr %0 {"immediate" = 1 : i32, "rd" = !riscv.reg<>} : (!riscv.reg<>) -> ()
riscv.jalr %0 {"immediate" = #riscv.label<"label">} : (!riscv.reg<>) -> ()
// CHECK-NEXT: riscv.jalr %0 {"immediate" = #riscv.label<"label">} : (!riscv.reg<>) -> ()
riscv.jalr %0, 1: (!riscv.reg<>) -> ()
// CHECK-NEXT: riscv.jalr %0, 1 : (!riscv.reg<>) -> ()
riscv.jalr %0, 1, !riscv.reg<> : (!riscv.reg<>) -> ()
// CHECK-NEXT: riscv.jalr %0, 1, !riscv.reg<> : (!riscv.reg<>) -> ()
riscv.jalr %0, "label" : (!riscv.reg<>) -> ()
// CHECK-NEXT: riscv.jalr %0, "label" : (!riscv.reg<>) -> ()

riscv.ret : () -> ()
// CHECK-NEXT: riscv.ret : () -> ()
Expand Down
78 changes: 69 additions & 9 deletions xdsl/dialects/riscv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Mapping, Set
from io import StringIO
from typing import IO, ClassVar, Sequence, TypeAlias

Expand Down Expand Up @@ -42,7 +43,7 @@
var_operand_def,
var_result_def,
)
from xdsl.parser import AttrParser, Parser
from xdsl.parser import AttrParser, Parser, UnresolvedOperand
from xdsl.printer import Printer
from xdsl.traits import IsTerminator, NoTerminator
from xdsl.utils.exceptions import VerifyException
Expand Down Expand Up @@ -333,13 +334,11 @@ def assembly_line(self) -> str | None:

@classmethod
def parse(cls, parser: Parser) -> Self:
args = parser.parse_optional_undelimited_comma_separated_list(
parser.parse_optional_unresolved_operand,
parser.parse_unresolved_operand,
)
if args is None:
args = []
attributes = parser.parse_optional_attr_dict()
args = cls.parse_unresolved_operand(parser)
custom_attributes = cls.custom_parse_attributes(parser)
remaining_attributes = parser.parse_optional_attr_dict()
# TODO ensure distinct keys for attributes
attributes = custom_attributes | remaining_attributes
regions = parser.parse_region_list()
parser.parse_punctuation(":")
func_type = parser.parse_function_type()
Expand All @@ -351,15 +350,51 @@ def parse(cls, parser: Parser) -> Self:
regions=regions,
)

@classmethod
def parse_unresolved_operand(cls, parser: Parser) -> list[UnresolvedOperand]:
"""
Parse a list of comma separated unresolved operands.
Notice that this method will consume trailing comma.
"""
if operand := parser.parse_optional_unresolved_operand():
operands = [operand]
while parser.parse_optional_punctuation(",") and (
operand := parser.parse_optional_unresolved_operand()
):
operands.append(operand)
return operands
return []

@classmethod
def custom_parse_attributes(cls, parser: Parser) -> Mapping[str, Attribute]:
"""
Parse attributes with custom syntax. Subclasses may override this method.
"""
return parser.parse_optional_attr_dict()

def print(self, printer: Printer) -> None:
if self.operands:
printer.print(" ")
printer.print_list(self.operands, printer.print_operand)
printer.print_op_attributes(self.attributes)
printed_attributes = self.custom_print_attributes(printer)
unprinted_attributes = {
name: attr
for name, attr in self.attributes.items()
if name not in printed_attributes
}
printer.print_op_attributes(unprinted_attributes)
printer.print_regions(self.regions)
printer.print(" : ")
printer.print_operation_type(self)

def custom_print_attributes(self, printer: Printer) -> Set[str]:
"""
Print attributes with custom syntax. Return the names of the attributes printed. Subclasses may override this method.
"""
printer.print_op_attributes(self.attributes)
return self.attributes.keys()


AssemblyInstructionArg: TypeAlias = (
AnyIntegerAttr | LabelAttr | SSAValue | IntRegisterType | str
Expand Down Expand Up @@ -721,6 +756,31 @@ def __init__(
def assembly_line_args(self) -> tuple[AssemblyInstructionArg | None, ...]:
return self.rd, self.rs1, self.immediate

@classmethod
def custom_parse_attributes(cls, parser: Parser) -> Mapping[str, Attribute]:
attributes = dict[str, Attribute]()
if immediate := parser.parse_optional_integer(allow_boolean=False):
attributes["immediate"] = IntegerAttr(
immediate, IntegerType(12, Signedness.SIGNED)
)
elif immediate := parser.parse_optional_str_literal():
attributes["immediate"] = LabelAttr(immediate)
if parser.parse_optional_punctuation(","):
attributes["rd"] = parser.parse_attribute()
return attributes

def custom_print_attributes(self, printer: Printer) -> Set[str]:
printer.print(", ")
match self.immediate:
case IntegerAttr():
printer.print(self.immediate.value.data)
case LabelAttr():
printer.print_string_literal(self.immediate.data)
if self.rd is not None:
printer.print(", ")
printer.print_attribute(self.rd)
return {"immediate", "rd"}


class RdRsIntegerOperation(IRDLOperation, RISCVInstruction, ABC):
"""
Expand Down

0 comments on commit cc83842

Please sign in to comment.