In [1]:
import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt
import time
import random


In [14]:
'''
Read, plot, resize data
'''

def plot(img, title = " "):
    
    fig, ax = plt.subplots(figsize=(16,8),dpi = 80)
    ax.imshow(img) 
    fig.suptitle(title)
    plt.show() 
    return

def read_image(path, flag = 1):
    
    image = cv.imread(path, flag)
    
    if flag == 1:
        image = cv.cvtColor(image, cv.COLOR_BGR2RGB)

    return image
 
def resize(img, scale = 20):
    
    width = int(img.shape[1] * scale / 100)
    height = int(img.shape[0] * scale / 100)
    dim = (width, height)
  
    return cv.resize(img, dim, interpolation = cv.INTER_AREA)


In [15]:
'''
Utility functions for SIFT keypoints
'''

def SIFT_descriptors(image):
    
    sift = cv.SIFT_create()
    keypoints, descriptors = sift.detectAndCompute(image, None)
    return keypoints, descriptors


def draw_and_display(image, keypoints):
    
    draw_features = np.zeros(image.shape)
    draw_features = cv.drawKeypoints(image, keypoints, draw_features)
    return draw_features


def SIFT_matching(descriptorL, descriptorR, img1, kp1, img2, kp2, display = False):
    
    matcher = cv.BFMatcher(cv.NORM_L2, crossCheck = True)
    #matches = matcher.knnMatch(descriptorL, descriptorR, k = 111)
    matches = matcher.match(descriptorL, descriptorR)
    if len(matches) < 4:
        print("Not enough correspondances!")
        exit(0)
    numPoints = min(max(4, int(0.1 * len(matches))), 100)   
    matches = sorted(matches, key = lambda x:x.distance)[0:numPoints]
    
    ''''
    Sorting wrt the distance object of each match.
    The closer the matches are, the less erroneous they are likely to be 
    '''
    
    if display:
        
        print(f"The number of matches: {len(matches)}")
        print(f"\033[35m The number of corresponding descriptors: \033[35m {len(matches)}")
        fig, ax = plt.subplots(figsize = (16, 4))
        img1 = cv.drawMatches(img1, kp1, img2, kp2, matches, None, flags = 2|4)
        fig.suptitle(f'The {numPoints} closest correspondences', fontsize = 16)
        ax.imshow(img1)
    
    pts1 = np.float32([ kp1[m.queryIdx].pt for m in matches ]).reshape(-1,2)
    pts2 = np.float32([ kp2[m.trainIdx].pt for m in matches ]).reshape(-1,2)
    return pts1, pts2
   
    
def get_points_descriptors(imgTrain, imgQuery, display = False):

    keypointsTrain, descriptorsTrain= SIFT_descriptors(imgTrain) 
    keypointsQuery, descriptorsQuery = SIFT_descriptors(imgQuery)
    pointsTrain, pointsQuery = SIFT_matching(descriptorsTrain, descriptorsQuery, imgTrain, keypointsTrain, imgQuery, keypointsQuery, display)
    H = get_homography_ransac(get_homo(pointsTrain), get_homo(pointsQuery))

    return pointsTrain, pointsQuery, H



In [19]:
'''
Functions for Homography estimation and RANSAC
'''

def get_homo(x):
    
    return np.hstack((x, np.ones(x.shape[0]).reshape(-1,1)))


def de_homo(arr):
    
    with np.errstate(divide='ignore', invalid='ignore'):
        arr = np.divide(arr,(arr[:,-1].reshape(-1,1)))
    arr[np.isnan(arr)] = 1e9

    return arr[:,:-1]


def get_homography(pointsTrain, pointsQuery):
    
    '''
    homography from Query to Train, H_{12}
    Ah = 0, least squares formulation
    '''
    
    A = []
    for p in range(pointsTrain.shape[0]):
        xTrain = pointsTrain[p][0] ; yTrain = pointsTrain[p][1]
        xQuery = pointsQuery[p][0] ; yQuery = pointsQuery[p][1]
        A.append([xQuery, yQuery, 1, 0, 0, 0, -xTrain * xQuery, -xTrain * yQuery, -xTrain])
        A.append([0, 0, 0, xQuery, yQuery, 1, -yTrain * xQuery, -yTrain * yQuery, -yTrain])
    A = np.array(A)
    _, _, vt = np.linalg.svd(A, full_matrices = True)
    H = vt[-1].reshape(3, 3)
    H = H / H[2][2]
    return H


def get_rmse(pointsL, pointsR, H):
    
    projectedL = (H @ pointsR.T).T
    d1 = (np.isclose(pointsL, projectedL, rtol = 10e-2))
    d1 = d1[:,1] * d1[:,0]
    d1 = np.sum(d1)
    return np.sqrt(np.mean((projectedL - pointsL) ** 2))


def get_homography_ransac(pointsL, pointsR, displayError = False):
    
    if(pointsL.shape[0] < 4):
        print("Not enough correspondances!")
        exit(0)
        
    random_index = np.arange(pointsL.shape[0])
    H = np.zeros((3,3))
    min_error = 10e8
    
    for iterations in range(9999):
        random.shuffle(random_index)
        pL = []
        pR = []
        minPointCorrespondences = 4
        
        for i in range(minPointCorrespondences):
            pL.append(pointsL[random_index[i]])
            pR.append(pointsR[random_index[i]])
            
        pL = np.array(pL)
        pR = np.array(pR)
        iter_H = get_homography(pL, pR)
        rmse = get_rmse(pointsL, pointsR, iter_H)
        
        if rmse < min_error:
            min_error = rmse
            H = iter_H
    if displayError:    
        print(f"\033[43;77m The minimum rmse error using RANSAC is: \033[43;77m {min_error}")
    return H
 


In [25]:
'''
Functions for image stitching
'''

def get_stitched(imgTarget, imgQuery):

    pointsTarget, pointsQuery,  H = get_points_descriptors(imgTarget, imgQuery)
    H = np.linalg.inv(H)
    x1, y1, _ = imgTarget.shape
    x2, y2, _ = imgQuery.shape
    
    '''
    Finding the corners of the image after applying the homography
    To find the right dimensions for warpPerspective
    '''
    
     
    getCorners = np.asarray([[0, 0], [0, x1 - 1], [y1 - 1, 0], [y1 - 1, x1 - 1]])
    getCorners = np.int32(de_homo((H @ get_homo(getCorners).T).T))
    getCorners = np.vstack((getCorners,[0, 0], [0, x2 - 1], [y2 - 1, x2 - 1], [y2 - 1, 0]))
    
    maxDim = np.max(getCorners, axis = 0)
    minDim = np.min(getCorners, axis = 0)
    
    
    if np.any(minDim < 0):
        # If zero padding needs to be done
        imgQueryPadded = np.pad(imgQuery, ((-minDim[1], 0), (-minDim[0] ,0), (0 ,0)), 'constant', constant_values = 0)
        pointsTarget, pointsQuery, H = get_points_descriptors(imgTarget, imgQueryPadded)
        H = np.linalg.inv(H)
        
    
    warpedImage = cv.warpPerspective(imgTarget, H, (maxDim[0] - minDim[0] + 1, maxDim[1] - minDim[1] + 1 ), borderMode = cv.BORDER_CONSTANT, borderValue=(0, 255, 0))
    
    plot(warpedImage, title = "warped Image")
    
    minThreshold = np.array([0, 254, 0])    
    maxThreshold = np.array([1, 255, 1])
            
    mask = cv.inRange(imgQuery, minThreshold, maxThreshold)
    #Returns 255 if within the threshold
    maskedImage = np.copy(imgQuery)
    maskedImage[mask != 0] = [0, 0, 0]
    stitchedImage = np.where( maskedImage == [0,0,0] , warpedImage[-minDim[1]:-minDim[1] + x2 ,-minDim[0]:-minDim[0] + y2, :], imgQuery)
    warpedImage[-minDim[1]:(-minDim[1] + x2), -minDim[0]:(-minDim[0] + y2)] = stitchedImage
    warpedImage = trim(warpedImage)
    
    return warpedImage

def stitch_multiple(img_arr):
    
    #stitching multiple images irrespective of the given order
    arr = list(range(len(img_arr)))
    i = 0
    while i<len(arr):
        if i == 0:
            out = img_arr[0] 
            i = i+1
            continue
        
        out1 = get_stitched(img_arr[i],out)
        
        if np.all(out1==0):
            k = arr[i]
            arr.remove(arr[i])
            arr.append(k)
            
        else:
            out = out1
            i=i+1
    
    return out

def trim(frame):
    #crop top
    if np.all(frame[0] == np.array([0,255,0])):
        return trim(frame[1:])
    #crop bottom
    elif np.all(frame[-1] == np.array([0,255,0])):
        return trim(frame[:-2])
    #crop left
    elif np.all(frame[:,0] == np.array([0,255,0])):
        return trim(frame[:,1:]) 
    #crop right
    elif np.all(frame[:,-1] == np.array([0,255,0])):
        return trim(frame[:,:-2])    
    return frame
