Skip to content

[MLIR][Linalg] Harden parsing Linalg named ops #145337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 25, 2025

Conversation

joker-eph
Copy link
Collaborator

@joker-eph joker-eph commented Jun 23, 2025

This thread through proper error handling / reporting capabilities to avoid hitting llvm_unreachable while parsing linalg ops.

Fixes #132755
Fixes #132740
Fixes #129185

@llvmbot
Copy link
Member

llvmbot commented Jun 23, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Mehdi Amini (joker-eph)

Changes

This thread through proper error handling / reporting capabilities to avoid hitting llvm_unreachable while parsing linalg ops.


Patch is 33.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145337.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/Linalg.h (+4)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td (+2-1)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+33-17)
  • (modified) mlir/lib/CAPI/Dialect/Linalg.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+134-35)
  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+10)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+8-4)
  • (modified) mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml (+16-9)
  • (modified) mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp (+12-5)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index 57bf6305a469d..a0fb0111d6ace 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -16,6 +16,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -26,6 +27,9 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
+
+#include "llvm/ADT/STLFunctionalExtras.h"
+
 #include <optional>
 
 namespace mlir {
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 33601c5d6dad9..a459656b982e6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -52,7 +52,8 @@ def Linalg_Dialect : Dialect {
         kMemoizedIndexingMapsAttrName = "linalg.memoized_indexing_maps";
 
     using RegionBuilderFunType = llvm::function_ref<
-      void(ImplicitLocOpBuilder &b, Block &, ArrayRef<NamedAttribute>)>;
+      void(ImplicitLocOpBuilder &b, Block &, ArrayRef<NamedAttribute>,
+           function_ref<InFlightDiagnostic()>)>;
     RegionBuilderFunType getRegionBuilder(StringRef name) {
       return namedStructuredOpRegionBuilders.lookup(name);
     }
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 74c4c0a8835f2..594d6c757d7bd 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -842,7 +842,7 @@ def LinalgStructuredInterface
         Returns a null function if this named op does not define a region
         builder.
       }],
-      /*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>)>",
+      /*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>, function_ref<InFlightDiagnostic()>)>",
       /*methodName=*/"getRegionBuilder",
       (ins),
       [{ return ConcreteOp::getRegionBuilder(); }]
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 61783812920bc..7bbc56f549c0b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -192,7 +192,8 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
     }
 
     static std::function<void(ImplicitLocOpBuilder &,
-                              Block &, ArrayRef<NamedAttribute>)>
+                              Block &, ArrayRef<NamedAttribute>,
+                              function_ref<InFlightDiagnostic()>)>
     getRegionBuilder() {
       return nullptr;
     }
@@ -300,7 +301,8 @@ def MapOp : LinalgStructuredBase_Op<"map", [
     }
 
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
-                              mlir::ArrayRef<mlir::NamedAttribute>)>
+                              mlir::ArrayRef<mlir::NamedAttribute>,
+                              function_ref<InFlightDiagnostic()>)>
     getRegionBuilder() {
       return nullptr;
     }
@@ -380,7 +382,8 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
 
     // Implement functions necessary for DestinationStyleOpInterface.
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
-                              mlir::ArrayRef<mlir::NamedAttribute>)>
+                              mlir::ArrayRef<mlir::NamedAttribute>,
+                              function_ref<InFlightDiagnostic()>)>
     getRegionBuilder() {
       return nullptr;
     }
@@ -449,13 +452,14 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
     MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
 
     static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
-        mlir::ArrayRef<mlir::NamedAttribute>) {
+        mlir::ArrayRef<mlir::NamedAttribute>, function_ref<InFlightDiagnostic()> emitError) {
       OpBuilder::InsertionGuard guard(b);
       b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
     }
 
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
-        mlir::ArrayRef<mlir::NamedAttribute>)>
+                              mlir::ArrayRef<mlir::NamedAttribute>,
+                              function_ref<InFlightDiagnostic()>)>
       getRegionBuilder() {
       return regionBuilder;
     }
@@ -521,13 +525,15 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
     MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
 
     static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
-        mlir::ArrayRef<mlir::NamedAttribute>) {
+                              mlir::ArrayRef<mlir::NamedAttribute>, 
+                              function_ref<InFlightDiagnostic()> emitError) {
       OpBuilder::InsertionGuard guard(b);
       b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
     }
 
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
-        mlir::ArrayRef<mlir::NamedAttribute>)>
+                              mlir::ArrayRef<mlir::NamedAttribute>,
+                              function_ref<InFlightDiagnostic()>)>
       getRegionBuilder() {
       return regionBuilder;
     }
@@ -631,10 +637,12 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
       /// Implements the block region builder for the elementwiseOp. This is
       /// called by the 'fillStructuredOpRegion'.
       static void regionBuilder(ImplicitLocOpBuilder &b,
-                                Block &block, ArrayRef<NamedAttribute> attrs);
+                                Block &block, ArrayRef<NamedAttribute> attrs,
+                                function_ref<InFlightDiagnostic()> emitError);
 
       static std::function<void(ImplicitLocOpBuilder &,
-                                Block &, ArrayRef<NamedAttribute>)>
+                                Block &, ArrayRef<NamedAttribute>,
+                                function_ref<InFlightDiagnostic()>)>
       getRegionBuilder() {
         return regionBuilder;
       }
@@ -771,7 +779,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
 
       /// Implements the block region builder.
       static void regionBuilder(ImplicitLocOpBuilder &b,
-                                Block &block, ArrayRef<NamedAttribute> attrs);
+                                Block &block, ArrayRef<NamedAttribute> attrs,
+                                function_ref<InFlightDiagnostic()> emitError);
 
       /// Returns a list of AffineMap with the default matmul indexing charactristic.
       static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
@@ -780,7 +789,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
       bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
 
       static std::function<void(ImplicitLocOpBuilder &,
-                                Block &, ArrayRef<NamedAttribute>)>
+                                Block &, ArrayRef<NamedAttribute>,
+                                function_ref<InFlightDiagnostic()>)>
       getRegionBuilder() {
         return regionBuilder;
       }
@@ -916,10 +926,12 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
     static unsigned getNumRegionArgs();
 
     static void regionBuilder(ImplicitLocOpBuilder &b,
-                              Block &block, ArrayRef<NamedAttribute> attrs);
+                              Block &block, ArrayRef<NamedAttribute> attrs,
+                              function_ref<InFlightDiagnostic()> emitError);
 
     static std::function<void(ImplicitLocOpBuilder &,
-                              Block &, ArrayRef<NamedAttribute>)>
+                              Block &, ArrayRef<NamedAttribute>,
+                              function_ref<InFlightDiagnostic()>)>
     getRegionBuilder() {
       return regionBuilder;
     }
@@ -1033,9 +1045,11 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
 
       SmallVector<utils::IteratorType> getIteratorTypesArray();
       static void regionBuilder(ImplicitLocOpBuilder &b,
-                                Block &block, ArrayRef<NamedAttribute> attrs);
+                                Block &block, ArrayRef<NamedAttribute> attrs,
+                                function_ref<InFlightDiagnostic()> emitError);
       static std::function<void(ImplicitLocOpBuilder &,
-                                Block &, ArrayRef<NamedAttribute>)>
+                                Block &, ArrayRef<NamedAttribute>,
+                                function_ref<InFlightDiagnostic()>)>
       getRegionBuilder() {
         return regionBuilder;
       }
@@ -1161,7 +1175,8 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
 
       /// Implements the block region builder.
       static void regionBuilder(ImplicitLocOpBuilder &b,
-                                Block &block, ArrayRef<NamedAttribute> attrs);
+                                Block &block, ArrayRef<NamedAttribute> attrs,
+                                function_ref<InFlightDiagnostic()> emitError);
 
       /// Returns a list of AffineMap with the default batch_reduce_matmul indexing charactristic.
       static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
@@ -1170,7 +1185,8 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
       bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
 
       static std::function<void(ImplicitLocOpBuilder &,
-                                Block &, ArrayRef<NamedAttribute>)>
+                                Block &, ArrayRef<NamedAttribute>,
+                              function_ref<InFlightDiagnostic()>)>
       getRegionBuilder() {
         return regionBuilder;
       }
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 0c4f6e88e7078..21db18dfd47ed 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -38,7 +38,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
   Region &region = op->getRegion(0);
   Block *body = b.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
   b.setInsertionPointToStart(body);
-  fun(b, *body, op->getAttrs());
+  fun(b, *body, op->getAttrs(), /*emitError=*/{});
 }
 
 MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op) {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5dbb2403eddbd..9cc60394e6635 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -117,8 +117,9 @@ OpFoldResult linalg::createFoldedDimOp(OpBuilder &b, Location loc, Value source,
 // Support for named Linalg ops defined in ods-gen.
 //===----------------------------------------------------------------------===//
 
-using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
-                                                ArrayRef<NamedAttribute>)>;
+using RegionBuilderFn = llvm::function_ref<void(
+    ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>,
+    function_ref<InFlightDiagnostic()>)>;
 
 /// Fills the region of a structured operation using the provided
 /// `regionBuilder`. The method is used by both named structured ops created by
@@ -128,6 +129,7 @@ using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
                                    TypeRange inputTypes, TypeRange outputTypes,
                                    ArrayRef<NamedAttribute> attrs,
+                                   function_ref<InFlightDiagnostic()> emitError,
                                    RegionBuilderFn regionBuilder) {
   SmallVector<Type, 8> argTypes;
   SmallVector<Location, 8> argLocs;
@@ -148,7 +150,7 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
 
   opBuilder.setInsertionPointToStart(body);
   ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
-  regionBuilder(b, *body, attrs);
+  regionBuilder(b, *body, attrs, emitError);
 
   // indexing_maps is an auto-generated method.
 
@@ -184,7 +186,8 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
   // Create and fill the region of the structured operation.
   Region &region = *state.addRegion();
   fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
-                         state.attributes.getAttrs(), regionBuilder);
+                         state.attributes.getAttrs(), /*emitError=*/{},
+                         regionBuilder);
 }
 
 static void buildMatmulOp(OpBuilder &b, OperationState &state,
@@ -339,9 +342,15 @@ static ParseResult parseNamedStructuredOpRegion(
   }
 
   OpBuilder opBuilder(parser.getContext());
-  fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
-                         regionBuilder);
-  return success();
+  ParseResult result = success();
+  fillStructuredOpRegion(
+      opBuilder, region, inputTypes, outputTypes, attrs,
+      [&]() {
+        result = failure();
+        return parser.emitError(parser.getCurrentLocation());
+      },
+      regionBuilder);
+  return result;
 }
 
 static ParseResult
@@ -435,9 +444,15 @@ class RegionBuilderHelper {
       : builder(builder), block(block) {}
 
   // Build the unary functions defined by OpDSL.
-  Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
-    if (!isFloatingPoint(arg))
+  Value buildUnaryFn(UnaryFn unaryFn, Value arg,
+                     function_ref<InFlightDiagnostic()> emitError = {}) {
+    if (!isFloatingPoint(arg)) {
+      if (emitError) {
+        emitError() << "unsupported non numeric type";
+        return nullptr;
+      }
       llvm_unreachable("unsupported non numeric type");
+    }
     OpBuilder::InsertionGuard g(builder);
     builder.setInsertionPointToEnd(&block);
     switch (unaryFn) {
@@ -472,18 +487,34 @@ class RegionBuilderHelper {
     case UnaryFn::erf:
       return builder.create<math::ErfOp>(arg.getLoc(), arg);
     }
+    if (emitError) {
+      emitError() << "unsupported unary function";
+      return nullptr;
+    }
     llvm_unreachable("unsupported unary function");
   }
 
   // Build the binary functions defined by OpDSL.
-  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
+  // If emitError is provided, an error will be emitted if the operation is not
+  // supported and a nullptr will be returned, otherwise an assertion will be
+  // raised.
+  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
+                      function_ref<InFlightDiagnostic()> emitError = {}) {
     bool allComplex = isComplex(arg0) && isComplex(arg1);
     bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
     bool allInteger = isInteger(arg0) && isInteger(arg1);
     bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
                    arg1.getType().getIntOrFloatBitWidth() == 1;
-    if (!allComplex && !allFloatingPoint && !allInteger)
+    if (!allComplex && !allFloatingPoint && !allInteger) {
+      if (emitError) {
+        emitError()
+            << "Cannot build binary Linalg operation: expects allComplex, "
+               "allFloatingPoint, or allInteger, got "
+            << arg0.getType() << " and " << arg1.getType();
+        return nullptr;
+      }
       llvm_unreachable("unsupported non numeric type");
+    }
     OpBuilder::InsertionGuard g(builder);
     builder.setInsertionPointToEnd(&block);
     switch (binaryFn) {
@@ -500,8 +531,13 @@ class RegionBuilderHelper {
         return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
       if (allFloatingPoint)
         return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
-      if (allBool)
+      if (allBool) {
+        if (emitError) {
+          emitError() << "unsupported operation: sub with bools";
+          return nullptr;
+        }
         llvm_unreachable("unsupported operation: sub with bools");
+      }
       return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
     case BinaryFn::mul:
       if (allComplex)
@@ -516,12 +552,22 @@ class RegionBuilderHelper {
         return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
       if (allFloatingPoint)
         return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
-      if (allBool)
+      if (allBool) {
+        if (emitError) {
+          emitError() << "unsupported operation: div with bools";
+          return nullptr;
+        }
         llvm_unreachable("unsupported operation: div with bools");
+      }
       return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
     case BinaryFn::div_unsigned:
-      if (!allInteger || allBool)
+      if (!allInteger || allBool) {
+        if (emitError) {
+          emitError() << "unsupported operation: unsigned div not on uint";
+          return nullptr;
+        }
         llvm_unreachable("unsupported operation: unsigned div not on uint");
+      }
       return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
     case BinaryFn::max_signed:
       assert(!allComplex);
@@ -547,12 +593,16 @@ class RegionBuilderHelper {
       assert(allFloatingPoint);
       return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
     }
+    if (emitError) {
+      emitError() << "unsupported binary function";
+      return nullptr;
+    }
     llvm_unreachable("unsupported binary function");
   }
 
   // Build the ternary functions defined by OpDSL.
-  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
-                       Value arg2) {
+  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
+                       function_ref<InFlightDiagnostic()> emitError = {}) {
     bool headBool =
         isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
     bool tailFloatingPoint =
@@ -566,17 +616,26 @@ class RegionBuilderHelper {
         llvm_unreachable("unsupported non numeric type");
       return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
     }
+    if (emitError) {
+      emitError() << "unsupported ternary function";
+      return nullptr;
+    }
     llvm_unreachable("unsupported ternary function");
   }
 
   // Build the type functions defined by OpDSL.
-  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
+  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
+                    function_ref<InFlightDiagnostic()> emitError = {}) {
     switch (typeFn) {
     case TypeFn::cast_signed:
       return cast(toType, operand, false);
     case TypeFn::cast_unsigned:
       return cast(toType, operand, true);
     }
+    if (emitError) {
+      emitError() << "unsupported type conversion function";
+      return nullptr;
+    }
     llvm_unreachable("unsupported type conversion function");
   }
 
@@ -3664,9 +3723,15 @@ bool MatmulOp::hasUserDefinedMaps() {
 /// Implements the block region builder for the MatmulOp. This is called by
 /// 'fillStructuredOpRegion'.
 void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
-                             ArrayRef<NamedAttribute> attrs) {
-  assert(3 > 0 && block.getNumArguments() == 3 &&
-         "MatmulOp regionBuilder expects 3 (>=0) args");
+                             ArrayRef<NamedAttribute> attrs,
+                             function_ref<InFlightDiagnostic()> emitError) {
+  if (emitError && block.getNumArguments() != 3) {
+    emitError() << "MatmulOp regionBuilder expects 3 args, got "
+                << block.getNumArguments();
+    return;
+  }
+  assert(block.getNumArguments() == 3 &&
+         "MatmulOp regionBuilder expects 3 args");
   RegionBuilderHelper helper(b, block);
   SmallVector<Value> yields;
 
@@ -3683,9 +3748,13 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
                                     block.getArgument(0));
   Value value2 = helper.buildTypeFn(cast...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jun 23, 2025

@llvm/pr-subscribers-mlir-core

Author: Mehdi Amini (joker-eph)

Changes

This thread through proper error handling / reporting capabilities to avoid hitting llvm_unreachable while parsing linalg ops.


Patch is 33.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145337.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/Linalg.h (+4)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td (+2-1)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+33-17)
  • (modified) mlir/lib/CAPI/Dialect/Linalg.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+134-35)
  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+10)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (+8-4)
  • (modified) mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml (+16-9)
  • (modified) mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp (+12-5)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index 57bf6305a469d..a0fb0111d6ace 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -16,6 +16,7 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -26,6 +27,9 @@
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
+
+#include "llvm/ADT/STLFunctionalExtras.h"
+
 #include <optional>
 
 namespace mlir {
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
index 33601c5d6dad9..a459656b982e6 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
@@ -52,7 +52,8 @@ def Linalg_Dialect : Dialect {
         kMemoizedIndexingMapsAttrName = "linalg.memoized_indexing_maps";
 
     using RegionBuilderFunType = llvm::function_ref<
-      void(ImplicitLocOpBuilder &b, Block &, ArrayRef<NamedAttribute>)>;
+      void(ImplicitLocOpBuilder &b, Block &, ArrayRef<NamedAttribute>,
+           function_ref<InFlightDiagnostic()>)>;
     RegionBuilderFunType getRegionBuilder(StringRef name) {
       return namedStructuredOpRegionBuilders.lookup(name);
     }
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
index 74c4c0a8835f2..594d6c757d7bd 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
@@ -842,7 +842,7 @@ def LinalgStructuredInterface
         Returns a null function if this named op does not define a region
         builder.
       }],
