diff --git a/tests/dialects/test_riscv.py b/tests/dialects/test_riscv.py index 9b20a5de41..1b0bfec64d 100644 --- a/tests/dialects/test_riscv.py +++ b/tests/dialects/test_riscv.py @@ -141,3 +141,101 @@ def test_return_op(): code = riscv_code(ModuleOp([return_op])) assert code == " ebreak # my comment\n" + + +def test_immediate_i_inst(): + # I-Type - 12-bits immediate + a1 = TestSSAValue(riscv.RegisterType(riscv.Registers.A1)) + + with pytest.raises(VerifyException): + riscv.AddiOp(a1, 1 << 11, rd=riscv.Registers.A0) + + with pytest.raises(VerifyException): + riscv.AddiOp(a1, -(1 << 11) - 2, rd=riscv.Registers.A0) + + riscv.AddiOp(a1, -(1 << 11), rd=riscv.Registers.A0) + + riscv.AddiOp(a1, (1 << 11) - 1, rd=riscv.Registers.A0) + + """ + Special handling for signed immediates for I- and S-Type instructions + https://github.com/riscv-non-isa/riscv-asm-manual/blob/master/riscv-asm.md#signed-immediates-for-i--and-s-type-instructions + """ + + riscv.AddiOp(a1, 0xFFFFFFFFFFFFF800, rd=riscv.Registers.A0) + riscv.AddiOp(a1, 0xFFFFFFFFFFFFFFFF, rd=riscv.Registers.A0) + riscv.AddiOp(a1, 0xFFFFF800, rd=riscv.Registers.A0) + riscv.AddiOp(a1, 0xFFFFFFFF, rd=riscv.Registers.A0) + + +def test_immediate_s_inst(): + # S-Type - 12-bits immediate + a1 = TestSSAValue(riscv.RegisterType(riscv.Registers.A1)) + a2 = TestSSAValue(riscv.RegisterType(riscv.Registers.A2)) + + with pytest.raises(VerifyException): + riscv.SwOp(a1, a2, 1 << 11) + + with pytest.raises(VerifyException): + riscv.SwOp(a1, a2, -(1 << 11) - 2) + + riscv.SwOp(a1, a2, -(1 << 11)) + riscv.SwOp(a1, a2, (1 << 11) - 1) + + """ + Special handling for signed immediates for I- and S-Type instructions + https://github.com/riscv-non-isa/riscv-asm-manual/blob/master/riscv-asm.md#signed-immediates-for-i--and-s-type-instructions + """ + + riscv.SwOp(a1, a2, 0xFFFFFFFFFFFFF800) + riscv.SwOp(a1, a2, 0xFFFFFFFFFFFFFFFF) + riscv.SwOp(a1, a2, 0xFFFFF800) + riscv.SwOp(a1, a2, 0xFFFFFFFF) + + +def test_immediate_u_j_inst(): + # U-Type and J-Type - 20-bits immediate + with pytest.raises(VerifyException): + riscv.LuiOp(1 << 20) + + with pytest.raises(VerifyException): + riscv.LuiOp(-(1 << 20) - 2) + + riscv.LuiOp((1 << 20) - 1) + + +def test_immediate_jalr_inst(): + # Jalr - 12-bits immediate + a1 = TestSSAValue(riscv.RegisterType(riscv.Registers.A1)) + + with pytest.raises(VerifyException): + riscv.JalrOp(a1, 1 << 12, rd=riscv.Registers.A0) + + with pytest.raises(VerifyException): + riscv.JalrOp(a1, -(1 << 12) - 2, rd=riscv.Registers.A0) + + riscv.JalrOp(a1, (1 << 11) - 1, rd=riscv.Registers.A0) + + +def test_immediate_pseudo_inst(): + # Pseudo-Instruction with custom handling + with pytest.raises(VerifyException): + riscv.LiOp(-(1 << 31) - 1, rd=riscv.Registers.A0) + + with pytest.raises(VerifyException): + riscv.LiOp(1 << 32, rd=riscv.Registers.A0) + + riscv.LiOp((1 << 31) - 1, rd=riscv.Registers.A0) + + +def test_immediate_shift_inst(): + # Shift instructions (SLLI, SRLI, SRAI) - 5-bits immediate + a1 = TestSSAValue(riscv.RegisterType(riscv.Registers.A1)) + + with pytest.raises(VerifyException): + riscv.SlliOp(a1, 1 << 5, rd=riscv.Registers.A0) + + with pytest.raises(VerifyException): + riscv.SlliOp(a1, -1, rd=riscv.Registers.A0) + + riscv.SlliOp(a1, (1 << 5) - 1, rd=riscv.Registers.A0) diff --git a/xdsl/dialects/riscv.py b/xdsl/dialects/riscv.py index a7ea281310..7145b6b809 100644 --- a/xdsl/dialects/riscv.py +++ b/xdsl/dialects/riscv.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, field from typing import IO, Annotated, Iterable, TypeAlias, Sequence + from xdsl.ir import ( Dialect, Operation, @@ -31,7 +32,9 @@ from xdsl.printer import Printer from xdsl.dialects.builtin import ( AnyIntegerAttr, + IntegerType, ModuleOp, + Signedness, UnitAttr, IntegerAttr, StringAttr, @@ -153,6 +156,36 @@ def print_parameter(self, printer: Printer) -> None: printer.print_string(name) +@irdl_attr_definition +class SImm12Attr(IntegerAttr[IntegerType]): + """ + A 12-bit immediate signed value. + """ + + name = "riscv.simm12" + + def __init__(self, value: int) -> None: + super().__init__(value, IntegerType(12, Signedness.SIGNED)) + + def verify(self) -> None: + """ + All I- and S-type instructions with 12-bit signed immediates --- e.g., addi but not slli --- + accept their immediate argument as an integer in the interval [-2048, 2047]. Integers in the subinterval [-2048, -1] + can also be passed by their (unsigned) associates in the interval [0xfffff800, 0xffffffff] on RV32I, + and in [0xfffffffffffff800, 0xffffffffffffffff] on both RV32I and RV64I. + + https://github.com/riscv-non-isa/riscv-asm-manual/blob/master/riscv-asm.md#signed-immediates-for-i--and-s-type-instructions + """ + + if 0xFFFFFFFFFFFFF800 <= self.value.data <= 0xFFFFFFFFFFFFFFFF: + return + + if 0xFFFFF800 <= self.value.data <= 0xFFFFFFFF: + return + + super().verify() + + @irdl_attr_definition class LabelAttr(Data[str]): name = "riscv.label" @@ -328,7 +361,7 @@ def __init__( comment: str | StringAttr | None = None, ): if isinstance(immediate, int): - immediate = IntegerAttr.from_int_and_width(immediate, 32) + immediate = IntegerAttr(immediate, IntegerType(20, Signedness.UNSIGNED)) elif isinstance(immediate, str): immediate = LabelAttr(immediate) if rd is None: @@ -373,7 +406,7 @@ def __init__( comment: str | StringAttr | None = None, ): if isinstance(immediate, int): - immediate = IntegerAttr.from_int_and_width(immediate, 32) + immediate = IntegerAttr(immediate, IntegerType(20, Signedness.SIGNED)) elif isinstance(immediate, str): immediate = LabelAttr(immediate) if isinstance(rd, Register): @@ -413,7 +446,7 @@ def __init__( comment: str | StringAttr | None = None, ): if isinstance(immediate, int): - immediate = IntegerAttr(immediate, 32) + immediate = SImm12Attr(immediate) elif isinstance(immediate, str): immediate = LabelAttr(immediate) @@ -436,6 +469,34 @@ def assembly_line_args(self) -> tuple[_AssemblyInstructionArg, ...]: return self.rd, self.rs1, self.immediate +class RdRsImmShiftOperation(RdRsImmOperation): + """ + A base class for RISC-V operations that have one destination register, one source + register and one immediate operand. + + This is called I-Type in the RISC-V specification. + + Shifts by a constant are encoded as a specialization of the I-type format. + The shift amount is encoded in the lower 5 bits of the I-immediate field for RV32 + + For RV32I, SLLI, SRLI, and SRAI generate an illegal instruction exception if + imm[5] 6 != 0 but the shift amount is encoded in the lower 6 bits of the I-immediate field for RV64I. + """ + + def __init__( + self, + rs1: Operation | SSAValue, + immediate: int | AnyIntegerAttr | str | LabelAttr, + *, + rd: RegisterType | Register | None = None, + comment: str | StringAttr | None = None, + ): + if isinstance(immediate, int): + immediate = IntegerAttr(immediate, IntegerType(5, Signedness.UNSIGNED)) + + super().__init__(rs1, immediate, rd=rd, comment=comment) + + class RdRsImmJumpOperation(IRDLOperation, RISCVInstruction, ABC): """ A base class for RISC-V operations that have one destination register, one source @@ -466,7 +527,7 @@ def __init__( comment: str | StringAttr | None = None, ): if isinstance(immediate, int): - immediate = IntegerAttr(immediate, 32) + immediate = IntegerAttr(immediate, IntegerType(12, Signedness.SIGNED)) elif isinstance(immediate, str): immediate = LabelAttr(immediate) @@ -542,7 +603,7 @@ def __init__( comment: str | StringAttr | None = None, ): if isinstance(offset, int): - offset = IntegerAttr.from_int_and_width(offset, 32) + offset = IntegerAttr(offset, 12) if isinstance(offset, str): offset = LabelAttr(offset) if isinstance(comment, str): @@ -581,7 +642,7 @@ def __init__( comment: str | StringAttr | None = None, ): if isinstance(immediate, int): - immediate = IntegerAttr.from_int_and_width(immediate, 32) + immediate = SImm12Attr(immediate) elif isinstance(immediate, str): immediate = LabelAttr(immediate) if isinstance(comment, str): @@ -946,7 +1007,7 @@ class XoriOp(RdRsImmOperation): @irdl_op_definition -class SlliOp(RdRsImmOperation): +class SlliOp(RdRsImmShiftOperation): """ Performs logical left shift on the value in register rs1 by the shift amount held in the lower 5 bits of the immediate. @@ -960,7 +1021,7 @@ class SlliOp(RdRsImmOperation): @irdl_op_definition -class SrliOp(RdRsImmOperation): +class SrliOp(RdRsImmShiftOperation): """ Performs logical right shift on the value in register rs1 by the shift amount held in the lower 5 bits of the immediate. @@ -974,7 +1035,7 @@ class SrliOp(RdRsImmOperation): @irdl_op_definition -class SraiOp(RdRsImmOperation): +class SraiOp(RdRsImmShiftOperation): """ Performs arithmetic right shift on the value in register rs1 by the shift amount held in the lower 5 bits of the immediate. @@ -1703,6 +1764,18 @@ class LiOp(RdImmOperation): name = "riscv.li" + def __init__( + self, + immediate: int | AnyIntegerAttr | str | LabelAttr, + *, + rd: RegisterType | Register | None = None, + comment: str | StringAttr | None = None, + ): + if isinstance(immediate, int): + immediate = IntegerAttr(immediate, IntegerType(32, Signedness.SIGNED)) + + super().__init__(immediate, rd=rd, comment=comment) + @irdl_op_definition class EcallOp(NullaryOperation):