-
Notifications
You must be signed in to change notification settings - Fork 12.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tosa] Switch zero point of negate to input variable type #129758
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir-linalg Author: Tai Ly (Tai78641) ChangesThis commit changes the zero point attribute to an input to align with the 1.0 spec. Patch is 38.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129758.diff 17 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index a9b458acd87f2..77e67a162fb9e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -112,8 +112,12 @@ profileComplianceMap = {
{"tosa.logical_not",
{{{Profile::pro_int, Profile::pro_fp}, {{boolT, boolT}}}}},
{"tosa.negate",
- {{{Profile::pro_int}, {{i8T, i8T}, {i16T, i16T}, {i32T, i32T}}},
- {{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
+ {{{Profile::pro_int},
+ {{i8T, i8T, i8T, i8T},
+ {i16T, i16T, i16T, i16T},
+ {i32T, i32T, i32T, i32T}}},
+ {{Profile::pro_fp},
+ {{fp16T, fp16T, fp16T, fp16T}, {fp32T, fp32T, fp32T, fp32T}}}}},
{"tosa.reciprocal",
{{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
{"tosa.rsqrt", {{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
@@ -308,7 +312,7 @@ extensionComplianceMap = {
{"tosa.exp", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
{"tosa.floor", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
{"tosa.log", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
- {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
+ {"tosa.negate", {{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T}}}}},
{"tosa.reciprocal", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
{"tosa.rsqrt", {{{Extension::bf16}, {{bf16T, bf16T}}}}},
{"tosa.select", {{{Extension::bf16}, {{bf16T, bf16T, bf16T}}}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index dceac03375606..743af85f3dd7b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -178,13 +178,13 @@ def Tosa_AvgPool2dOpQuantInfoBuilder : OpBuilder<
input, kernel, stride, pad, acc_type);
}]>;
-// This builder is called on single-parameter unary operators that have a scale
+// This builder is called on single-parameter negate operators that have a scale
// relationship between their input and output, expressed by the
// UnaryOpQuantizationAttr.
-def Tosa_UnaryOpQuantInfoBuilder : OpBuilder<
+def Tosa_NegateOpQuantInfoBuilder : OpBuilder<
(ins "Type":$outputType, "Value":$input),
[{
- buildUnaryOpWithQuantInfo($_builder, $_state, outputType, input);
+ buildNegateOpWithQuantInfo($_builder, $_state, outputType, input);
}]>;
// These builders are called on the TOSA pad operator that needs to create its
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e0f2fd411bbe4..2a119be5fcd24 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1343,7 +1343,9 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseUnaryOp<"logical_not"> {
//===----------------------------------------------------------------------===//
// Operator: negate
//===----------------------------------------------------------------------===//
-def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
+def Tosa_NegateOp : Tosa_InferShapedTypeOp<"negate", [
+ TosaElementwiseOperator,
+ Pure]> {
let summary = "Elementwise negate op";
let description = [{
@@ -1352,8 +1354,8 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
let arguments = (ins
Tosa_Tensor:$input1,
- OptionalAttr<I32Attr>:$input1_zp,
- OptionalAttr<I32Attr>:$output_zp
+ Tosa_ScalarTensor:$input1_zp,
+ Tosa_ScalarTensor:$output_zp
);
let results = (outs
@@ -1365,9 +1367,20 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> {
Extension<[Tosa_EXT_BF16]>,
];
- let builders = [Tosa_UnaryOpQuantInfoBuilder];
+ let builders = [Tosa_NegateOpQuantInfoBuilder];
+
+ let extraClassDeclaration = [{
+ FailureOr<int64_t> getInput1ZeroPoint();
+ FailureOr<int64_t> getOutputZeroPoint();
+ LogicalResult verifyInput1ZeroPoint(int64_t zp);
+ LogicalResult verifyOutputZeroPoint(int64_t zp);
+ }];
let hasFolder = 1;
+ let hasVerifier = 1;
+
+ let assemblyFormat =
+ "operands attr-dict `:` functional-type(operands, results)";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 8732ddafa24d4..af24a68704792 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -193,18 +193,29 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// tosa::NegateOp
if (isa<tosa::NegateOp>(op)) {
- if (isa<FloatType>(elementTy))
- return rewriter.create<arith::NegFOp>(loc, resultTypes, args);
+ auto negate = cast<tosa::NegateOp>(op);
- if (isa<IntegerType>(elementTy)) {
- auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1ZpAttr();
- auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZpAttr();
+ FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
+ if (failed(maybeInZp)) {
+ (void)rewriter.notifyMatchFailure(
+ op, "input1 zero point cannot be statically determined");
+ return nullptr;
+ }
+
+ FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
+ if (failed(maybeOutZp)) {
+ (void)rewriter.notifyMatchFailure(
+ op, "output zero point cannot be statically determined");
+ return nullptr;
+ }
- const int64_t inZp =
- inputZpAttr ? inputZpAttr.getValue().getSExtValue() : 0;
- const int64_t outZp =
- outputZpAttr ? outputZpAttr.getValue().getSExtValue() : 0;
+ int64_t inZp = *maybeInZp;
+ int64_t outZp = *maybeOutZp;
+ if (isa<FloatType>(elementTy))
+ return rewriter.create<arith::NegFOp>(loc, resultTypes, args[0]);
+
+ if (isa<IntegerType>(elementTy)) {
if (!inZp && !outZp) {
auto constant = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(elementTy, 0));
diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
index ffbb707344b8c..e549e43c1cb8b 100644
--- a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
@@ -60,6 +60,45 @@ struct MatMulOpSharding
}
};
+struct NegateOpSharding
+ : public ShardingInterface::ExternalModel<NegateOpSharding, NegateOp> {
+ SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+ Value val = op->getOperand(0);
+ auto type = dyn_cast<RankedTensorType>(val.getType());
+ if (!type)
+ return {};
+ SmallVector<utils::IteratorType> types(type.getRank(),
+ utils::IteratorType::parallel);
+ return types;
+ }
+
+ SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+ MLIRContext *ctx = op->getContext();
+ Value val = op->getOperand(0);
+ auto type = dyn_cast<RankedTensorType>(val.getType());
+ if (!type)
+ return {};
+ int64_t rank = type.getRank();
+ SmallVector<AffineMap> maps = {
+ AffineMap::getMultiDimIdentityMap(rank, ctx),
+ AffineMap::get(0, 0, {}, ctx), AffineMap::get(0, 0, {}, ctx),
+ AffineMap::getMultiDimIdentityMap(rank, ctx)};
+ return maps;
+ }
+
+ LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshSharding> operandShardings,
+ ArrayRef<MeshSharding> resultShardings,
+ IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
+ spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings,
+ resultShardings, spmdizationMap,
+ symbolTable, builder);
+ return success();
+ }
+};
+
template <typename OpType>
static void registerElemwiseOne(MLIRContext *ctx) {
OpType::template attachInterface<ElementwiseShardingInterface<OpType>>(*ctx);
@@ -82,9 +121,10 @@ void mlir::tosa::registerShardingInterfaceExternalModels(
BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp,
LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp,
MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp,
- LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
+ LogOp, LogicalNotOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp,
GreaterOp, GreaterEqualOp>(ctx);
MatMulOp::attachInterface<MatMulOpSharding>(*ctx);
+ NegateOp::attachInterface<NegateOpSharding>(*ctx);
});
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 363b5958bc0fd..663c2f9f4bd3b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1190,13 +1190,36 @@ OpFoldResult tosa::ExpOp::fold(FoldAdaptor adaptor) {
}
OpFoldResult tosa::NegateOp::fold(FoldAdaptor adaptor) {
- auto input = getInput1();
// Element-wise negate(negate(x)) = x
- if (auto op = input.getDefiningOp<tosa::NegateOp>()) {
- return op.getInput1();
+ // iff all zero points are constant 0
+ auto definingOp = getInput1().getDefiningOp<tosa::NegateOp>();
+ if (!definingOp) {
+ // defining op of input1 is not a negate, cannot fold
+ return {};
}
- return {};
+ if (FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
+ failed(maybeIZp) || *maybeIZp != 0) {
+ // input1 zero point is not constant 0, cannot fold
+ return {};
+ }
+ if (FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
+ failed(maybeOZp) || *maybeOZp != 0) {
+ // output zero point is not constant 0, cannot fold
+ return {};
+ }
+ if (FailureOr<int64_t> maybeIZp = definingOp.getInput1ZeroPoint();
+ failed(maybeIZp) || *maybeIZp != 0) {
+ // definingOp's input1 zero point is not constant 0, cannot fold
+ return {};
+ }
+ if (FailureOr<int64_t> maybeOZp = definingOp.getOutputZeroPoint();
+ failed(maybeOZp) || *maybeOZp != 0) {
+ // definingOp's output zero point is not constant 0, cannot fold
+ return {};
+ }
+
+ return definingOp.getInput1();
}
OpFoldResult tosa::AbsOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 8841d53b6e64d..9df9b77052d78 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -680,23 +680,43 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
result.types.push_back(outputType);
}
-/// This builder is called on single-parameter unary operators that have scale
-/// relationship between their input and output, expressed by the
-/// UnaryOpQuantizationAttr.
-static void buildUnaryOpWithQuantInfo(OpBuilder &builder,
- OperationState &result, Type outputType,
- Value input) {
- result.addOperands(input);
+/// This builder is called on single-parameter negate operator that
+/// have scale relationship between their input and output, expressed
+/// by the UnaryOpQuantizationAttr.
+static void buildNegateOpWithQuantInfo(OpBuilder &builder,
+ OperationState &result, Type outputType,
+ Value input) {
+ const Location loc{result.location};
+ int64_t input1Zp{0};
+ int64_t outputZp{0};
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
if (quantAttr) {
- // note: negateOp has attributes input1_zp and output_zp
- result.addAttribute("input1_zp",
- builder.getI32IntegerAttr(
- static_cast<int32_t>(quantAttr.getInputZp())));
- result.addAttribute("output_zp",
- builder.getI32IntegerAttr(
- static_cast<int32_t>(quantAttr.getOutputZp())));
+ input1Zp = quantAttr.getInputZp();
+ outputZp = quantAttr.getOutputZp();
+ }
+ const std::optional<Value> input1ZpOp =
+ createZeroPointTensor(builder, loc, input.getType(), input1Zp);
+ if (!input1ZpOp) {
+ (void)emitError(
+ loc, "Failed to create input1 zero point for quantized NEGATE op");
}
+
+ const std::optional<Value> outputZpOp =
+ createZeroPointTensor(builder, loc, input.getType(), outputZp);
+ if (!outputZpOp) {
+ (void)emitError(
+ loc, "Failed to create output zero point for quantized NEGATE op");
+ }
+
+ if (input1ZpOp && outputZpOp) {
+ result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
+ } else {
+ // failed to create one or more zero points above: just add input as
+ // operands. This will trigger error in building the op because of
+ // missing zero points
+ result.addOperands({input});
+ }
+
result.types.push_back(outputType);
}
@@ -1560,6 +1580,9 @@ ZERO_POINT_HELPER(TransposeConv2DOp, Input)
ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
ZERO_POINT_HELPER(AvgPool2dOp, Input)
ZERO_POINT_HELPER(AvgPool2dOp, Output)
+ZERO_POINT_HELPER(NegateOp, Input1)
+ZERO_POINT_HELPER(NegateOp, Output)
+
#undef ZERO_POINT_HELPER
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
@@ -2039,7 +2062,6 @@ NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
NARY_SHAPE_INFER(tosa::LogicalXorOp)
NARY_SHAPE_INFER(tosa::MaximumOp)
NARY_SHAPE_INFER(tosa::MinimumOp)
-NARY_SHAPE_INFER(tosa::NegateOp)
NARY_SHAPE_INFER(tosa::PowOp)
NARY_SHAPE_INFER(tosa::ReciprocalOp)
NARY_SHAPE_INFER(tosa::RescaleOp)
@@ -2053,6 +2075,55 @@ NARY_SHAPE_INFER(tosa::ErfOp)
NARY_SHAPE_INFER(tosa::SigmoidOp)
#undef PRED_SHAPE_INFER
+LogicalResult tosa::NegateOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ NegateOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ ShapeAdaptor inputShape(adaptor.getInput1().getType());
+ inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+ return success();
+}
+
+LogicalResult tosa::NegateOp::verify() {
+ // Verify same element type
+ const Type input1Type = getInput1().getType();
+ const Type outputType = getOutput().getType();
+ if (verifySameElementTypes(*this, input1Type, outputType).failed())
+ return failure();
+
+ // Verify same shape
+ const SmallVector<Type, 2> types = {input1Type, outputType};
+ if (failed(verifyCompatibleShapes(types)))
+ return emitOpError() << "requires the same shape for input1 and output";
+
+ const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
+ const Type input1ZpEType =
+ getStorageElementTypeOrSelf(getInput1Zp().getType());
+ if (input1EType != input1ZpEType) {
+ return emitOpError("expect both input1 and its zero point are the same "
+ "element type, got ")
+ << input1EType << " and " << input1ZpEType;
+ }
+ const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
+ const Type outputZpEType =
+ getStorageElementTypeOrSelf(getOutputZp().getType());
+ if (outputEType != outputZpEType) {
+ return emitOpError("expect both output and its zero point are the same "
+ "element type, got ")
+ << outputEType << " and " << outputZpEType;
+ }
+
+ FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
+ if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
+ return failure();
+
+ FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
+ if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
+ return failure();
+
+ return success();
+}
+
static LogicalResult poolingInferReturnTypes(
ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
ArrayRef<int64_t> pad,
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 6ca260a5324a9..e9dfc68c89357 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -477,7 +477,9 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
// CHECK: linalg.generic
// CHECK: arith.negf
- %5 = tosa.negate %0 : (tensor<1xf32>) -> tensor<1xf32>
+ %in_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %out_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %5 = tosa.negate %0, %in_zp, %out_zp : (tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
// CHECK: linalg.generic
// CHECK: pow
@@ -662,10 +664,12 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
%40 = tosa.int_div %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
- // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
+ // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32):
// CHECK: [[ZERO:%.+]] = arith.constant 0
// CHECK: arith.subi [[ZERO]], %[[ARG1]]
- %5 = tosa.negate %arg0 : (tensor<1xi32>) -> tensor<1xi32>
+ %in_zp = "tosa.const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %out_zp = "tosa.const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %5 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
// CHECK: and
@@ -852,40 +856,22 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
// CHECK-LABEL: @test_negate_quantized
func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
// CHECK: linalg.generic
- // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
- // CHECK: [[ZERO:%.+]] = arith.constant 0
- // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]], %[[BBARG0]]
- // CHECK: linalg.yield [[SUB]]
- %0 = tosa.negate %arg0 {input_zp1 = 0 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
-
- // CHECK: linalg.generic
- // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
- // CHECK: [[C32639:%.+]] = arith.constant 32639
+ // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
+ // CHECK: [[CNST:%.+]] = arith.constant 7
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
- // CHECK: [[SUB:%.+]] = arith.subi [[C32639]], [[EXT]]
- // CHECK: [[MIN:%.+]] = arith.constant -128
- // CHECK: [[MAX:%.+]] = arith.constant 127
- // CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
- // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
- // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
- // CHECK: linalg.yield [[TRUNC]]
- %1 = tosa.negate %arg0 {input1_zp = 32639 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
-
- // CHECK: linalg.generic
- // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
- // CHECK: [[C32640:%.+]] = arith.constant 32640
- // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i32
- // CHECK: [[SUB:%.+]] = arith.subi [[C32640]], [[EXT]]
+ // CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]]
// CHECK: [[MIN:%.+]] = arith.constant -128
// CHECK: [[MAX:%.+]] = arith.constant 127
// CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]]
// CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]]
// CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]]
// CHECK: linalg.yield [[TRUNC]]
- %2 = tosa.negate %arg0 {input1_zp = 32640 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8>
+ %in_zp0 = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %out_zp0 = "tosa.const"() <{value = dense<7> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.negate %arg0, %in_zp0, %out_zp0 : (tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
// CHECK: linalg.generic
- // CHECK: ^bb0(%[[BBARG0:.+]]: i8,
+ // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
// CHECK: [[C_128:%.+]] = arith.constant -128
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
// CHECK: [[SUB:%.+]] = arith.subi [[C_128]], [[EXT]]
@@ -895,14 +881,18 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Tai78641, had a couple of comments, otherwise LGTM. Note: I won't explicitly approve since I authored some of this change
Accidental approval - I authored some of this patch, so don't want to explicitly approve
rebased and resolved merge conflict in invalid.mlir |
3287c35
to
edba22f
Compare
This commit changes the zero point attribute to an input to align with the 1.0 spec. Change-Id: Ibc9e5959b36c182a9e0c5c23a2f9d42a572a1184 Signed-off-by: Tai Ly <tai.ly@arm.com>
This commit changes the zero point attribute to an input to align with the 1.0 spec.