Skip to content

[MLIR][Tosa] Fix argmax NaN propagate lowering #133074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 14, 2025

Conversation

RoboTux
Copy link
Contributor

@RoboTux RoboTux commented Mar 26, 2025

In the propagate mode, NaN compare equal to each other so in case of
several NaNs the index of the first one needs to be returned. This
commit changes the index update condition to check that the current
index is not that of a NaN.

The commit also simplifies argmax NaN ignore lowering to only use OGT.
This prevent any update in case of NaN. The only case where the index of
a NaN is returned is when all values are NaN and this is covered by the
fact that the initial index value is 0 so no update will result in 0
being returned.

In the propagate mode, NaN compare equal to each other so in case of
several NaNs the index of the first one needs to be returned. This
commit changes the index update condition to check that the current
index is not that of a NaN.

The commit also simplifies argmax NaN ignore lowering to only use OGT.
This prevent any update in case of NaN. The only case where the index of
a NaN is returned is when all values are NaN and this is covered by the
fact that the initial index value is 0 so no update will result in 0
being returned.
@llvmbot
Copy link
Member

llvmbot commented Mar 26, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Thomas Preud'homme (RoboTux)

Changes

In the propagate mode, NaN compare equal to each other so in case of
several NaNs the index of the first one needs to be returned. This
commit changes the index update condition to check that the current
index is not that of a NaN.

The commit also simplifies argmax NaN ignore lowering to only use OGT.
This prevent any update in case of NaN. The only case where the index of
a NaN is returned is when all values are NaN and this is covered by the
fact that the initial index value is 0 so no update will result in 0
being returned.


Full diff: https://github.com/llvm/llvm-project/pull/133074.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+16-24)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+6-7)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index e18fa849e9f30..9ca93ab28daed 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2285,8 +2285,22 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
 
           Value predicate;
           if (isa<FloatType>(inElementTy)) {
-            predicate = rewriter.create<arith::CmpFOp>(
-                nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
+            if (argmaxOp.getNanMode() == "IGNORE") {
+              // Only update index & max value for non NaN values. If all
+              // values are NaNs, the initial index will be return which is 0.
+              predicate = rewriter.create<arith::CmpFOp>(
+                  nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
+            } else {
+              // Update max value if either of the following is true:
+              // - new value is bigger
+              // - cur max is not NaN and new value is NaN
+              Value gt = rewriter.create<arith::CmpFOp>(
+                  nestedLoc, arith::CmpFPredicate::UGT, newValue, oldValue);
+              Value oldNonNaN = rewriter.create<arith::CmpFOp>(
+                  nestedLoc, arith::CmpFPredicate::ORD, oldValue, oldValue);
+              predicate = rewriter.create<arith::AndIOp>(
+                  nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
+            }
           } else if (isa<IntegerType>(inElementTy)) {
             predicate = rewriter.create<arith::CmpIOp>(
                 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
@@ -2299,28 +2313,6 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
               nestedLoc, predicate, newValue, oldValue);
           auto resultIndex = rewriter.create<arith::SelectOp>(
               nestedLoc, predicate, newIndex, oldIndex);
-
-          // Check if we need to materialize compare and select for the given
-          // NaN propagation mode.
-
-          // "PROPAGATE" matches the default NaN propagation mode of the arith
-          // dialect so no compare and select is required.
-          //
-          // In the case "IGNORE" we check if the current argument is NaN and
-          // select the old index and value otherwise take the updated index and
-          // value.
-          if (const auto nanMode = argmaxOp.getNanMode();
-              isa<FloatType>(inElementTy) && nanMode == "IGNORE") {
-            // Unordered comparison of NaN against itself will always return
-            // true.
-            Value isNaN = rewriter.create<arith::CmpFOp>(
-                argmaxOp.getLoc(), arith::CmpFPredicate::UNO, newValue,
-                newValue);
-            resultMax = rewriter.create<arith::SelectOp>(nestedLoc, isNaN,
-                                                         oldValue, resultMax);
-            resultIndex = rewriter.create<arith::SelectOp>(
-                nestedLoc, isNaN, oldIndex, resultIndex);
-          }
           nestedBuilder.create<linalg::YieldOp>(
               nestedLoc, ValueRange({resultIndex, resultMax}));
         });
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 9258442de5a45..eafc62eb71e05 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1525,7 +1525,9 @@ func.func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () {
   // CHECK: arith.constant -3.40282347E+38 : f32
   // CHECK: linalg.index
   // CHECK: arith.index_cast
-  // CHECK: arith.cmpf ogt
+  // CHECK: arith.cmpf ugt
+  // CHECK: arith.cmpf ord
+  // CHECK: andi
   // CHECK: select
   // CHECK: select
   // CHECK: linalg.yield
@@ -2230,12 +2232,12 @@ func.func @maximum_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) ->
 // CHECK-LABEL: @argmax_nan_propagate
 func.func @argmax_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
   // CHECK: linalg.generic
-  // CHECK: arith.cmpf ogt
+  // CHECK: arith.cmpf ugt
+  // CHECK: arith.cmpf ord
+  // CHECK: andi
   // CHECK: arith.select
   // CHECK: arith.select
   // CHECK-NOT: arith.cmpf uno
-  // CHECK-NOT: arith.cmpf uno
-  // CHECK-NOT: arith.select
   // CHECK-NOT: arith.select
   // CHECK: linalg.yield
   %11 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>)  -> tensor<4xi32>
@@ -2267,9 +2269,6 @@ func.func @argmax_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) ->
   // CHECK: arith.cmpf ogt
   // CHECK: arith.select
   // CHECK: arith.select
-  // CHECK: arith.cmpf uno
-  // CHECK: arith.select
-  // CHECK: arith.select
   // CHECK: linalg.yield
   %12 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>)  -> tensor<4xi32>
   return

@llvmbot
Copy link
Member

llvmbot commented Mar 26, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Thomas Preud'homme (RoboTux)

Changes

In the propagate mode, NaN compare equal to each other so in case of
several NaNs the index of the first one needs to be returned. This
commit changes the index update condition to check that the current
index is not that of a NaN.

The commit also simplifies argmax NaN ignore lowering to only use OGT.
This prevent any update in case of NaN. The only case where the index of
a NaN is returned is when all values are NaN and this is covered by the
fact that the initial index value is 0 so no update will result in 0
being returned.


Full diff: https://github.com/llvm/llvm-project/pull/133074.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+16-24)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+6-7)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index e18fa849e9f30..9ca93ab28daed 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -2285,8 +2285,22 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
 
           Value predicate;
           if (isa<FloatType>(inElementTy)) {
-            predicate = rewriter.create<arith::CmpFOp>(
-                nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
+            if (argmaxOp.getNanMode() == "IGNORE") {
+              // Only update index & max value for non NaN values. If all
+              // values are NaNs, the initial index will be return which is 0.
+              predicate = rewriter.create<arith::CmpFOp>(
+                  nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
+            } else {
+              // Update max value if either of the following is true:
+              // - new value is bigger
+              // - cur max is not NaN and new value is NaN
+              Value gt = rewriter.create<arith::CmpFOp>(
+                  nestedLoc, arith::CmpFPredicate::UGT, newValue, oldValue);
+              Value oldNonNaN = rewriter.create<arith::CmpFOp>(
+                  nestedLoc, arith::CmpFPredicate::ORD, oldValue, oldValue);
+              predicate = rewriter.create<arith::AndIOp>(
+                  nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
+            }
           } else if (isa<IntegerType>(inElementTy)) {
             predicate = rewriter.create<arith::CmpIOp>(
                 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
@@ -2299,28 +2313,6 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
               nestedLoc, predicate, newValue, oldValue);
           auto resultIndex = rewriter.create<arith::SelectOp>(
               nestedLoc, predicate, newIndex, oldIndex);
-
-          // Check if we need to materialize compare and select for the given
-          // NaN propagation mode.
-
-          // "PROPAGATE" matches the default NaN propagation mode of the arith
-          // dialect so no compare and select is required.
-          //
-          // In the case "IGNORE" we check if the current argument is NaN and
-          // select the old index and value otherwise take the updated index and
-          // value.
-          if (const auto nanMode = argmaxOp.getNanMode();
-              isa<FloatType>(inElementTy) && nanMode == "IGNORE") {
-            // Unordered comparison of NaN against itself will always return
-            // true.
-            Value isNaN = rewriter.create<arith::CmpFOp>(
-                argmaxOp.getLoc(), arith::CmpFPredicate::UNO, newValue,
-                newValue);
-            resultMax = rewriter.create<arith::SelectOp>(nestedLoc, isNaN,
-                                                         oldValue, resultMax);
-            resultIndex = rewriter.create<arith::SelectOp>(
-                nestedLoc, isNaN, oldIndex, resultIndex);
-          }
           nestedBuilder.create<linalg::YieldOp>(
               nestedLoc, ValueRange({resultIndex, resultMax}));
         });
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 9258442de5a45..eafc62eb71e05 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -1525,7 +1525,9 @@ func.func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () {
   // CHECK: arith.constant -3.40282347E+38 : f32
   // CHECK: linalg.index
   // CHECK: arith.index_cast
-  // CHECK: arith.cmpf ogt
+  // CHECK: arith.cmpf ugt
+  // CHECK: arith.cmpf ord
+  // CHECK: andi
   // CHECK: select
   // CHECK: select
   // CHECK: linalg.yield
@@ -2230,12 +2232,12 @@ func.func @maximum_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) ->
 // CHECK-LABEL: @argmax_nan_propagate
 func.func @argmax_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
   // CHECK: linalg.generic
-  // CHECK: arith.cmpf ogt
+  // CHECK: arith.cmpf ugt
+  // CHECK: arith.cmpf ord
+  // CHECK: andi
   // CHECK: arith.select
   // CHECK: arith.select
   // CHECK-NOT: arith.cmpf uno
-  // CHECK-NOT: arith.cmpf uno
-  // CHECK-NOT: arith.select
   // CHECK-NOT: arith.select
   // CHECK: linalg.yield
   %11 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>)  -> tensor<4xi32>
@@ -2267,9 +2269,6 @@ func.func @argmax_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) ->
   // CHECK: arith.cmpf ogt
   // CHECK: arith.select
   // CHECK: arith.select
-  // CHECK: arith.cmpf uno
-  // CHECK: arith.select
-  // CHECK: arith.select
   // CHECK: linalg.yield
   %12 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>)  -> tensor<4xi32>
   return

@RoboTux
Copy link
Contributor Author

RoboTux commented Apr 14, 2025

Ping?

@RoboTux RoboTux merged commit 95d526f into llvm:main Apr 14, 2025
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants