Skip to content

[mlir][tablegen] Underlying signless integer storage for Enum Attributes is handled incorrectly #144005

Closed
@0xMihir

Description

@0xMihir

Description

If we create an integer enum attribute with a case that has the MSB set, it will fail validation due to an overflow.

The following case illustrates the bug.

FIrst, we can define an enum in tablegen:

def I32Case5:  I32EnumAttrCase<"case5", 5>;
def I32Case10: I32EnumAttrCase<"case10", 10>;
def I32CaseSignedMaxPlusOne: I32EnumAttrCase<"caseSignedMaxPlusOne", 2147483648>;
def I32CaseUnsignedMax: I32EnumAttrCase<"caseUnsignedMax", 4294967295>;


def SomeI32Enum: I32EnumAttr<
  "SomeI32Enum", "", [I32Case5, I32Case10, 
                      I32CaseSignedMaxPlusOne, I32CaseUnsignedMax]>;

def I32EnumAttrOp : TEST_Op<"i32_enum_attr"> {
  let arguments = (ins SomeI32Enum:$attr);
  let results = (outs I32:$val);
}

Then, using the defined Op, we can observe that the last two cases fail:

// CHECK-LABEL: func @allowed_cases_pass
func.func @allowed_cases_pass() {
  // CHECK: test.i32_enum_attr
  %0 = "test.i32_enum_attr"() {attr = 5: i32} : () -> i32
  // CHECK: test.i32_enum_attr
  %1 = "test.i32_enum_attr"() {attr = 10: i32} : () -> i32
  // CHECK: test.i32_enum_attr
  %2 = "test.i32_enum_attr"() {attr = 2147483648: i32} : () -> i32
  // CHECK: test.i32_enum_attr
  %3 = "test.i32_enum_attr"() {attr = 4294967295: i32} : () -> i32
  return
}

This is because the underlying generator (EnumsGen.cpp) and tablegen code (EnumAttr.td) use the deprecated getInt method, which sign extends the last bit. I don't think this is the best default behavior, but I don't think that we should change this old API.

Fix

To fix this, we can return ZExtValue inside EnumsGen.cpp like so:

diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index 95767a29b9c3..def322a9d684 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -524,7 +524,7 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
 
   os << formatv("{0} {1}::getValue() const {{\n", enumName, attrClassName);
 
-  os << formatv("  return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n",
+  os << formatv("  return static_cast<{0}>(::mlir::IntegerAttr::getValue().getZExtValue());\n",
                 enumName);
 
   os << "}\n";

Similarly, for EnumAttr.td, we can use APInt's eq method.

diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td
index 9fec28f03ec2..8d004f8b7b8c 100644
--- a/mlir/include/mlir/IR/EnumAttr.td
+++ b/mlir/include/mlir/IR/EnumAttr.td
@@ -34,7 +34,7 @@ class IntEnumAttrCaseBase<I intType, string sym, string strVal, int intVal> :
     EnumAttrCaseInfo<sym, intVal, strVal>,
     SignlessIntegerAttrBase<intType, "case " # strVal> {
   let predicate =
-    CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getInt() == " # intVal>;
+    CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().eq(APInt(" # intType.bitwidth # ", " # intVal # "))">;
 }
 
 // Cases of integer enum attributes with a specific type. By default, the string

Happy to open a PR for the above changes and additional test cases.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions