In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import math
import os


# ================================================================== #
#                     选择特征提取器函数
# ================================================================== #
def detectAndDescribe(image):
    # SIFT特征提取器
    descriptor = cv2.SIFT_create()
    (kps, features) = descriptor.detectAndCompute(image, None)
    return (kps, features)


# ================================================================== #
#                     使用knn检测函数
# ================================================================== #
def matchKeyPointsKNN(featuresA, featuresB, ratio):
    bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=False)

    rawMatches = bf.knnMatch(featuresA, featuresB, 2)
    matches = []
    for m, n in rawMatches:
        if m.distance < n.distance * ratio:
            matches.append(m)
    return matches


# ================================================================== #
#                     计算关键点单应性变化
# ================================================================== #
def getHomography(kpsA, kpsB, matches, reprojThresh):
    kpsA = np.float32([kp.pt for kp in kpsA])
    kpsB = np.float32([kp.pt for kp in kpsB])
    if len(matches) > 4:

        ptsA = np.float32([kpsA[m.queryIdx] for m in matches])
        ptsB = np.float32([kpsB[m.trainIdx] for m in matches])

        (H, status) = cv2.findHomography(ptsA, ptsB, cv2.RANSAC,
                                         reprojThresh)
        return (matches, H, status)
    else:
        return None


# ================================================================== #
#                     去除图像黑边
# ================================================================== #
def cutBlack(pic):
    rows, cols = np.where(pic[:, :, 0] != 0)
    min_row, max_row = min(rows), max(rows) + 1
    min_col, max_col = min(cols), max(cols) + 1
    pic = pic[min_row:max_row, min_col:max_col, :]
    return pic


# ================================================================== #
#                          调换
# ================================================================== #
def swap(a, b):
    return b, a


# ================================================================== #
#                            主要的函数
#   默认使用SIFT特征，修改为其他特征时注意修改detectAndDescribe函数中的特征提取器
#          和matchKeyPointsKNN函数中的距离计算，以达到好的效果
# ================================================================== #
def handle(path1, path2, isShow=False):
    """
    读取原始图像
    """
    if isinstance(path2, str):
        imageA = cv2.imread(path2)
        imageA = cv2.cvtColor(imageA, cv2.COLOR_BGR2RGB)
    else:
        imageA = path2
    imageA_gray = cv2.cvtColor(imageA, cv2.COLOR_RGB2GRAY)

    if isinstance(path1, str):
        imageB = cv2.imread(path1)
        imageB = cv2.cvtColor(imageB, cv2.COLOR_BGR2RGB)
    else:
        imageB = path1
    imageB_gray = cv2.cvtColor(imageB, cv2.COLOR_RGB2GRAY)

    """
    显示输入的两张图片
    """
    if isShow:
        f = plt.figure(figsize=(10, 4))
        f.add_subplot(1, 2, 1)
        plt.title("imageB")
        plt.imshow(imageB)
        plt.xticks([]), plt.yticks([])
        f.add_subplot(1, 2, 2)
        plt.title("imageA")
        plt.imshow(imageA)
        plt.xticks([]), plt.yticks([])
    """
    提取两张图片的特征
    """
    kpsA, featuresA = detectAndDescribe(imageA_gray)
    kpsB, featuresB = detectAndDescribe(imageB_gray)
    """
    显示关键点
    """
    if isShow:
        fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(10, 4), constrained_layout=False)
        ax1.imshow(cv2.drawKeypoints(imageA_gray, kpsA, None, color=(0, 255, 0)))
        ax1.set_xlabel("(a)key point", fontsize=14)
        ax2.imshow(cv2.drawKeypoints(imageB_gray, kpsB, None, color=(0, 255, 0)))
        ax2.set_xlabel("(b)key point", fontsize=14)

    """
    进行特征匹配
    """
    matches = matchKeyPointsKNN(featuresA, featuresB, ratio=0.75)
    if len(matches) < 10:
        return None
    img3 = cv2.drawMatches(imageA, kpsA, imageB, kpsB, np.random.choice(matches, 100),
                           None, flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS)
    """
    匹配的特征展示
    """
    if isShow:
        fig = plt.figure(figsize=(10, 4))
        plt.imshow(img3)
        plt.title("feature match")
        plt.axis('off')

    """
    计算两张图的单应性变换
    """
    matchCount = len(matches)
    M = getHomography(kpsA, kpsB, matches, reprojThresh=4)
    if M is None:
        print("Error!")
    (matches, H, status) = M
    """
    将图片A进行透视变换并检查图片位置
    """
    result = cv2.warpPerspective(imageA, H,
                                 ((imageA.shape[1] + imageB.shape[1]) * 2, (imageA.shape[0] + imageB.shape[0]) * 2))

    resultAfterCut = cutBlack(result)

    # 检查图片位置
    if np.size(resultAfterCut) < np.size(imageA) * 0.95:
        print("图片位置不对,将自动调换")
        # 调换图片
        kpsA, kpsB = swap(kpsA, kpsB)
        imageA, imageB = swap(imageA, imageB)

        matches = matchKeyPointsKNN(featuresB, featuresA, ratio=0.75)
        if len(matches) < 10:
            return None
        matchCount = len(matches)
        M = getHomography(kpsA, kpsB, matches, reprojThresh=4)
        if M is None:
            print("Error!")
        (matches, H, status) = M
        result = cv2.warpPerspective(imageA, H,
                                     ((imageA.shape[1] + imageB.shape[1]) * 2, (imageA.shape[0] + imageB.shape[0]) * 2))

    result = cutBlack(result)
    return result, matchCount





# ================================================================== #
#                     主函数
# ================================================================== #
if __name__ == "__main__":
    #isshow控制是否显示原始图像及特征点匹配图像等
    result, _ = handle('1_left.jpg','1_right.jpg', isShow=True)

    if not result is None:
        cv2.imshow("result", result[:,:,[2, 1, 0]])
        cv2.imwrite('result_image.jpg', result[:,:,[2, 1, 0]])
        plt.show()
        cv2.waitKey(0)
    else:
        print("没有找到对应特征点,无法计算")
    exit()