-      /*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>)>",
+      /*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>, function_ref<InFlightDiagnostic()>)>",
       /*methodName=*/"getRegionBuilder",
       (ins),
       [{ return ConcreteOp::getRegionBuilder(); }]
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 61783812920bc..7bbc56f549c0b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -192,7 +192,8 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
     }
 
     static std::function<void(ImplicitLocOpBuilder &,
-                              Block &, ArrayRef<NamedAttribute>)>
+                              Block &, ArrayRef<NamedAttribute>,
+                              function_ref<InFlightDiagnostic()>)>
     getRegionBuilder() {
       return nullptr;
     }
@@ -300,7 +301,8 @@ def MapOp : LinalgStructuredBase_Op<"map", [
     }
 
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
-                              mlir::ArrayRef<mlir::NamedAttribute>)>
+                              mlir::ArrayRef<mlir::NamedAttribute>,
+                              function_ref<InFlightDiagnostic()>)>
     getRegionBuilder() {
       return nullptr;
     }
@@ -380,7 +382,8 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
 
     // Implement functions necessary for DestinationStyleOpInterface.
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
-                              mlir::ArrayRef<mlir::NamedAttribute>)>
+                              mlir::ArrayRef<mlir::NamedAttribute>,
+                              function_ref<InFlightDiagnostic()>)>
     getRegionBuilder() {
       return nullptr;
     }
@@ -449,13 +452,14 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
     MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
 
     static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
-        mlir::ArrayRef<mlir::NamedAttribute>) {
+        mlir::ArrayRef<mlir::NamedAttribute>, function_ref<InFlightDiagnostic()> emitError) {
       OpBuilder::InsertionGuard guard(b);
       b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
     }
 
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
-        mlir::ArrayRef<mlir::NamedAttribute>)>
+                              mlir::ArrayRef<mlir::NamedAttribute>,
+                              function_ref<InFlightDiagnostic()>)>
       getRegionBuilder() {
       return regionBuilder;
     }
@@ -521,13 +525,15 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
     MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
 
     static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
-        mlir::ArrayRef<mlir::NamedAttribute>) {
+                              mlir::ArrayRef<mlir::NamedAttribute>, 
+                              function_ref<InFlightDiagnostic()> emitError) {
       OpBuilder::InsertionGuard guard(b);
       b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
     }
 
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
-        mlir::ArrayRef<mlir::NamedAttribute>)>
+                              mlir::ArrayRef<mlir::NamedAttribute>,
+                              function_ref<InFlightDiagnostic()>)>
       getRegionBuilder() {
       return regionBuilder;
     }
@@ -631,10 +637,12 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
       /// Implements the block region builder for the elementwiseOp. This is
       /// called by the 'fillStructuredOpRegion'.
       static void regionBuilder(ImplicitLocOpBuilder &b,
-                                Block &block, ArrayRef<NamedAttribute> attrs);
+                                Block &block, ArrayRef<NamedAttribute> attrs,
+                                function_ref<InFlightDiagnostic()> emitError);
 
       static std::function<void(ImplicitLocOpBuilder &,
-                                Block &, ArrayRef<NamedAttribute>)>
+                                Block &, ArrayRef<NamedAttribute>,
+                                function_ref<InFlightDiagnostic()>)>
       getRegionBuilder() {
         return regionBuilder;
       }
@@ -771,7 +779,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
 
       /// Implements the block region builder.
       static void regionBuilder(ImplicitLocOpBuilder &b,
-                                Block &block, ArrayRef<NamedAttribute> attrs);
+                                Block &block, ArrayRef<NamedAttribute> attrs,
+                                function_ref<InFlightDiagnostic()> emitError);
 
       /// Returns a list of AffineMap with the default matmul indexing charactristic.
       static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
@@ -780,7 +789,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
       bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);
 
       static std::function<void(ImplicitLocOpBuilder &,
-                                Block &, ArrayRef<NamedAttribute>)>
+                                Block &, ArrayRef<NamedAttribute>,
+                                function_ref<InFlightDiagnostic()>)>
       getRegionBuilder() {
         return regionBuilder;
       }
@@ -916,10 +926,12 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
     static unsigned getNumRegionArgs();
 
     static void regionBuilder(ImplicitLocOpBuilder &b,
-                              Block &block, ArrayRef<NamedAttribute> attrs);
+                              Block &block, ArrayRef<NamedAttribute> attrs,
+                              function_ref<InFlightDiagnostic()> emitError);
 
     static std::function<void(ImplicitLocOpBuilder &,
-                              Block &, ArrayRef<NamedAttribute>)>
+                              Block &, ArrayRef<NamedAttribute>,
+                              function_ref<InFlightDiagnostic()>)>
     getRegionBuilder() {
       return regionBuilder;
     }
@@ -1033,9 +1045,11 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
 
       SmallVector<utils::IteratorType> getIteratorTypesArray();
       static void regionBuilder(ImplicitLocOpBuilder &b,
-                                Block &block, ArrayRef<NamedAttribute> attrs);
+                                Block &block, ArrayRef<NamedAttribute> attrs,
+                                function_ref<InFlightDiagnostic()> emitError);
       static std::function<void(ImplicitLocOpBuilder &,
-                                Block &, ArrayRef<NamedAttribute>)>
+                                Block &, ArrayRef<NamedAttribute>,
+                                function_ref<InFlightDiagnostic()>)>
       getRegionBuilder() {
         return regionBuilder;
       }
@@ -1161,7 +1175,8 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
 
       /// Implements the block region builder.
       static void regionBuilder(ImplicitLocOpBuilder &b,
-                                Block &block, ArrayRef<NamedAttribute> attrs);
+                                Block &block, ArrayRef<NamedAttribute> attrs,
+                                function_ref<InFlightDiagnostic()> emitError);
 
       /// Returns a list of AffineMap with the default batch_reduce_matmul indexing charactristic.
       static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
@@ -1170,7 +1185,8 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
       bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);
 
       static std::function<void(ImplicitLocOpBuilder &,
-                                Block &, ArrayRef<NamedAttribute>)>
+                                Block &, ArrayRef<NamedAttribute>,
+                              function_ref<InFlightDiagnostic()>)>
       getRegionBuilder() {
         return regionBuilder;
       }
diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp
index 0c4f6e88e7078..21db18dfd47ed 100644
--- a/mlir/lib/CAPI/Dialect/Linalg.cpp
+++ b/mlir/lib/CAPI/Dialect/Linalg.cpp
@@ -38,7 +38,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
   Region &region = op->getRegion(0);
   Block *body = b.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
   b.setInsertionPointToStart(body);
-  fun(b, *body, op->getAttrs());
+  fun(b, *body, op->getAttrs(), /*emitError=*/{});
 }
 
 MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op) {
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5dbb2403eddbd..9cc60394e6635 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -117,8 +117,9 @@ OpFoldResult linalg::createFoldedDimOp(OpBuilder &b, Location loc, Value source,
 // Support for named Linalg ops defined in ods-gen.
 //===----------------------------------------------------------------------===//
 
-using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
-                                                ArrayRef<NamedAttribute>)>;
+using RegionBuilderFn = llvm::function_ref<void(
+    ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>,
+    function_ref<InFlightDiagnostic()>)>;
 
 /// Fills the region of a structured operation using the provided
 /// `regionBuilder`. The method is used by both named structured ops created by
@@ -128,6 +129,7 @@ using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
                                    TypeRange inputTypes, TypeRange outputTypes,
                                    ArrayRef<NamedAttribute> attrs,
+                                   function_ref<InFlightDiagnostic()> emitError,
                                    RegionBuilderFn regionBuilder) {
   SmallVector<Type, 8> argTypes;
   SmallVector<Location, 8> argLocs;
@@ -148,7 +150,7 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
 
   opBuilder.setInsertionPointToStart(body);
   ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
-  regionBuilder(b, *body, attrs);
+  regionBuilder(b, *body, attrs, emitError);
 
   // indexing_maps is an auto-generated method.
 
@@ -184,7 +186,8 @@ static void buildStructuredOp(OpBuilder &b, OperationState &state,
   // Create and fill the region of the structured operation.
   Region &region = *state.addRegion();
   fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
-                         state.attributes.getAttrs(), regionBuilder);
+                         state.attributes.getAttrs(), /*emitError=*/{},
+                         regionBuilder);
 }
 
 static void buildMatmulOp(OpBuilder &b, OperationState &state,
@@ -339,9 +342,15 @@ static ParseResult parseNamedStructuredOpRegion(
   }
 
   OpBuilder opBuilder(parser.getContext());
-  fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
-                         regionBuilder);
-  return success();
+  ParseResult result = success();
+  fillStructuredOpRegion(
+      opBuilder, region, inputTypes, outputTypes, attrs,
+      [&]() {
+        result = failure();
+        return parser.emitError(parser.getCurrentLocation());
+      },
+      regionBuilder);
+  return result;
 }
 
 static ParseResult
