Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][tosa] Switch zero point of negate to input variable type #129758

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Tai78641
Copy link
Contributor

@Tai78641 Tai78641 commented Mar 4, 2025

This commit changes the zero point attribute to an input to align with the 1.0 spec.

@llvmbot
Copy link
Member

llvmbot commented Mar 4, 2025

@llvm/pr-subscribers-mlir-tosa
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Tai Ly (Tai78641)

Changes

This commit changes the zero point attribute to an input to align with the 1.0 spec.


Patch is 38.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129758.diff

17 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+7-3)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+3-3)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+17-4)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+20-9)
  • (modified) mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp (+41-1)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+27-4)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+86-15)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+20-30)
  • (modified) mlir/test/Dialect/Mesh/sharding-propagation.mlir (+8-5)
  • (modified) mlir/test/Dialect/Mesh/spmdization.mlir (+11-4)
  • (modified) mlir/test/Dialect/Tosa/availability.mlir (+3-1)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+43-2)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+64)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/quant-test.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+8-4)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index a9b458acd87f2..77e67a162fb9e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -112,8 +112,12 @@ profileComplianceMap = {
     {"tosa.logical_not",
      {{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}},
     {"tosa.negate",
-     {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
-      {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+     {{{Profile::pro_int},
+       {{i8T, i8T, i8T, i8T},
+        {i16T, i16T, i16T, i16T},
+        {i32T, i32T, i32T, i32T}}},
+      {{Profile::pro_fp},
+       {{fp16T, fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T, fp32T}}}}},
     {"tosa.reciprocal",
      {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
     {"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
@@ -308,7 +312,7 @@ extensionComplianceMap = {
     {"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
     {"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
     {"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
-    {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+    {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T}}}}},
     {"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
     {"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
     {"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index dceac03375606..743af85f3dd7b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -178,13 +178,13 @@ def Tosa_AvgPool2dOpQuantInfoBuilder : OpBuilder<
                                   input, kernel, stride, pad, acc_type);
   }]>;
 
-// This builder is called on single-parameter unary operators that have a scale
+// This builder is called on single-parameter negate operators that have a scale
 // relationship between their input and output, expressed by the
 // UnaryOpQuantizationAttr.
-def Tosa_UnaryOpQuantInfoBuilder : OpBuilder<
+def Tosa_NegateOpQuantInfoBuilder : OpBuilder<
   (ins "Type":$outputType, "Value":$input),
   [{
-    buildUnaryOpWithQuantInfo($_builder, $_state, outputType, input);
+    buildNegateOpWithQuantInfo($_builder, $_state, outputType, input);
   }]>;
 
 // These builders are called on the TOSA pad operator that needs to create its
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e0f2fd411bbe4..2a119be5fcd24 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1343,7 +1343,9 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseUnaryOp<"logical_not"> {
 //===----------------------------------------------------------------------===//
 // Operator: negate
 //===----------------------------------------------------------------------===//
-def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
+def Tosa_NegateOp : Tosa_InferShapedTypeOp<"negate", [
+    TosaElementwiseOperator,
+    Pure]> {
   let summary = "Elementwise negate op";
 
   let description = [{
@@ -1352,8 +1354,8 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
 
   let arguments = (ins
       Tosa_Tensor:$input1,
-      OptionalAttr<I32Attr>:$input1_zp,
-      OptionalAttr<I32Attr>:$output_zp
+      Tosa_ScalarTensor:$input1_zp,
+      Tosa_ScalarTensor:$output_zp
   );
 
   let results = (outs
@@ -1365,9 +1367,20 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
     Extension<[Tosa_EXT_BF16]>,
   ];
 
-  let builders = [Tosa_UnaryOpQuantInfoBuilder];
+  let builders = [Tosa_NegateOpQuantInfoBuilder];
+
+  let extraClassDeclaration = [{
+    FailureOr<int64_t> getInput1ZeroPoint();
+    FailureOr<int64_t> getOutputZeroPoint();
+    LogicalResult verifyInput1ZeroPoint(int64_t zp);
+    LogicalResult verifyOutputZeroPoint(int64_t zp);
+  }];
 
   let hasFolder = 1;
+  let hasVerifier = 1;
+
+  let assemblyFormat =
+      "operands attr-dict `:` functional-type(operands, results)";
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 8732ddafa24d4..af24a68704792 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -193,18 +193,29 @@ static Value createLinalgBodyCalculationForElementwiseOp(
 
   // tosa::NegateOp
   if (isa<tosa::NegateOp>(op)) {
-    if (isa<FloatType>(elementTy))
-      return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
+    auto negate = cast<tosa::NegateOp>(op);
 
-    if (isa<IntegerType>(elementTy)) {
-      auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1ZpAttr();
-      auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZpAttr();
+    FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
+    if (failed(maybeInZp)) {
+      (void)rewriter.notifyMatchFailure(
+          op, "input1 zero point cannot be statically determined");
+      return nullptr;
+    }
+
+    FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
+    if (failed(maybeOutZp)) {
+      (void)rewriter.notifyMatchFailure(
+          op, "output zero point cannot be statically determined");
+      return nullptr;
+    }
 
-      const int64_t inZp =
-          inputZpAttr ? inputZpAttr.getValue().getSExtValue() : 0;
-      const int64_t outZp =
-          outputZpAttr ? outputZpAttr.getValue().getSExtValue() : 0;
+    int64_t inZp = *maybeInZp;
+    int64_t outZp = *maybeOutZp;
 
+    if (isa<FloatType>(elementTy))
+      return rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
+
+    if (isa<IntegerType>(elementTy)) {
       if (!inZp && !outZp) {
         auto constant = rewriter.create<arith::ConstantOp>(
             loc, IntegerAttr::get(elementTy, 0));
diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
index ffbb707344b8c..e549e43c1cb8b 100644
--- a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
@@ -60,6 +60,45 @@ struct MatMulOpSharding
   }
 };
 
+struct NegateOpSharding
+    : public ShardingInterface::ExternalModel<NegateOpSharding, NegateOp> {
+  SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+    Value val = op->getOperand(0);
+    auto type = dyn_cast<RankedTensorType>(val.getType());
+    if (!type)
+      return {};
+    SmallVector<utils::IteratorType> types(type.getRank(),
+                                           utils::IteratorType::parallel);
+    return types;
+  }
+
+  SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+    MLIRContext *ctx = op->getContext();
+    Value val = op->getOperand(0);
+    auto type = dyn_cast<RankedTensorType>(val.getType());
+    if (!type)
+      return {};
+    int64_t rank = type.getRank();
+    SmallVector<AffineMap> maps = {
+        AffineMap::getMultiDimIdentityMap(rank, ctx),
+        AffineMap::get(0, 0, {}, ctx), AffineMap::get(0, 0, {}, ctx),
+        AffineMap::getMultiDimIdentityMap(rank, ctx)};
+    return maps;
+  }
+
+  LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+                        ArrayRef<MeshSharding> operandShardings,
+                        ArrayRef<MeshSharding> resultShardings,
+                        IRMapping &spmdizationMap,
+                        SymbolTableCollection &symbolTable,
+                        OpBuilder &builder) const {
+    spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
+                                       resultShardings, spmdizationMap,
+                                       symbolTable, builder);
+    return success();
+  }
+};
+
 template <typename OpType>
 static void registerElemwiseOne(MLIRContext *ctx) {
   OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
@@ -82,9 +121,10 @@ void mlir::tosa::registerShardingInterfaceExternalModels(
         BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp,
         LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
         MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
-        LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
+        LogOp, LogicalNotOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
         GreaterOp, GreaterEqualOp>(ctx);
 
     MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
+    NegateOp::attachInterface<NegateOpSharding>(*ctx);
   });
 }
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 363b5958bc0fd..663c2f9f4bd3b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1190,13 +1190,36 @@ OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
 }
 
 OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
-  auto input = getInput1();
   // Element-wise negate(negate(x)) = x
-  if (auto op = input.getDefiningOp<tosa::NegateOp>()) {
-    return op.getInput1();
+  // iff all zero points are constant 0
+  auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
+  if (!definingOp) {
+    // defining op of input1 is not a negate, cannot fold
+    return {};
   }
 
-  return {};
+  if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
+      failed(maybeIZp) || *maybeIZp != 0) {
+    // input1 zero point is not constant 0, cannot fold
+    return {};
+  }
+  if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
+      failed(maybeOZp) || *maybeOZp != 0) {
+    // output zero point is not constant 0, cannot fold
+    return {};
+  }
+  if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
+      failed(maybeIZp) || *maybeIZp != 0) {
+    // definingOp's input1 zero point is not constant 0, cannot fold
+    return {};
+  }
+  if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
+      failed(maybeOZp) || *maybeOZp != 0) {
+    // definingOp's output zero point is not constant 0, cannot fold
+    return {};
+  }
+
+  return definingOp.getInput1();
 }
 
 OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 8841d53b6e64d..9df9b77052d78 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -680,23 +680,43 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
   result.types.push_back(outputType);
 }
 
-/// This builder is called on single-parameter unary operators that have scale
-/// relationship between their input and output, expressed by the
-/// UnaryOpQuantizationAttr.
-static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
-                                      OperationState &result, Type outputType,
-                                      Value input) {
-  result.addOperands(input);
+/// This builder is called on single-parameter negate operator that
+/// have scale relationship between their input and output, expressed
+/// by the UnaryOpQuantizationAttr.
+static void buildNegateOpWithQuantInfo(OpBuilder &builder,
+                                       OperationState &result, Type outputType,
+                                       Value input) {
+  const Location loc{result.location};
+  int64_t input1Zp{0};
+  int64_t outputZp{0};
   auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
   if (quantAttr) {
-    // note: negateOp has attributes input1_zp and output_zp
-    result.addAttribute("input1_zp",
-                        builder.getI32IntegerAttr(
-                            static_cast<int32_t>(quantAttr.getInputZp())));
-    result.addAttribute("output_zp",
-                        builder.getI32IntegerAttr(
-                            static_cast<int32_t>(quantAttr.getOutputZp())));
+    input1Zp = quantAttr.getInputZp();
+    outputZp = quantAttr.getOutputZp();
+  }
+  const std::optional<Value> input1ZpOp =
+      createZeroPointTensor(builder, loc, input.getType(), input1Zp);
+  if (!input1ZpOp) {
+    (void)emitError(
+        loc, "Failed to create input1 zero point for quantized NEGATE op");
   }
+
+  const std::optional<Value> outputZpOp =
+      createZeroPointTensor(builder, loc, input.getType(), outputZp);
+  if (!outputZpOp) {
+    (void)emitError(
+        loc, "Failed to create output zero point for quantized NEGATE op");
+  }
+
+  if (input1ZpOp && outputZpOp) {
+    result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
+  } else {
+    // failed to create one or more zero points above: just add input as
+    // operands. This will trigger error in building the op because of
+    // missing zero points
+    result.addOperands({input});
+  }
+
   result.types.push_back(outputType);
 }
 
@@ -1560,6 +1580,9 @@ ZERO_POINT_HELPER(TransposeConv2DOp, Input)
 ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
 ZERO_POINT_HELPER(AvgPool2dOp, Input)
 ZERO_POINT_HELPER(AvgPool2dOp, Output)
+ZERO_POINT_HELPER(NegateOp, Input1)
+ZERO_POINT_HELPER(NegateOp, Output)
+
 #undef ZERO_POINT_HELPER
 
 LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
@@ -2039,7 +2062,6 @@ NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
 NARY_SHAPE_INFER(tosa::LogicalXorOp)
 NARY_SHAPE_INFER(tosa::MaximumOp)
 NARY_SHAPE_INFER(tosa::MinimumOp)
-NARY_SHAPE_INFER(tosa::NegateOp)
 NARY_SHAPE_INFER(tosa::PowOp)
 NARY_SHAPE_INFER(tosa::ReciprocalOp)
 NARY_SHAPE_INFER(tosa::RescaleOp)
@@ -2053,6 +2075,55 @@ NARY_SHAPE_INFER(tosa::ErfOp)
 NARY_SHAPE_INFER(tosa::SigmoidOp)
 #undef PRED_SHAPE_INFER
 
+LogicalResult tosa::NegateOp::inferReturnTypeComponents(
+    MLIRContext *context, ::std::optional<Location> location,
+    NegateOp::Adaptor adaptor,
+    SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+  ShapeAdaptor inputShape(adaptor.getInput1().getType());
+  inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+  return success();
+}
+
+LogicalResult tosa::NegateOp::verify() {
+  // Verify same element type
+  const Type input1Type = getInput1().getType();
+  const Type outputType = getOutput().getType();
+  if (verifySameElementTypes(*this, input1Type, outputType).failed())
+    return failure();
+
+  // Verify same shape
+  const SmallVector<Type, 2> types = {input1Type, outputType};
+  if (failed(verifyCompatibleShapes(types)))
+    return emitOpError() << "requires the same shape for input1 and output";
+
+  const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
+  const Type input1ZpEType =
+      getStorageElementTypeOrSelf(getInput1Zp().getType());
+  if (input1EType != input1ZpEType) {
+    return emitOpError("expect both input1 and its zero point are the same "
+                       "element type, got ")
+           << input1EType << " and " << input1ZpEType;
+  }
+  const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
+  const Type outputZpEType =
+      getStorageElementTypeOrSelf(getOutputZp().getType());
+  if (outputEType != outputZpEType) {
+    return emitOpError("expect both output and its zero point are the same "
+                       "element type, got ")
+           << outputEType << " and " << outputZpEType;
+  }
+
+  FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
+  if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
+    return failure();
+
+  FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
+  if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
+    return failure();
+
+  return success();
+}
+
 static LogicalResult poolingInferReturnTypes(
     ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
     ArrayRef<int64_t> pad,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 6ca260a5324a9..e9dfc68c89357 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -477,7 +477,9 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
 
   // CHECK: linalg.generic
   // CHECK: arith.negf
-  %5 = tosa.negate %0 : (tensor<1xf32>) -> tensor<1xf32>
+  %in_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %out_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %5 = tosa.negate %0, %in_zp, %out_zp : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: pow
@@ -662,10 +664,12 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
   %40 = tosa.int_div %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
+  // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32):
   // CHECK: [[ZERO:%.+]] = arith.constant 0
   // CHECK: arith.subi [[ZERO]], %[[ARG1]]
-  %5 = tosa.negate %arg0 : (tensor<1xi32>) -> tensor<1xi32>
+  %in_zp = "tosa.const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %out_zp = "tosa.const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %5 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: and
@@ -852,40 +856,22 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
 // CHECK-LABEL: @test_negate_quantized
 func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
   // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
-  // CHECK: [[ZERO:%.+]] = arith.constant 0
-  // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]], %[[BBARG0]]
-  // CHECK: linalg.yield [[SUB]]
-  %0 = tosa.negate %arg0 {input_zp1 = 0 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
-
-  // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
-  // CHECK: [[C32639:%.+]] = arith.constant 32639
+  // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
+  // CHECK: [[CNST:%.+]] = arith.constant 7
   // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
-  // CHECK: [[SUB:%.+]] = arith.subi [[C32639]], [[EXT]]
-  // CHECK: [[MIN:%.+]] = arith.constant -128
-  // CHECK: [[MAX:%.+]] = arith.constant 127
-  // CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
-  // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
-  // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
-  // CHECK: linalg.yield [[TRUNC]]
-  %1 = tosa.negate %arg0 {input1_zp = 32639 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
-
-  // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
-  // CHECK: [[C32640:%.+]] = arith.constant 32640
-  // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i32
-  // CHECK: [[SUB:%.+]] = arith.subi [[C32640]], [[EXT]]
+  // CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]]
   // CHECK: [[MIN:%.+]] = arith.constant -128
   // CHECK: [[MAX:%.+]] = arith.constant 127
   // CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
   // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
   // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
   // CHECK: linalg.yield [[TRUNC]]
-  %2 = tosa.negate %arg0 {input1_zp = 32640 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
+  %in_zp0 = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %out_zp0 = "tosa.const"() <{value = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %0 = tosa.negate %arg0, %in_zp0, %out_zp0 : (tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
 
   // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
+  // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
   // CHECK: [[C_128:%.+]] = arith.constant -128
   // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
   // CHECK: [[SUB:%.+]] = arith.subi [[C_128]], [[EXT]]
@@ -895,14 +881,18 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -...
[truncated]

@Jerry-Ge Jerry-Ge requested review from GeorgeARM and lhutton1 and removed request for GeorgeARM March 4, 2025 18:29
Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Tai78641, had a couple of comments, otherwise LGTM. Note: I won't explicitly approve since I authored some of this change

lhutton1
lhutton1 previously approved these changes Mar 6, 2025
@lhutton1 lhutton1 dismissed their stale review March 6, 2025 08:54

Accidental approval - I authored some of this patch, so don't want to explicitly approve

@Tai78641
Copy link
Contributor Author

Tai78641 commented Mar 6, 2025

rebased and resolved merge conflict in invalid.mlir

@Tai78641 Tai78641 force-pushed the pr_negate_zp branch 3 times, most recently from 3287c35 to edba22f Compare March 7, 2025 19:36
This commit changes the zero point attribute to an input to align
with the 1.0 spec.

Change-Id: Ibc9e5959b36c182a9e0c5c23a2f9d42a572a1184
Signed-off-by: Tai Ly <tai.ly@arm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants