Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (riscv) add verifiers for riscv ops using immediate #1027

Merged
merged 11 commits into from
Jun 5, 2023
133 changes: 133 additions & 0 deletions tests/dialects/test_riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,136 @@ def test_return_op():

code = riscv_code(ModuleOp([return_op]))
assert code == " ebreak # my comment\n"


def test_immediate_i_inst():
# I-Type - 12-bits immediate
a1 = TestSSAValue(riscv.RegisterType(riscv.Registers.A1))

with pytest.raises(VerifyException):
pos_invalid_addi_op = riscv.AddiOp(a1, 1 << 11, rd=riscv.Registers.A0)
superlopuh marked this conversation as resolved.
Show resolved Hide resolved
pos_invalid_addi_op.verify()

with pytest.raises(VerifyException):
neg_invalid_addi_op = riscv.AddiOp(a1, -(1 << 11) - 2, rd=riscv.Registers.A0)
neg_invalid_addi_op.verify()

neg_valid_addi_op = riscv.AddiOp(a1, -(1 << 11), rd=riscv.Registers.A0)
neg_valid_addi_op.verify()

pos_valid_addi_op = riscv.AddiOp(a1, (1 << 11) - 1, rd=riscv.Registers.A0)
pos_valid_addi_op.verify()

"""
Special handling for signed immediates for I- and S-Type instructions
https://github.com/riscv-non-isa/riscv-asm-manual/blob/master/riscv-asm.md#signed-immediates-for-i--and-s-type-instructions
"""

sext_0_valid_addi_op = riscv.AddiOp(a1, 0xFFFFFFFFFFFFF800, rd=riscv.Registers.A0)
sext_0_valid_addi_op.verify()

sext_1_valid_addi_op = riscv.AddiOp(a1, 0xFFFFFFFFFFFFFFFF, rd=riscv.Registers.A0)
sext_1_valid_addi_op.verify()

sext_2_valid_addi_op = riscv.AddiOp(a1, 0xFFFFF800, rd=riscv.Registers.A0)
sext_2_valid_addi_op.verify()

sext_3_valid_addi_op = riscv.AddiOp(a1, 0xFFFFFFFF, rd=riscv.Registers.A0)
sext_3_valid_addi_op.verify()


def test_immediate_s_inst():
# S-Type - 12-bits immediate
a1 = TestSSAValue(riscv.RegisterType(riscv.Registers.A1))
a2 = TestSSAValue(riscv.RegisterType(riscv.Registers.A2))

with pytest.raises(VerifyException):
pos_invalid_sw_op = riscv.SwOp(a1, a2, 1 << 11)
pos_invalid_sw_op.verify()

with pytest.raises(VerifyException):
neg_invalid_sw_op = riscv.SwOp(a1, a2, -(1 << 11) - 2)
neg_invalid_sw_op.verify()

neg_valid_sw_op = riscv.SwOp(a1, a2, -(1 << 11))
neg_valid_sw_op.verify()

pos_valid_sw_op = riscv.SwOp(a1, a2, (1 << 11) - 1)
pos_valid_sw_op.verify()

"""
Special handling for signed immediates for I- and S-Type instructions
https://github.com/riscv-non-isa/riscv-asm-manual/blob/master/riscv-asm.md#signed-immediates-for-i--and-s-type-instructions
"""

sext_0_valid_sw_op = riscv.SwOp(a1, a2, 0xFFFFFFFFFFFFF800)
sext_0_valid_sw_op.verify()

sext_1_valid_sw_op = riscv.SwOp(a1, a2, 0xFFFFFFFFFFFFFFFF)
sext_1_valid_sw_op.verify()

sext_2_valid_sw_op = riscv.SwOp(a1, a2, 0xFFFFF800)
sext_2_valid_sw_op.verify()

sext_3_valid_sw_op = riscv.SwOp(a1, a2, 0xFFFFFFFF)
sext_3_valid_sw_op.verify()


def test_immediate_u_j_inst():
# U-Type and J-Type - 20-bits immediate
with pytest.raises(VerifyException):
pos_invalid_j_op = riscv.LuiOp(1 << 20)
pos_invalid_j_op.verify()

with pytest.raises(VerifyException):
neg_invalid_j_op = riscv.LuiOp(-(1 << 20) - 2)
neg_invalid_j_op.verify()

valid_j_op = riscv.LuiOp((1 << 20) - 1)
valid_j_op.verify()


def test_immediate_jalr_inst():
# Jalr - 12-bits immediate
a1 = TestSSAValue(riscv.RegisterType(riscv.Registers.A1))

with pytest.raises(VerifyException):
pos_invalid_jalr_op = riscv.JalrOp(a1, 1 << 12, rd=riscv.Registers.A0)
pos_invalid_jalr_op.verify()

with pytest.raises(VerifyException):
neg_invalid_jalr_op = riscv.JalrOp(a1, -(1 << 12) - 2, rd=riscv.Registers.A0)
neg_invalid_jalr_op.verify()

jalr_op = riscv.JalrOp(a1, (1 << 11) - 1, rd=riscv.Registers.A0)
jalr_op.verify()


def test_immediate_pseudo_inst():
# Pseudo-Instruction with custom handling
with pytest.raises(VerifyException):
neg_invalid_li_op = riscv.LiOp(-(1 << 31) - 1, rd=riscv.Registers.A0)
neg_invalid_li_op.verify()

with pytest.raises(VerifyException):
pos_invalid_li_op = riscv.LiOp(1 << 32, rd=riscv.Registers.A0)
pos_invalid_li_op.verify()

valid_li_op = riscv.LiOp((1 << 31) - 1, rd=riscv.Registers.A0)
valid_li_op.verify()


def test_immediate_shift_inst():
# Shift instructions (SLLI, SRLI, SRAI) - 5-bits immediate
a1 = TestSSAValue(riscv.RegisterType(riscv.Registers.A1))

with pytest.raises(VerifyException):
pos_invalid_slli_op = riscv.SlliOp(a1, 1 << 5, rd=riscv.Registers.A0)
pos_invalid_slli_op.verify()

with pytest.raises(VerifyException):
neg_invalid_slli_op = riscv.SlliOp(a1, -1, rd=riscv.Registers.A0)
neg_invalid_slli_op.verify()

valid_slli_op = riscv.SlliOp(a1, (1 << 5) - 1, rd=riscv.Registers.A0)
valid_slli_op.verify()
95 changes: 86 additions & 9 deletions xdsl/dialects/riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import dataclass, field
from typing import IO, Annotated, Iterable, TypeAlias, Sequence