@@ -435,9 +444,15 @@ class RegionBuilderHelper {
       : builder(builder), block(block) {}
 
   // Build the unary functions defined by OpDSL.
-  Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
-    if (!isFloatingPoint(arg))
+  Value buildUnaryFn(UnaryFn unaryFn, Value arg,
+                     function_ref<InFlightDiagnostic()> emitError = {}) {
+    if (!isFloatingPoint(arg)) {
+      if (emitError) {
+        emitError() << "unsupported non numeric type";
+        return nullptr;
+      }
       llvm_unreachable("unsupported non numeric type");
+    }
     OpBuilder::InsertionGuard g(builder);
     builder.setInsertionPointToEnd(&block);
     switch (unaryFn) {
@@ -472,18 +487,34 @@ class RegionBuilderHelper {
     case UnaryFn::erf:
       return builder.create<math::ErfOp>(arg.getLoc(), arg);
     }
+    if (emitError) {
+      emitError() << "unsupported unary function";
+      return nullptr;
+    }
     llvm_unreachable("unsupported unary function");
   }
 
   // Build the binary functions defined by OpDSL.
-  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
+  // If emitError is provided, an error will be emitted if the operation is not
+  // supported and a nullptr will be returned, otherwise an assertion will be
+  // raised.
+  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
+                      function_ref<InFlightDiagnostic()> emitError = {}) {
     bool allComplex = isComplex(arg0) && isComplex(arg1);
     bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
     bool allInteger = isInteger(arg0) && isInteger(arg1);
     bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
                    arg1.getType().getIntOrFloatBitWidth() == 1;
-    if (!allComplex && !allFloatingPoint && !allInteger)
+    if (!allComplex && !allFloatingPoint && !allInteger) {
+      if (emitError) {
+        emitError()
+            << "Cannot build binary Linalg operation: expects allComplex, "
+               "allFloatingPoint, or allInteger, got "
+            << arg0.getType() << " and " << arg1.getType();
+        return nullptr;
+      }
       llvm_unreachable("unsupported non numeric type");
+    }
     OpBuilder::InsertionGuard g(builder);
     builder.setInsertionPointToEnd(&block);
     switch (binaryFn) {
@@ -500,8 +531,13 @@ class RegionBuilderHelper {
         return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
       if (allFloatingPoint)
         return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
-      if (allBool)
+      if (allBool) {
+        if (emitError) {
+          emitError() << "unsupported operation: sub with bools";
+          return nullptr;
+        }
         llvm_unreachable("unsupported operation: sub with bools");
+      }
       return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
     case BinaryFn::mul:
       if (allComplex)
@@ -516,12 +552,22 @@ class RegionBuilderHelper {
         return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
       if (allFloatingPoint)
         return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
-      if (allBool)
+      if (allBool) {
+        if (emitError) {
+          emitError() << "unsupported operation: div with bools";
+          return nullptr;
+        }
         llvm_unreachable("unsupported operation: div with bools");
+      }
       return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
     case BinaryFn::div_unsigned:
-      if (!allInteger || allBool)
+      if (!allInteger || allBool) {
+        if (emitError) {
+          emitError() << "unsupported operation: unsigned div not on uint";
+          return nullptr;
+        }
         llvm_unreachable("unsupported operation: unsigned div not on uint");
+      }
       return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
     case BinaryFn::max_signed:
       assert(!allComplex);
@@ -547,12 +593,16 @@ class RegionBuilderHelper {
       assert(allFloatingPoint);
       return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
     }
+    if (emitError) {
+      emitError() << "unsupported binary function";
+      return nullptr;
+    }
     llvm_unreachable("unsupported binary function");
   }
 
   // Build the ternary functions defined by OpDSL.
-  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
-                       Value arg2) {
+  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
+                       function_ref<InFlightDiagnostic()> emitError = {}) {
     bool headBool =
         isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
     bool tailFloatingPoint =
@@ -566,17 +616,26 @@ class RegionBuilderHelper {
         llvm_unreachable("unsupported non numeric type");
       return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
     }
+    if (emitError) {
+      emitError() << "unsupported ternary function";
+      return nullptr;
+    }
     llvm_unreachable("unsupported ternary function");
   }
 
   // Build the type functions defined by OpDSL.
-  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
+  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
+                    function_ref<InFlightDiagnostic()> emitError = {}) {
     switch (typeFn) {
     case TypeFn::cast_signed:
       return cast(toType, operand, false);
     case TypeFn::cast_unsigned:
       return cast(toType, operand, true);
     }
+    if (emitError) {
+      emitError() << "unsupported type conversion function";
+      return nullptr;
+    }
     llvm_unreachable("unsupported type conversion function");
   }
 
@@ -3664,9 +3723,15 @@ bool MatmulOp::hasUserDefinedMaps() {
 /// Implements the block region builder for the MatmulOp. This is called by
 /// 'fillStructuredOpRegion'.
 void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
-                             ArrayRef<NamedAttribute> attrs) {
-  assert(3 > 0 && block.getNumArguments() == 3 &&
-         "MatmulOp regionBuilder expects 3 (>=0) args");
+                             ArrayRef<NamedAttribute> attrs,
+                             function_ref<InFlightDiagnostic()> emitError) {
+  if (emitError && block.getNumArguments() != 3) {
+    emitError() << "MatmulOp regionBuilder expects 3 args, got "
+                << block.getNumArguments();
+    return;
+  }
+  assert(block.getNumArguments() == 3 &&
+         "MatmulOp regionBuilder expects 3 args");
   RegionBuilderHelper helper(b, block);
   SmallVector<Value> yields;
 
@@ -3683,9 +3748,13 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
                                     block.getArgument(0));
   Value value2 = helper.buildTypeFn(cast...
[truncated]

@joker-eph joker-eph requested review from Copilot, banach-space and Groverkss and removed request for rengolin, nicolasvasilache and dcaballe June 23, 2025 15:04
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds an optional error‐reporting callback (emitError) to all Linalg region builders and wiring code, updates ODS YAML generation and tests to pass the new parameter, and improves parsing logic to gracefully handle failures instead of using llvm_unreachable.

  • Extend regionBuilder signatures across C++ and ODS definitions to take function_ref<InFlightDiagnostic()> emitError.
  • Update fillStructuredOpRegion, parsing routines, and helper methods (buildUnaryFn, buildBinaryFn, etc.) to invoke emitError and return early on errors.
  • Adjust the YAML generator, TestOps definitions, and MLIR Linalg tests to incorporate the new callback parameter.

Reviewed Changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated no comments.

Show a summary per file
File Description
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp Add emitError to generated region builders
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml Update test patterns for new emitError parameter
mlir/test/lib/Dialect/Test/TestOps.td Extend test op region builders with emitError callback
mlir/test/Dialect/Linalg/invalid.mlir Add an invalid case expecting the new diagnostic
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp Propagate emitError through parsing and helper code
mlir/lib/CAPI/Dialect/Linalg.cpp Pass an empty/default callback in C API region builders
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td Update ODS definitions for getRegionBuilder signature
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td Change interface return type to include emitError
mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td Adjust dialect‐wide RegionBuilderFunType for callback
mlir/include/mlir/Dialect/Linalg/IR/Linalg.h Include diagnostics and functional extras for function_ref
Comments suppressed due to low confidence (5)

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:451

  • [nitpick] Include the actual operand type in this diagnostic for better context, e.g.: emitError() << "unsupported non-numeric type: " << arg.getType();
        emitError() << "unsupported non numeric type";

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:491

  • [nitpick] Specify which unary function is unsupported in the message, e.g.: emitError() << "unsupported unary function: " << unaryFn;
      emitError() << "unsupported unary function";

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:597

  • [nitpick] Include the binary function name in the diagnostic, e.g.: emitError() << "unsupported binary function: " << binaryFn;
      emitError() << "unsupported binary function";

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:636

  • [nitpick] Add the specific type conversion operation to the message, e.g.: emitError() << "unsupported type conversion: " << typeFn;
      emitError() << "unsupported type conversion function";

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp:620

  • [nitpick] Include the ternary function identifier in this error, e.g.: emitError() << "unsupported ternary function: " << ternaryFn;
      emitError() << "unsupported ternary function";

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice, thanks! Small comment on my part but otherwise ok.

Will give time to others to look, too.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking this on!

Shouldn't this also include tests for linalg.matmul and linalg.elemwise_unary?

@joker-eph
Copy link
Collaborator Author

Added the tests from the two other bugs.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

I posted 2 additional questions/requests - these are non-blocking.

Thanks again for tackling this!

This thread through proper error handling / reporting capabilities to
avoid hitting llvm_unreachable while parsing linalg ops.
Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, thanks much!

@joker-eph joker-eph merged commit ff0dcc4 into llvm:main Jun 25, 2025
7 checks passed
@hiraditya
Copy link
Collaborator

Thanks for fixing it the right way. I was having this issue a while back and it was tricky to handle without modifying the region builder API https://discourse.llvm.org/t/rfc-linalgops-regionbuilder-function-needs-to-indicate-if-anything-goes-wrong/86053

@vzakhari
Copy link
Contributor

Am I the only one who is having build issues with this patch?

The errors look like this:

.../llvm/tools/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h.inc:1450:39: error: could not convert ‘mlir::linalg::MinOp
::getRegionBuilder()()’ from ‘function<void(mlir::ImplicitLocOpBuilder&, mlir::Block&, llvm::ArrayRef<mlir::NamedAttribute>)>’ to ‘function<void(mlir::ImplicitLocOpBui
lder&, mlir::Block&, llvm::ArrayRef<mlir::NamedAttribute>, llvm::function_ref<mlir::InFlightDiagnostic()>)>’
.../llvm/tools/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h.inc: In instantiation of ‘static std::function<void(mlir::
ImplicitLocOpBuilder&, mlir::Block&, llvm::ArrayRef<mlir::NamedAttribute>, llvm::function_ref<mlir::InFlightDiagnostic()>)> mlir::linalg::detail::LinalgOpInterfaceTrai
ts::Model<ConcreteOp>::getRegionBuilder() [with ConcreteOp = mlir::linalg::Mmt4DOp]’:
.../llvm/tools/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h.inc:361:724:   required from ‘mlir::linalg::detail::Linalg
OpInterfaceTraits::Model<ConcreteOp>::Model() [with ConcreteOp = mlir::linalg::Mmt4DOp]’
/llvm-project/mlir/include/mlir/Support/InterfaceSupport.h:238:9:   required from ‘void mlir::detail::InterfaceMap::insertModel() [with
 InterfaceModel = mlir::linalg::detail::LinalgOpInterfaceTraits::Model<mlir::linalg::Mmt4DOp>]’

@vzakhari
Copy link
Contributor

I was a false alarm (somewhat), it passed with a clean build.

anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
This thread through proper error handling / reporting capabilities to
avoid hitting llvm_unreachable while parsing linalg ops.

Fixes llvm#132755
Fixes llvm#132740
Fixes llvm#129185
rlavaee pushed a commit to rlavaee/llvm-project that referenced this pull request Jul 1, 2025
This thread through proper error handling / reporting capabilities to
avoid hitting llvm_unreachable while parsing linalg ops.

Fixes llvm#132755
Fixes llvm#132740
Fixes llvm#129185
@benvanik
Copy link
Contributor

benvanik commented Jul 1, 2025

I too am getting errors with this - will try a clean build - I suspect there's a missing cmake dependency that is causing something to not be rebuilt.

@benvanik
Copy link
Contributor

benvanik commented Jul 1, 2025

Clean build fixed it, so definitely something shady with the linalg generator cmake dependencies.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:linalg mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir][linalg] linalg::AddOp::parse crashes [mlir] Linalg MatmulOp::parse crashes [mlir][linalg] Crash parsing linalg.elemwise_unary
8 participants