-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[CIR] Implement EqualOp for ComplexType #145769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CIR] Implement EqualOp for ComplexType #145769
Conversation
|
@llvm/pr-subscribers-clang @llvm/pr-subscribers-clangir Author: Amr Hesham (AmrDeveloper) ChangesThis change adds support for equal operation for ComplexType Full diff: https://github.com/llvm/llvm-project/pull/145769.diff 5 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 5f24ab7816cbc..6eef525f52f8e 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2429,6 +2429,31 @@ def ComplexImagOp : CIR_Op<"complex.imag", [Pure]> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// ComplexEqualOp
+//===----------------------------------------------------------------------===//
+
+def ComplexEqualOp : CIR_Op<"complex.eq", [Pure, SameTypeOperands]> {
+
+ let summary = "Computes whether two complex values are equal";
+ let description = [{
+ The `complex.equal` op takes two complex numbers and returns whether
+ they are equal.
+
+ ```mlir
+ %r = cir.complex.eq %a, %b : !cir.complex<!cir.float>
+ ```
+ }];
+
+ let results = (outs CIR_BoolType:$result);
+ let arguments = (ins CIR_ComplexType:$lhs, CIR_ComplexType:$rhs);
+
+ let assemblyFormat = [{
+ $lhs `,` $rhs
+ `:` qualified(type($lhs)) attr-dict
+ }];
+}
+
//===----------------------------------------------------------------------===//
// Assume Operations
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
index 7f8dcd96a6bff..4bcbc6d7ce798 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
@@ -894,9 +894,17 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
}
} else {
// Complex Comparison: can only be an equality comparison.
- assert(!cir::MissingFeatures::complexType());
- cgf.cgm.errorNYI(loc, "complex comparison");
- result = builder.getBool(false, loc);
+ assert(e->getOpcode() == BO_EQ || e->getOpcode() == BO_NE);
+
+ BinOpInfo boInfo = emitBinOps(e);
+ if (e->getOpcode() == BO_EQ) {
+ result =
+ builder.create<cir::ComplexEqualOp>(loc, boInfo.lhs, boInfo.rhs);
+ } else {
+ assert(!cir::MissingFeatures::complexType());
+ cgf.cgm.errorNYI(loc, "complex not equal");
+ result = builder.getBool(false, loc);
+ }
}
return emitScalarConversion(result, cgf.getContext().BoolTy, e->getType(),
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index d41afbdd0b69e..1d33b00d026f4 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1905,7 +1905,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
CIRToLLVMVecTernaryOpLowering,
CIRToLLVMComplexCreateOpLowering,
CIRToLLVMComplexRealOpLowering,
- CIRToLLVMComplexImagOpLowering
+ CIRToLLVMComplexImagOpLowering,
+ CIRToLLVMComplexEqualOpLowering
// clang-format on
>(converter, patterns.getContext());
@@ -2227,6 +2228,43 @@ mlir::LogicalResult CIRToLLVMComplexImagOpLowering::matchAndRewrite(
return mlir::success();
}
+mlir::LogicalResult CIRToLLVMComplexEqualOpLowering::matchAndRewrite(
+ cir::ComplexEqualOp op, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ mlir::Value lhs = adaptor.getLhs();
+ mlir::Value rhs = adaptor.getRhs();
+
+ auto complexType = mlir::cast<cir::ComplexType>(op.getLhs().getType());
+ mlir::Type complexElemTy =
+ getTypeConverter()->convertType(complexType.getElementType());
+
+ mlir::Location loc = op.getLoc();
+ auto lhsReal =
+ rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0);
+ auto lhsImag =
+ rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1);
+ auto rhsReal =
+ rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0);
+ auto rhsImag =
+ rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1);
+
+ if (complexElemTy.isInteger()) {
+ auto realCmp = rewriter.create<mlir::LLVM::ICmpOp>(
+ loc, mlir::LLVM::ICmpPredicate::eq, lhsReal, rhsReal);
+ auto imagCmp = rewriter.create<mlir::LLVM::ICmpOp>(
+ loc, mlir::LLVM::ICmpPredicate::eq, lhsImag, rhsImag);
+ rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, realCmp, imagCmp);
+ return mlir::success();
+ }
+
+ auto realCmp = rewriter.create<mlir::LLVM::FCmpOp>(
+ loc, mlir::LLVM::FCmpPredicate::oeq, lhsReal, rhsReal);
+ auto imagCmp = rewriter.create<mlir::LLVM::FCmpOp>(
+ loc, mlir::LLVM::FCmpPredicate::oeq, lhsImag, rhsImag);
+ rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, realCmp, imagCmp);
+ return mlir::success();
+}
+
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
return std::make_unique<ConvertCIRToLLVMPass>();
}
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index 8502cb1ae5d9f..25cf218cf8b6c 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -463,6 +463,16 @@ class CIRToLLVMComplexImagOpLowering
mlir::ConversionPatternRewriter &) const override;
};
+class CIRToLLVMComplexEqualOpLowering
+ : public mlir::OpConversionPattern<cir::ComplexEqualOp> {
+public:
+ using mlir::OpConversionPattern<cir::ComplexEqualOp>::OpConversionPattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(cir::ComplexEqualOp op, OpAdaptor,
+ mlir::ConversionPatternRewriter &) const override;
+};
+
} // namespace direct
} // namespace cir
diff --git a/clang/test/CIR/CodeGen/complex.cpp b/clang/test/CIR/CodeGen/complex.cpp
index ad3720097a795..1e9ce0e29fd46 100644
--- a/clang/test/CIR/CodeGen/complex.cpp
+++ b/clang/test/CIR/CodeGen/complex.cpp
@@ -368,4 +368,77 @@ int foo17(int _Complex a, int _Complex b) {
// OGCG: %[[B_REAL:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 0
// OGCG: %[[TMP_B:.*]] = load i32, ptr %[[B_REAL]], align 4
// OGCG: %[[ADD:.*]] = add nsw i32 %[[TMP_A]], %[[TMP_B]]
-// OGCG: ret i32 %[[ADD]]
\ No newline at end of file
+// OGCG: ret i32 %[[ADD]]
+
+bool foo18(int _Complex a, int _Complex b) {
+ return a == b;
+}
+
+// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
+// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!s32i>>, !cir.complex<!s32i>
+// CIR: %[[RESULT:.*]] = cir.complex.eq %[[COMPLEX_A]], %[[COMPLEX_B]] : !cir.complex<!s32i>
+
+// LLVM: %[[COMPLEX_A:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
+// LLVM: %[[COMPLEX_B:.*]] = load { i32, i32 }, ptr {{.*}}, align 4
+// LLVM: %[[A_REAL:.*]] = extractvalue { i32, i32 } %[[COMPLEX_A]], 0
+// LLVM: %[[A_IMAG:.*]] = extractvalue { i32, i32 } %[[COMPLEX_A]], 1
+// LLVM: %[[B_REAL:.*]] = extractvalue { i32, i32 } %[[COMPLEX_B]], 0
+// LLVM: %[[B_IMAG:.*]] = extractvalue { i32, i32 } %[[COMPLEX_B]], 1
+// LLVM: %[[CMP_REAL:.*]] = icmp eq i32 %[[A_REAL]], %[[B_REAL]]
+// LLVM: %[[CMP_IMAG:.*]] = icmp eq i32 %[[A_IMAG]], %[[B_IMAG]]
+// LLVM: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]]
+
+// OGCG: %[[COMPLEX_A:.*]] = alloca { i32, i32 }, align 4
+// OGCG: %[[COMPLEX_B:.*]] = alloca { i32, i32 }, align 4
+// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_A]], i32 0, i32 0
+// OGCG: %[[A_REAL:.*]] = load i32, ptr %[[A_REAL_PTR]], align 4
+// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_A]], i32 0, i32 1
+// OGCG: %[[A_IMAG:.*]] = load i32, ptr %[[A_IMAG_PTR]], align 4
+// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 0
+// OGCG: %[[B_REAL:.*]] = load i32, ptr %[[B_REAL_PTR]], align 4
+// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { i32, i32 }, ptr %[[COMPLEX_B]], i32 0, i32 1
+// OGCG: %[[B_IMAG:.*]] = load i32, ptr %[[B_IMAG_PTR]], align 4
+// OGCG: %[[CMP_REAL:.*]] = icmp eq i32 %[[A_REAL]], %[[B_REAL]]
+// OGCG: %[[CMP_IMAG:.*]] = icmp eq i32 %[[A_IMAG]], %[[B_IMAG]]
+// OGCG: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]]
+
+bool foo19(double _Complex a, double _Complex b) {
+ return a == b;
+}
+
+// CIR: %[[COMPLEX_A:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
+// CIR: %[[COMPLEX_B:.*]] = cir.load{{.*}} {{.*}} : !cir.ptr<!cir.complex<!cir.double>>, !cir.complex<!cir.double>
+// CIR: %[[RESULT:.*]] = cir.complex.eq %[[COMPLEX_A]], %[[COMPLEX_B]] : !cir.complex<!cir.double>
+
+// LLVM: %[[COMPLEX_A:.*]] = load { double, double }, ptr {{.*}}, align 8
+// LLVM: %[[COMPLEX_B:.*]] = load { double, double }, ptr {{.*}}, align 8
+// LLVM: %[[A_REAL:.*]] = extractvalue { double, double } %[[COMPLEX_A]], 0
+// LLVM: %[[A_IMAG:.*]] = extractvalue { double, double } %[[COMPLEX_A]], 1
+// LLVM: %[[B_REAL:.*]] = extractvalue { double, double } %[[COMPLEX_B]], 0
+// LLVM: %[[B_IMAG:.*]] = extractvalue { double, double } %[[COMPLEX_B]], 1
+// LLVM: %[[CMP_REAL:.*]] = fcmp oeq double %[[A_REAL]], %[[B_REAL]]
+// LLVM: %[[CMP_IMAG:.*]] = fcmp oeq double %[[A_IMAG]], %[[B_IMAG]]
+// LLVM: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]]
+
+// OGCG: %[[COMPLEX_A:.*]] = alloca { double, double }, align 8
+// OGCG: %[[COMPLEX_B:.*]] = alloca { double, double }, align 8
+// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 0
+// OGCG: store double {{.*}}, ptr %[[A_REAL_PTR]], align 8
+// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 1
+// OGCG: store double {{.*}}, ptr %[[A_IMAG_PTR]], align 8
+// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 0
+// OGCG: store double {{.*}}, ptr %[[B_REAL_PTR]], align 8
+// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 1
+// OGCG: store double {{.*}}, ptr %[[B_IMAG_PTR]], align 8
+// OGCG: %[[A_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 0
+// OGCG: %[[A_REAL:.*]] = load double, ptr %[[A_REAL_PTR]], align 8
+// OGCG: %[[A_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_A]], i32 0, i32 1
+// OGCG: %[[A_IMAG:.*]] = load double, ptr %[[A_IMAG_PTR]], align 8
+// OGCG: %[[B_REAL_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 0
+// OGCG: %[[B_REAL:.*]] = load double, ptr %[[B_REAL_PTR]], align 8
+// OGCG: %[[B_IMAG_PTR:.*]] = getelementptr inbounds nuw { double, double }, ptr %[[COMPLEX_B]], i32 0, i32 1
+// OGCG: %[[B_IMAG:.*]] = load double, ptr %[[B_IMAG_PTR]], align 8
+// OGCG: %[[CMP_REAL:.*]] = fcmp oeq double %[[A_REAL]], %[[B_REAL]]
+// OGCG: %[[CMP_IMAG:.*]] = fcmp oeq double %[[A_IMAG]], %[[B_IMAG]]
+// OGCG: %[[RESULT:.*]] = and i1 %[[CMP_REAL]], %[[CMP_IMAG]]
+
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great. Just one very minor nit.
@@ -1905,7 +1905,8 @@ void ConvertCIRToLLVMPass::runOnOperation() { | |||
CIRToLLVMVecTernaryOpLowering, | |||
CIRToLLVMComplexCreateOpLowering, | |||
CIRToLLVMComplexRealOpLowering, | |||
CIRToLLVMComplexImagOpLowering | |||
CIRToLLVMComplexImagOpLowering, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I missed this in previous reviews, but we were trying to keep these in lexicographical order.
This change adds support for equal operation for ComplexType llvm#141365
This change adds support for equal operation for ComplexType llvm#141365
This change adds support for equal operation for ComplexType
#141365