from xdsl.ir import (
Dialect,
Operation,
Expand Down Expand Up @@ -31,7 +32,9 @@
from xdsl.printer import Printer
from xdsl.dialects.builtin import (
AnyIntegerAttr,
IntegerType,
ModuleOp,
Signedness,
UnitAttr,
IntegerAttr,
StringAttr,
Expand Down Expand Up @@ -154,6 +157,40 @@ def print_parameter(self, printer: Printer) -> None:
printer.print_string(name)


@irdl_attr_definition
class SImm12Attr(IntegerAttr[IntegerType]):
"""
A 12-bit immediate signed value.
"""

name = "riscv.simm12"

def __init__(self, value: int) -> None:
super().__init__(value, IntegerType(12, Signedness.SIGNED))

# Pyright is confused by the parent class __new__ signature
def __new__(cls, value: int) -> SImm12Attr:
return super().__new__(cls) # type: ignore
superlopuh marked this conversation as resolved.
Show resolved Hide resolved

def verify(self) -> None:
"""
All I- and S-type instructions with 12-bit signed immediates --- e.g., addi but not slli ---
accept their immediate argument as an integer in the interval [-2048, 2047]. Integers in the subinterval [-2048, -1]
can also be passed by their (unsigned) associates in the interval [0xfffff800, 0xffffffff] on RV32I,
and in [0xfffffffffffff800, 0xffffffffffffffff] on both RV32I and RV64I.

https://github.com/riscv-non-isa/riscv-asm-manual/blob/master/riscv-asm.md#signed-immediates-for-i--and-s-type-instructions
"""

if 0xFFFFFFFFFFFFF800 <= self.value.data <= 0xFFFFFFFFFFFFFFFF:
return

if 0xFFFFF800 <= self.value.data <= 0xFFFFFFFF:
return

super().verify()


@irdl_attr_definition
class LabelAttr(Data[str]):
name = "riscv.label"
Expand Down Expand Up @@ -329,7 +366,7 @@ def __init__(
comment: str | StringAttr | None = None,
):
if isinstance(immediate, int):
immediate = IntegerAttr.from_int_and_width(immediate, 32)
immediate = IntegerAttr(immediate, IntegerType(20, Signedness.UNSIGNED))
elif isinstance(immediate, str):
immediate = LabelAttr(immediate)
if rd is None:
Expand Down Expand Up @@ -374,7 +411,7 @@ def __init__(
comment: str | StringAttr | None = None,
):
if isinstance(immediate, int):
immediate = IntegerAttr.from_int_and_width(immediate, 32)
immediate = IntegerAttr(immediate, IntegerType(20, Signedness.SIGNED))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering whether we might want to convert all parameters to this bitwidth. So if we get another integer attr, we'd still catch whether the bitwidth is correct. This should possibly be done at the rewrite level, though.

elif isinstance(immediate, str):
immediate = LabelAttr(immediate)
if isinstance(rd, Register):
Expand Down Expand Up @@ -414,7 +451,7 @@ def __init__(
comment: str | StringAttr | None = None,
):
if isinstance(immediate, int):
immediate = IntegerAttr(immediate, 32)
immediate = SImm12Attr(immediate)
elif isinstance(immediate, str):
immediate = LabelAttr(immediate)

Expand All @@ -437,6 +474,34 @@ def assembly_line_args(self) -> tuple[_AssemblyInstructionArg, ...]:
return self.rd, self.rs1, self.immediate


class RdRsImmShiftOperation(RdRsImmOperation):
"""
A base class for RISC-V operations that have one destination register, one source
register and one immediate operand.

This is called I-Type in the RISC-V specification.

Shifts by a constant are encoded as a specialization of the I-type format.
The shift amount is encoded in the lower 5 bits of the I-immediate field for RV32

For RV32I, SLLI, SRLI, and SRAI generate an illegal instruction exception if
imm[5] 6 != 0 but the shift amount is encoded in the lower 6 bits of the I-immediate field for RV64I.
"""

def __init__(
self,
rs1: Operation | SSAValue,
immediate: int | AnyIntegerAttr | str | LabelAttr,
*,
rd: RegisterType | Register | None = None,
comment: str | StringAttr | None = None,
):
if isinstance(immediate, int):
immediate = IntegerAttr(immediate, IntegerType(5, Signedness.UNSIGNED))

super().__init__(rs1, immediate, rd=rd, comment=comment)


class RdRsImmJumpOperation(IRDLOperation, RISCVInstruction, ABC):
"""
A base class for RISC-V operations that have one destination register, one source
Expand Down Expand Up @@ -467,7 +532,7 @@ def __init__(
comment: str | StringAttr | None = None,
):
if isinstance(immediate, int):
immediate = IntegerAttr(immediate, 32)
immediate = IntegerAttr(immediate, IntegerType(12, Signedness.SIGNED))
elif isinstance(immediate, str):
immediate = LabelAttr(immediate)

Expand Down Expand Up @@ -543,7 +608,7 @@ def __init__(
comment: str | StringAttr | None = None,
):
if isinstance(offset, int):
offset = IntegerAttr.from_int_and_width(offset, 32)
offset = IntegerAttr(offset, 12)
if isinstance(offset, str):
offset = LabelAttr(offset)
if isinstance(comment, str):
Expand Down Expand Up @@ -582,7 +647,7 @@ def __init__(
comment: str | StringAttr | None = None,
):
if isinstance(immediate, int):
immediate = IntegerAttr.from_int_and_width(immediate, 32)
immediate = SImm12Attr(immediate)
elif isinstance(immediate, str):
immediate = LabelAttr(immediate)
if isinstance(comment, str):
Expand Down Expand Up @@ -947,7 +1012,7 @@ class XoriOp(RdRsImmOperation):


@irdl_op_definition
class SlliOp(RdRsImmOperation):
class SlliOp(RdRsImmShiftOperation):
"""
Performs logical left shift on the value in register rs1 by the shift amount
held in the lower 5 bits of the immediate.
Expand All @@ -961,7 +1026,7 @@ class SlliOp(RdRsImmOperation):


@irdl_op_definition
class SrliOp(RdRsImmOperation):
class SrliOp(RdRsImmShiftOperation):
"""
Performs logical right shift on the value in register rs1 by the shift amount held
in the lower 5 bits of the immediate.
Expand All @@ -975,7 +1040,7 @@ class SrliOp(RdRsImmOperation):


@irdl_op_definition
class SraiOp(RdRsImmOperation):
class SraiOp(RdRsImmShiftOperation):
"""
Performs arithmetic right shift on the value in register rs1 by the shift amount
held in the lower 5 bits of the immediate.
Expand Down Expand Up @@ -1704,6 +1769,18 @@ class LiOp(RdImmOperation):

name = "riscv.li"

def __init__(
self,
immediate: int | AnyIntegerAttr | str | LabelAttr,
*,
rd: RegisterType | Register | None = None,
comment: str | StringAttr | None = None,
):
if isinstance(immediate, int):
immediate = IntegerAttr(immediate, IntegerType(32, Signedness.SIGNED))

super().__init__(immediate, rd=rd, comment=comment)


@irdl_op_definition
class EcallOp(NullaryOperation):
Expand Down