In [9]:
import argparse
import glob
import numpy as np
import os
import time
import cv2
import torch
import matplotlib.pyplot as plt
import random
from skimage.metrics import structural_similarity as ssim
from skimage.measure import ransac
from skimage.transform import FundamentalMatrixTransform, AffineTransform

# Stub to warn about opencv version.
if int(cv2.__version__[0]) < 3: # pragma: no cover
  print('Warning: OpenCV 3 is not installed')

# Jet colormap for visualization.
myjet = np.array([[0.        , 0.        , 0.5       ],
                  [0.        , 0.        , 0.99910873],
                  [0.        , 0.37843137, 1.        ],
                  [0.        , 0.83333333, 1.        ],
                  [0.30044276, 1.        , 0.66729918],
                  [0.66729918, 1.        , 0.30044276],
                  [1.        , 0.90123457, 0.        ],
                  [1.        , 0.48002905, 0.        ],
                  [0.99910873, 0.07334786, 0.        ],
                  [0.5       , 0.        , 0.        ]])


In [10]:
class SuperPointNet(torch.nn.Module):
  """ Pytorch definition of SuperPoint Network. """
  def __init__(self):
    super(SuperPointNet, self).__init__()
    self.relu = torch.nn.ReLU(inplace=True)
    self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
    c1, c2, c3, c4, c5, d1 = 64, 64, 128, 128, 256, 256
    # Shared Encoder.
    self.conv1a = torch.nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
    self.conv1b = torch.nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
    self.conv2a = torch.nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
    self.conv2b = torch.nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
    self.conv3a = torch.nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
    self.conv3b = torch.nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
    self.conv4a = torch.nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
    self.conv4b = torch.nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
    # Detector Head.
    self.convPa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
    self.convPb = torch.nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
    # Descriptor Head.
    self.convDa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
    self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=1, stride=1, padding=0)

  def forward(self, x):
    """ Forward pass that jointly computes unprocessed point and descriptor
    tensors.
    Input
      x: Image pytorch tensor shaped N x 1 x H x W.
    Output
      semi: Output point pytorch tensor shaped N x 65 x H/8 x W/8.
      desc: Output descriptor pytorch tensor shaped N x 256 x H/8 x W/8.
    """
    # Shared Encoder.
    x = self.relu(self.conv1a(x))
    x = self.relu(self.conv1b(x))
    x = self.pool(x)
    x = self.relu(self.conv2a(x))
    x = self.relu(self.conv2b(x))
    x = self.pool(x)
    x = self.relu(self.conv3a(x))
    x = self.relu(self.conv3b(x))
    x = self.pool(x)
    x = self.relu(self.conv4a(x))
    x = self.relu(self.conv4b(x))
    # Detector Head.
    cPa = self.relu(self.convPa(x))
    semi = self.convPb(cPa)
    # Descriptor Head.
    cDa = self.relu(self.convDa(x))
    desc = self.convDb(cDa)
    dn = torch.norm(desc, p=2, dim=1) # Compute the norm.
    desc = desc.div(torch.unsqueeze(dn, 1)) # Divide by norm to normalize.
    return semi, desc


class SuperPointFrontend(object):
  """ Wrapper around pytorch net to help with pre and post image processing. """
  def __init__(self, weights_path, nms_dist, conf_thresh, nn_thresh,
               cuda=False):
    self.name = 'SuperPoint'
    self.cuda = cuda
    self.nms_dist = nms_dist
    self.conf_thresh = conf_thresh
    self.nn_thresh = nn_thresh # L2 descriptor distance for good match.
    self.cell = 8 # Size of each output cell. Keep this fixed.
    self.border_remove = 4 # Remove points this close to the border.

    # Load the network in inference mode.
    self.net = SuperPointNet()
    if cuda:
      # Train on GPU, deploy on GPU.
      self.net.load_state_dict(torch.load(weights_path))
      self.net = self.net.cuda()
    else:
      # Train on GPU, deploy on CPU.
      self.net.load_state_dict(torch.load(weights_path,
                               map_location=lambda storage, loc: storage))
    self.net.eval()

  def nms_fast(self, in_corners, H, W, dist_thresh):
    """
    Run a faster approximate Non-Max-Suppression on numpy corners shaped:
      3xN [x_i,y_i,conf_i]^T
  
    Algo summary: Create a grid sized HxW. Assign each corner location a 1, rest
    are zeros. Iterate through all the 1's and convert them either to -1 or 0.
    Suppress points by setting nearby values to 0.
  
    Grid Value Legend:
    -1 : Kept.
     0 : Empty or suppressed.
     1 : To be processed (converted to either kept or supressed).
  
    NOTE: The NMS first rounds points to integers, so NMS distance might not
    be exactly dist_thresh. It also assumes points are within image boundaries.
  
    Inputs
      in_corners - 3xN numpy array with corners [x_i, y_i, confidence_i]^T.
      H - Image height.
      W - Image width.
      dist_thresh - Distance to suppress, measured as an infinty norm distance.
    Returns
      nmsed_corners - 3xN numpy matrix with surviving corners.
      nmsed_inds - N length numpy vector with surviving corner indices.
    """
    grid = np.zeros((H, W)).astype(int) # Track NMS data.
    inds = np.zeros((H, W)).astype(int) # Store indices of points.
    # Sort by confidence and round to nearest int.
    inds1 = np.argsort(-in_corners[2,:])
    corners = in_corners[:,inds1]
    rcorners = corners[:2,:].round().astype(int) # Rounded corners.
    # Check for edge case of 0 or 1 corners.
    if rcorners.shape[1] == 0:
      return np.zeros((3,0)).astype(int), np.zeros(0).astype(int)
    if rcorners.shape[1] == 1:
      out = np.vstack((rcorners, in_corners[2])).reshape(3,1)
      return out, np.zeros((1)).astype(int)
    # Initialize the grid.
    for i, rc in enumerate(rcorners.T):
      grid[rcorners[1,i], rcorners[0,i]] = 1
      inds[rcorners[1,i], rcorners[0,i]] = i
    # Pad the border of the grid, so that we can NMS points near the border.
    pad = dist_thresh
    grid = np.pad(grid, ((pad,pad), (pad,pad)), mode='constant')
    # Iterate through points, highest to lowest conf, suppress neighborhood.
    count = 0
    for i, rc in enumerate(rcorners.T):
      # Account for top and left padding.
      pt = (rc[0]+pad, rc[1]+pad)
      if grid[pt[1], pt[0]] == 1: # If not yet suppressed.
        grid[pt[1]-pad:pt[1]+pad+1, pt[0]-pad:pt[0]+pad+1] = 0
        grid[pt[1], pt[0]] = -1
        count += 1
    # Get all surviving -1's and return sorted array of remaining corners.
    keepy, keepx = np.where(grid==-1)
    keepy, keepx = keepy - pad, keepx - pad
    inds_keep = inds[keepy, keepx]
    out = corners[:, inds_keep]
    values = out[-1, :]
    inds2 = np.argsort(-values)
    out = out[:, inds2]
    out_inds = inds1[inds_keep[inds2]]
    return out, out_inds

  def run(self, img):
    """ Process a numpy image to extract points and descriptors.
    Input
      img - HxW numpy float32 input image in range [0,1].
    Output
      corners - 3xN numpy array with corners [x_i, y_i, confidence_i]^T.
      desc - 256xN numpy array of corresponding unit normalized descriptors.
      heatmap - HxW numpy heatmap in range [0,1] of point confidences.
      """
    assert img.ndim == 2, 'Image must be grayscale.'
    assert img.dtype == np.float32, 'Image must be float32.'
    H, W = img.shape[0], img.shape[1]
    inp = img.copy()
    inp = (inp.reshape(1, H, W))
    inp = torch.from_numpy(inp)
    inp = torch.autograd.Variable(inp).view(1, 1, H, W)
    if self.cuda:
      inp = inp.cuda()
    # Forward pass of network.
    outs = self.net.forward(inp)
    semi, coarse_desc = outs[0], outs[1]
    # Convert pytorch -> numpy.
    semi = semi.data.cpu().numpy().squeeze()
    # --- Process points.
    dense = np.exp(semi) # Softmax.
    dense = dense / (np.sum(dense, axis=0)+.00001) # Should sum to 1.
    # Remove dustbin.
    nodust = dense[:-1, :, :]
    # Reshape to get full resolution heatmap.
    Hc = int(H / self.cell)
    Wc = int(W / self.cell)
    nodust = nodust.transpose(1, 2, 0)
    heatmap = np.reshape(nodust, [Hc, Wc, self.cell, self.cell])
    heatmap = np.transpose(heatmap, [0, 2, 1, 3])
    heatmap = np.reshape(heatmap, [Hc*self.cell, Wc*self.cell])
    xs, ys = np.where(heatmap >= self.conf_thresh) # Confidence threshold.
    if len(xs) == 0:
      return np.zeros((3, 0)), None, None
    pts = np.zeros((3, len(xs))) # Populate point data sized 3xN.
    pts[0, :] = ys
    pts[1, :] = xs
    pts[2, :] = heatmap[xs, ys]
    pts, _ = self.nms_fast(pts, H, W, dist_thresh=self.nms_dist) # Apply NMS.
    inds = np.argsort(pts[2,:])
    pts = pts[:,inds[::-1]] # Sort by confidence.
    # Remove points along border.
    bord = self.border_remove
    toremoveW = np.logical_or(pts[0, :] < bord, pts[0, :] >= (W-bord))
    toremoveH = np.logical_or(pts[1, :] < bord, pts[1, :] >= (H-bord))
    toremove = np.logical_or(toremoveW, toremoveH)
    pts = pts[:, ~toremove]
    # --- Process descriptor.
    D = coarse_desc.shape[1]
    if pts.shape[1] == 0:
      desc = np.zeros((D, 0))
    else:
      # Interpolate into descriptor map using 2D point locations.
      samp_pts = torch.from_numpy(pts[:2, :].copy())
      samp_pts[0, :] = (samp_pts[0, :] / (float(W)/2.)) - 1.
      samp_pts[1, :] = (samp_pts[1, :] / (float(H)/2.)) - 1.
      samp_pts = samp_pts.transpose(0, 1).contiguous()
      samp_pts = samp_pts.view(1, 1, -1, 2)
      samp_pts = samp_pts.float()
      if self.cuda:
        samp_pts = samp_pts.cuda()
      desc = torch.nn.functional.grid_sample(coarse_desc, samp_pts)
      desc = desc.data.cpu().numpy().reshape(D, -1)
      desc /= np.linalg.norm(desc, axis=0)[np.newaxis, :]
    return pts, desc, heatmap


class PointTracker(object):
  """ Class to manage a fixed memory of points and descriptors that enables
  sparse optical flow point tracking.

  Internally, the tracker stores a 'tracks' matrix sized M x (2+L), of M
  tracks with maximum length L, where each row corresponds to:
  row_m = [track_id_m, avg_desc_score_m, point_id_0_m, ..., point_id_L-1_m].
  """

  def __init__(self, max_length, nn_thresh):
    if max_length < 2:
      raise ValueError('max_length must be greater than or equal to 2.')
    self.maxl = max_length
    self.nn_thresh = nn_thresh
    self.all_pts = []
    for n in range(self.maxl):
      self.all_pts.append(np.zeros((2, 0)))
    self.last_desc = None
    self.tracks = np.zeros((0, self.maxl+2))
    self.track_count = 0
    self.max_score = 9999

  def nn_match_two_way(self, desc1, desc2, nn_thresh):
    """
    Performs two-way nearest neighbor matching of two sets of descriptors, such
    that the NN match from descriptor A->B must equal the NN match from B->A.

    Inputs:
      desc1 - NxM numpy matrix of N corresponding M-dimensional descriptors.
      desc2 - NxM numpy matrix of N corresponding M-dimensional descriptors.
      nn_thresh - Optional descriptor distance below which is a good match.

    Returns:
      matches - 3xL numpy array, of L matches, where L <= N and each column i is
                a match of two descriptors, d_i in image 1 and d_j' in image 2:
                [d_i index, d_j' index, match_score]^T
    """
    assert desc1.shape[0] == desc2.shape[0]
    if desc1.shape[1] == 0 or desc2.shape[1] == 0:
      return np.zeros((3, 0))
    if nn_thresh < 0.0:
      raise ValueError('\'nn_thresh\' should be non-negative')
    # Compute L2 distance. Easy since vectors are unit normalized.
    dmat = np.dot(desc1.T, desc2)
    dmat = np.sqrt(2-2*np.clip(dmat, -1, 1))
    # Get NN indices and scores.
    idx = np.argmin(dmat, axis=1)
    scores = dmat[np.arange(dmat.shape[0]), idx]
    # Threshold the NN matches.  <<< ======================= this threshold is not good
    keep = scores < nn_thresh
    # Check if nearest neighbor goes both directions and keep those.
    idx2 = np.argmin(dmat, axis=0)
    keep_bi = np.arange(len(idx)) == idx2[idx]
    keep = np.logical_and(keep, keep_bi)
    idx = idx[keep]
    scores = scores[keep]
    # Get the surviving point indices.
    m_idx1 = np.arange(desc1.shape[1])[keep]
    m_idx2 = idx
    # Populate the final 3xN match data structure.
    matches = np.zeros((3, int(keep.sum())))
    matches[0, :] = m_idx1
    matches[1, :] = m_idx2
    matches[2, :] = scores
    return matches
  
  def nn_match_two_way_with_ransac(self, points1, points2, matches, max_reproj_error=5.0):  # <<=========================================== TODO Here is the RANSAC
    '''find matching points between two images using RANSAC'''
    
    # estimate affine transform model using all coordinates
    model = AffineTransform()
    model.estimate(points1, points2)

    # try min_samples = 3, if fail, try min_samples = 2
    try:
      # Find the best fundamental matrix using RANSAC
      best_model, best_inliers = ransac((points1, points2),
                                        AffineTransform, min_samples=3,
                                        residual_threshold=max_reproj_error, max_trials=100)
    except:
      try:
        best_model, best_inliers = ransac((points1, points2),
                                          AffineTransform, min_samples=2,
                                          residual_threshold=max_reproj_error, max_trials=100)
        
      except:
        # if ransac failed, return an array of TRUE with the original matches shape
        print('ransac failed')
        print('matches shape: ', matches.shape)
        return np.ones((matches.shape[1])).astype(bool)
    
    # the inliners are the matching points
    matches = np.array(best_inliers).T

    print(f'matches: {matches.shape}')
    return matches

  def get_offsets(self):
    """ Iterate through list of points and accumulate an offset value. Used to
    index the global point IDs into the list of points.

    Returns
      offsets - N length array with integer offset locations.
    """
    # Compute id offsets.
    offsets = []
    offsets.append(0)
    for i in range(len(self.all_pts)-1): # Skip last camera size, not needed.
      offsets.append(self.all_pts[i].shape[1])
    offsets = np.array(offsets)
    offsets = np.cumsum(offsets)
    return offsets

  def update(self, pts, desc):
    """ Add a new set of point and descriptor observations to the tracker.

    Inputs
      pts - 3xN numpy array of 2D point observations.
      desc - DxN numpy array of corresponding D dimensional descriptors.
    """
    if pts is None or desc is None:
      print('PointTracker: Warning, no points were added to tracker.')
      return
    assert pts.shape[1] == desc.shape[1]
    # Initialize last_desc.
    if self.last_desc is None:
      self.last_desc = np.zeros((desc.shape[0], 0))
    # Remove oldest points, store its size to update ids later.
    remove_size = self.all_pts[0].shape[1]
    self.all_pts.pop(0)
    self.all_pts.append(pts)
    # Remove oldest point in track.
    self.tracks = np.delete(self.tracks, 2, axis=1)
    # Update track offsets.
    for i in range(2, self.tracks.shape[1]):
      self.tracks[:, i] -= remove_size
    self.tracks[:, 2:][self.tracks[:, 2:] < -1] = -1
    offsets = self.get_offsets()
    # Add a new -1 column.
    self.tracks = np.hstack((self.tracks, -1*np.ones((self.tracks.shape[0], 1))))
    # Try to append to existing tracks.
    matched = np.zeros((pts.shape[1])).astype(bool)
    matches = self.nn_match_two_way(self.last_desc, desc, self.nn_thresh)
    for match in matches.T:
      # Add a new point to it's matched track.
      id1 = int(match[0]) + offsets[-2]
      id2 = int(match[1]) + offsets[-1]
      found = np.argwhere(self.tracks[:, -2] == id1)
      if found.shape[0] > 0:
        matched[int(match[1])] = True
        row = int(found)
        self.tracks[row, -1] = id2
        if self.tracks[row, 1] == self.max_score:
          # Initialize track score.
          self.tracks[row, 1] = match[2]
        else:
          # Update track score with running average.
          # NOTE(dd): this running average can contain scores from old matches
          #           not contained in last max_length track points.
          track_len = (self.tracks[row, 2:] != -1).sum() - 1.
          frac = 1. / float(track_len)
          self.tracks[row, 1] = (1.-frac)*self.tracks[row, 1] + frac*match[2]
    # Add unmatched tracks.
    new_ids = np.arange(pts.shape[1]) + offsets[-1]
    new_ids = new_ids[~matched]
    new_tracks = -1*np.ones((new_ids.shape[0], self.maxl + 2))
    new_tracks[:, -1] = new_ids
    new_num = new_ids.shape[0]
    new_trackids = self.track_count + np.arange(new_num)
    new_tracks[:, 0] = new_trackids
    new_tracks[:, 1] = self.max_score*np.ones(new_ids.shape[0])
    self.tracks = np.vstack((self.tracks, new_tracks))
    self.track_count += new_num # Update the track count.
    # Remove empty tracks.
    keep_rows = np.any(self.tracks[:, 2:] >= 0, axis=1)
    self.tracks = self.tracks[keep_rows, :]
    # Store the last descriptors.
    self.last_desc = desc.copy()
    return

  def get_tracks(self, min_length):
    """ Retrieve point tracks of a given minimum length.
    Input
      min_length - integer >= 1 with minimum track length
    Output
      returned_tracks - M x (2+L) sized matrix storing track indices, where
        M is the number of tracks and L is the maximum track length.
    """
    if min_length < 1:
      raise ValueError('\'min_length\' too small.')
    valid = np.ones((self.tracks.shape[0])).astype(bool)
    good_len = np.sum(self.tracks[:, 2:] != -1, axis=1) >= min_length
    # Remove tracks which do not have an observation in most recent frame.
    not_headless = (self.tracks[:, -1] != -1)
    keepers = np.logical_and.reduce((valid, good_len, not_headless))
    returned_tracks = self.tracks[keepers, :].copy()
    return returned_tracks

  def draw_tracks(self, out, tracks):
    """ Visualize tracks all overlayed on a single image.
    Inputs
      out - numpy uint8 image sized HxWx3 upon which tracks are overlayed.
      tracks - M x (2+L) sized matrix storing track info.
    """
    # Store the number of points per camera.
    pts_mem = self.all_pts
    N = len(pts_mem) # Number of cameras/images.
    # Get offset ids needed to reference into pts_mem.
    offsets = self.get_offsets()
    # Width of track and point circles to be drawn.
    stroke = 1
    # Iterate through each track and draw it.
    for track in tracks:
      clr = myjet[int(np.clip(np.floor(track[1]*10), 0, 9)), :]*255
      for i in range(N-1):
        if track[i+2] == -1 or track[i+3] == -1:
          continue
        offset1 = offsets[i]
        offset2 = offsets[i+1]
        idx1 = int(track[i+2]-offset1)
        idx2 = int(track[i+3]-offset2)
        pt1 = pts_mem[i][:2, idx1]
        pt2 = pts_mem[i+1][:2, idx2]
        p1 = (int(round(pt1[0])), int(round(pt1[1])))
        p2 = (int(round(pt2[0])), int(round(pt2[1])))
        cv2.line(out, p1, p2, clr, thickness=stroke, lineType=16)
        # Draw end points of each track.
        if i == N-2:
          clr2 = (255, 0, 0)
          cv2.circle(out, p2, stroke, clr2, -1, lineType=16)

  def read_image(self, impath, img_size):
    """ Read image as grayscale and resize to img_size.
    Inputs
      impath: Path to input image.
      img_size: (W, H) tuple specifying resize size.
    Returns
      grayim: float32 numpy array sized H x W with values in range [0, 1].
    """
    grayim = cv2.imread(impath, 0)
    if grayim is None:
      raise Exception('Error reading image %s' % impath)
    # Image is resized via opencv.
    interp = cv2.INTER_AREA
    grayim = cv2.resize(grayim, (img_size[1], img_size[0]), interpolation=interp)
    grayim = (grayim.astype('float32') / 255.)
    return grayim

  def next_frame(self):
    """ Return the next frame, and increment internal counter.
    Returns
       image: Next H x W image.
       status: True or False depending whether image was loaded.
    """
    if self.i == self.maxlen:
      return (None, False)
    if self.camera:
      ret, input_image = self.cap.read()
      if ret is False:
        print('VideoStreamer: Cannot get image from camera (maybe bad --camid?)')
        return (None, False)
      if self.video_file:
        self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.listing[self.i])
      input_image = cv2.resize(input_image, (self.sizer[1], self.sizer[0]),
                               interpolation=cv2.INTER_AREA)
      input_image = cv2.cvtColor(input_image, cv2.COLOR_RGB2GRAY)
      input_image = input_image.astype('float')/255.0
    else:
      image_file = self.listing[self.i]
      input_image = self.read_image(image_file, self.sizer)
    # Increment internal counter.
    self.i = self.i + 1
    input_image = input_image.astype('float32')
    return (input_image, True)
  

# Modified code to load and feed 2 images 

In [11]:
# Function to overlay points on the image
def overlay_points(image, points, color=(0, 255, 0), radius=5):
    for point in points.T:
        x, y = int(point[0]), int(point[1])
        cv2.circle(image, (x, y), radius, color, -1)
    return image

# Predefined bright colors
bright_colors = [
    (255, 0, 0),    # Red
    (0, 255, 0),    # Green
    (0, 0, 255),    # Blue
    (255, 255, 0),  # Yellow
    (0, 255, 255),  # Cyan
    (255, 0, 255),  # Magenta
    (255, 128, 0),  # Orange
    (0, 255, 128),  # Spring Green
    (128, 0, 255),  # Purple
    (255, 0, 128),  # Rose
]

# Function to draw lines connecting points from image 1 to image 2 with random bright colors
def draw_lines(image1, image2, points1, points2, match, line_thickness=1, opacity=0.5):
    # Convert grayscale images to 3-channel with repeated intensity value
    if len(image1.shape) == 2:
        image1 = cv2.cvtColor(image1, cv2.COLOR_GRAY2BGR)
    if len(image2.shape) == 2:
        image2 = cv2.cvtColor(image2, cv2.COLOR_GRAY2BGR)

    combined_image = np.concatenate((image1, image2), axis=1)
    for pt1, pt2, value in zip(points1.T, points2.T, match):
        x1, y1 = int(pt1[0]), int(pt1[1])
        x2, y2 = int(pt2[0] + image1.shape[1]), int(pt2[1])

        # Randomly choose a color from the predefined bright colors
        line_color = random.choice(bright_colors)

        # Draw the line with the chosen color and specified thickness
        cv2.line(combined_image, (x1, y1), (x2, y2), line_color, line_thickness)

        # Add lines with the chosen color and opacity on top of the combined image
        overlay = combined_image.copy()
        cv2.line(overlay, (x1, y1), (x2, y2), line_color, line_thickness)
        cv2.addWeighted(overlay, opacity, combined_image, 1 - opacity, 0, combined_image)

        # Add value from 'match' array as text around the beginning of the line on the left image
        value_text = f"{value:.2f}"
        text_x = x1 - 50 if x1 > 50 else x1 + 10
        text_y = y1 + 15
        cv2.putText(combined_image, value_text, (text_x, text_y),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, line_color, 1)

    return combined_image

# Function to create a heatmap from keypoints
def create_heatmap(image, keypoints, size=5, sigma=1):
    heatmap = np.zeros_like(image)
    for kp in keypoints.T:
        x, y = int(kp[0]), int(kp[1])
        heatmap[y - size:y + size + 1, x - size:x + size + 1] = 1
    heatmap = cv2.GaussianBlur(heatmap, (0, 0), sigmaX=sigma, sigmaY=sigma)
    return heatmap

def create_subplot(name, image1_name, image2_name, image1, image2, points1, points2, 
                   heatmap1, heatmap2, match_score):
    # Create a subplot with 2 rows and 2 columns
    fig, axes = plt.subplot_mosaic("AAA;BCD", figsize=(15, 10))

    overlaid1 = overlay_points(image1.copy(), points1, radius=2)
    overlaid2 = overlay_points(image2.copy(), points2, radius=2)

    # offset the points in image 2 by the width of image 1
    #points2[:, 0] += image1.shape[1]
    
    # Row 1: Two images side-by-side with overlaid points and lines
    # show also MSE and SSIM between the two images
    axes["A"].imshow(draw_lines(overlaid1, overlaid2, points1, points2, match_score))
    axes["A"].set_title(f"(RANSAC) Image {image1_name} & {image2_name} with Points. MSE (x100): {100*np.mean((image1 - image2)**2):.2f} SSIM: {ssim(image1, image2):.2f}")
    axes["A"].axis('off')
    
    # Row 2: Heatmap for image 1 and image 2
    axes["B"].imshow(image1, cmap='gray')
    axes["B"].imshow(heatmap1, cmap='hot', alpha=0.8)
    axes["B"].set_title("Heatmap for Image 1")
    axes["B"].axis('off')

    axes["C"].imshow(image2, cmap='gray')
    axes["C"].imshow(heatmap2, cmap='hot', alpha=0.8)
    axes["C"].set_title("Heatmap for Image 2")
    axes["C"].axis('off')

    # New subplot for histograms of 'match_score' array
    axes["D"].hist(match_score, bins=20, color='blue', alpha=0.7)
    axes["D"].set_title(f"Match Score Histogram, {len(match_score)} matches")
    axes["D"].set_xlabel("Value")
    axes["D"].set_ylabel("Frequency")
    
    plt.tight_layout()  # Adjust the layout to leave space for the histogram
    os.makedirs('output_images', exist_ok=True)
    plt.savefig(f"output_images/{name}_{image1_name}_{image2_name}_RANSAC.png")
    plt.close()

def check_location(points):
    # check if the location of the points in array is within the image
    # return the number of points that are outside the image
    # points is a 2D array of shape (2, N)
    # where N is the number of points
    # the first row is the x coordinate
    # the second row is the y coordinate
    # the image size is 512 x 512
    # the points are assumed to be in the range of 0 to 511
    # return the number of points that are outside the image
    number_outside = np.sum((points < 0) | (points > 511))
    
    # print if there are points outside the image
    if number_outside > 0:
        print(f'Number of points outside the image: {number_outside}')
    # return number_outside

def load_images_from_folder(folder, img_size=(512, 512)):
    images = []
    for filename in os.listdir(folder):
        img_path = os.path.join(folder, filename)
        grayim = cv2.imread(img_path, 0)
        interp = cv2.INTER_AREA
        grayim = cv2.resize(grayim, (img_size[1], img_size[0]), interpolation=interp)
        image = (grayim.astype('float32') / 255.)
        if image is not None:
            images.append(image)
    return images

def main(folder_path, name, weights_path='superpoint_v1.pth', cuda=True):
    # Load all images from the folder
    images = load_images_from_folder(folder_path)

    if len(images) < 2:
        print("At least 2 images are required in the folder.")
        return

    # Initialize SuperPointFrontend
    superpoint = SuperPointFrontend(weights_path, nms_dist=4,
                          conf_thresh=0.015,
                          nn_thresh=0.7, cuda=cuda)

    # Process all pairs of images
    for i in range(len(images)):
        for j in range(i + 1, len(images)):
            print(f'\nImage {i} & {j}')
            image1 = images[i]
            image2 = images[j]

            # Process the first image
            points1, desc1, heatmap1 = superpoint.run(image1)

            # Process the second image
            points2, desc2, heatmap2 = superpoint.run(image2)

            # check if the location of the points in array is within the image
            check_location(points1)
            check_location(points2)

            # match the points between the two images
            tracker = PointTracker(5, nn_thresh=0.7)
            matches = tracker.nn_match_two_way(desc1, desc2, nn_thresh=0.7)

            # take the elements from points1 and points2 using the matches as indices
            matches1 = points1[:, matches[0, :].astype(int)]
            matches2 = points2[:, matches[1, :].astype(int)]
            print(f'Before: {matches1[:2, :].shape}, {matches2[:2, :].shape}')
            
            matches_RANSAC = tracker.nn_match_two_way_with_ransac(matches1[:2, :].T, matches2[:2, :].T, matches, max_reproj_error=2)
            # print amount of TRUE in matches_RANSAC
            print(f'After RANSAC: {np.sum(matches_RANSAC)}')
            
            # take elements from matches1 and matches2 where matches_RANSAC is TRUE
            matches_RANSAC = matches_RANSAC.reshape(-1)
            
            matches1_RANSAC = matches1[:2, matches_RANSAC]
            matches2_RANSAC = matches2[:2, matches_RANSAC]
            print(f'{matches1_RANSAC.shape}, {matches2_RANSAC.shape}')

            '''if j == i + 1:
                # Create the "output_points" folder if it does not exist
                output_folder = "output_points"
                os.makedirs(output_folder, exist_ok=True)

                # Save points1, desc1, as files in the "output_points" folder
                points1_file = os.path.join(output_folder, f"img_{i}_points.txt")
                desc1_file = os.path.join(output_folder, f"img_{i}_desc.txt")

                np.savetxt(points1_file, matches12)
                np.savetxt(desc1_file, desc1)'''

            # take the match score of the elements in matches_RANSAC
            match_score = matches[2, matches_RANSAC]
            # Create the subplot of the requested outputs
            create_subplot(name, i, j, image1, image2, matches1_RANSAC, matches2_RANSAC, \
                           heatmap1, heatmap2, match_score)
    print("Done!")
    # return points1, desc1, points2, desc2

In [12]:
# Set your file path and other arguments here
file_path = "../Dataset/Dataset-processed/05-05-2560/3084442-L/b1"

main(file_path, '05-05-2560_3084442-L_b1')


Image 0 & 1


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Before: (2, 55), (2, 55)
matches: (55,)
After RANSAC: 11
(2, 11), (2, 11)

Image 0 & 2
Before: (2, 28), (2, 28)
matches: (28,)
After RANSAC: 5
(2, 5), (2, 5)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 0 & 3


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Before: (2, 10), (2, 10)
matches: (10,)
After RANSAC: 4
(2, 4), (2, 4)

Image 0 & 4
Before: (2, 19), (2, 19)
matches: (19,)
After RANSAC: 5
(2, 5), (2, 5)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 0 & 5


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Before: (2, 11), (2, 11)
matches: (11,)
After RANSAC: 4
(2, 4), (2, 4)

Image 0 & 6
Before: (2, 10), (2, 10)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


matches: (10,)
After RANSAC: 3
(2, 3), (2, 3)

Image 0 & 7
Before: (2, 10), (2, 10)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


matches: (10,)
After RANSAC: 5
(2, 5), (2, 5)

Image 0 & 8
Before: (2, 245), (2, 245)
matches: (245,)
After RANSAC: 44
(2, 44), (2, 44)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 0 & 9
Before: (2, 188), (2, 188)
matches: (188,)
After RANSAC: 36
(2, 36), (2, 36)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 0 & 10
Before: (2, 17), (2, 17)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


matches: (17,)
After RANSAC: 6
(2, 6), (2, 6)

Image 0 & 11


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Before: (2, 18), (2, 18)
matches: (18,)
After RANSAC: 5
(2, 5), (2, 5)

Image 1 & 2
Before: (2, 201), (2, 201)
matches: (201,)
After RANSAC: 45
(2, 45), (2, 45)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 1 & 3
Before: (2, 25), (2, 25)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


matches: (25,)
After RANSAC: 6
(2, 6), (2, 6)

Image 1 & 4


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Before: (2, 15), (2, 15)
matches: (15,)
After RANSAC: 7
(2, 7), (2, 7)

Image 1 & 5
Before: (2, 20), (2, 20)
matches: (20,)
After RANSAC: 4
(2, 4), (2, 4)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 1 & 6
Before: (2, 2), (2, 2)
ransac failed
matches shape:  (3, 2)
After RANSAC: 2
(2, 2), (2, 2)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 1 & 7


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Before: (2, 21), (2, 21)
matches: (21,)
After RANSAC: 4
(2, 4), (2, 4)

Image 1 & 8
Before: (2, 88), (2, 88)
matches: (88,)
After RANSAC: 41
(2, 41), (2, 41)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 1 & 9


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Before: (2, 59), (2, 59)
matches: (59,)
After RANSAC: 17
(2, 17), (2, 17)

Image 1 & 10
Before: (2, 10), (2, 10)
matches: (10,)
After RANSAC: 4
(2, 4), (2, 4)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 1 & 11
Before: (2, 168), (2, 168)
matches: (168,)
After RANSAC: 38
(2, 38), (2, 38)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 2 & 3
Before: (2, 45), (2, 45)
matches: (45,)
After RANSAC: 10
(2, 10), (2, 10)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 2 & 4
Before: (2, 13), (2, 13)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


matches: (13,)
After RANSAC: 5
(2, 5), (2, 5)

Image 2 & 5


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Before: (2, 37), (2, 37)
matches: (37,)
After RANSAC: 9
(2, 9), (2, 9)

Image 2 & 6
Before: (2, 2), (2, 2)
ransac failed
matches shape:  (3, 2)
After RANSAC: 2
(2, 2), (2, 2)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 2 & 7
Before: (2, 21), (2, 21)
matches: (21,)
After RANSAC: 6
(2, 6), (2, 6)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 2 & 8
Before: (2, 99), (2, 99)
matches: (99,)
After RANSAC: 37
(2, 37), (2, 37)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 2 & 9
Before: (2, 98), (2, 98)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


matches: (98,)
After RANSAC: 58
(2, 58), (2, 58)

Image 2 & 10
Before: (2, 11), (2, 11)
matches: (11,)
After RANSAC: 6
(2, 6), (2, 6)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 2 & 11
Before: (2, 399), (2, 399)
matches: (399,)
After RANSAC: 280
(2, 280), (2, 280)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 3 & 4


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Before: (2, 22), (2, 22)
matches: (22,)
After RANSAC: 7
(2, 7), (2, 7)

Image 3 & 5
Before: (2, 184), (2, 184)
matches: (184,)
After RANSAC: 82
(2, 82), (2, 82)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 3 & 6
Before: (2, 4), (2, 4)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


matches: (4,)
After RANSAC: 3
(2, 3), (2, 3)

Image 3 & 7
Before: (2, 157), (2, 157)
matches: (157,)
After RANSAC: 72
(2, 72), (2, 72)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 3 & 8
Before: (2, 13), (2, 13)
matches: (13,)
After RANSAC: 4
(2, 4), (2, 4)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 3 & 9
Before: (2, 8), (2, 8)
matches: (8,)
After RANSAC: 4
(2, 4), (2, 4)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 3 & 10
Before: (2, 63), (2, 63)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


matches: (63,)
After RANSAC: 18
(2, 18), (2, 18)

Image 3 & 11


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Before: (2, 24), (2, 24)
matches: (24,)
After RANSAC: 10
(2, 10), (2, 10)

Image 4 & 5
Before: (2, 21), (2, 21)
matches: (21,)
After RANSAC: 4
(2, 4), (2, 4)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 4 & 6


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Before: (2, 17), (2, 17)
matches: (17,)
After RANSAC: 5
(2, 5), (2, 5)

Image 4 & 7
Before: (2, 19), (2, 19)
matches: (19,)
After RANSAC: 5
(2, 5), (2, 5)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 4 & 8
Before: (2, 30), (2, 30)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


matches: (30,)
After RANSAC: 8
(2, 8), (2, 8)

Image 4 & 9


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Before: (2, 18), (2, 18)
matches: (18,)
After RANSAC: 6
(2, 6), (2, 6)

Image 4 & 10
Before: (2, 132), (2, 132)
matches: (132,)
After RANSAC: 17
(2, 17), (2, 17)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 4 & 11
Before: (2, 9), (2, 9)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


matches: (9,)
After RANSAC: 3
(2, 3), (2, 3)

Image 5 & 6


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Before: (2, 4), (2, 4)
matches: (4,)
After RANSAC: 3
(2, 3), (2, 3)

Image 5 & 7
Before: (2, 187), (2, 187)
matches: (187,)
After RANSAC: 96
(2, 96), (2, 96)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 5 & 8
Before: (2, 12), (2, 12)
matches: (12,)
After RANSAC: 5
(2, 5), (2, 5)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 5 & 9
Before: (2, 6), (2, 6)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


matches: (6,)
After RANSAC: 4
(2, 4), (2, 4)

Image 5 & 10


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Before: (2, 72), (2, 72)
matches: (72,)
After RANSAC: 12
(2, 12), (2, 12)

Image 5 & 11
Before: (2, 24), (2, 24)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


matches: (24,)
After RANSAC: 10
(2, 10), (2, 10)

Image 6 & 7
Before: (2, 5), (2, 5)
matches: (5,)
After RANSAC: 3
(2, 3), (2, 3)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 6 & 8
Before: (2, 8), (2, 8)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


matches: (8,)
After RANSAC: 3
(2, 3), (2, 3)

Image 6 & 9


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Before: (2, 7), (2, 7)
matches: (7,)
After RANSAC: 3
(2, 3), (2, 3)

Image 6 & 10
Before: (2, 27), (2, 27)
matches: (27,)
After RANSAC: 9
(2, 9), (2, 9)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 6 & 11
Before: (2, 3), (2, 3)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


matches: (3,)
After RANSAC: 2
(2, 2), (2, 2)

Image 7 & 8


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Before: (2, 10), (2, 10)
matches: (10,)
After RANSAC: 4
(2, 4), (2, 4)

Image 7 & 9
Before: (2, 5), (2, 5)
matches: (5,)
After RANSAC: 3
(2, 3), (2, 3)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 7 & 10
Before: (2, 71), (2, 71)
matches: (71,)
After RANSAC: 14
(2, 14), (2, 14)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 7 & 11


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Before: (2, 16), (2, 16)
matches: (16,)
After RANSAC: 7
(2, 7), (2, 7)

Image 8 & 9
Before: (2, 391), (2, 391)
matches: (391,)
After RANSAC: 204
(2, 204), (2, 204)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 8 & 10


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Before: (2, 19), (2, 19)
matches: (19,)
After RANSAC: 6
(2, 6), (2, 6)

Image 8 & 11
Before: (2, 69), (2, 69)
matches: (69,)
After RANSAC: 18
(2, 18), (2, 18)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 9 & 10
Before: (2, 14), (2, 14)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


matches: (14,)
After RANSAC: 6
(2, 6), (2, 6)

Image 9 & 11
Before: (2, 67), (2, 67)
matches: (67,)
After RANSAC: 19
(2, 19), (2, 19)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).



Image 10 & 11
Before: (2, 9), (2, 9)
matches: (9,)
After RANSAC: 3
(2, 3), (2, 3)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


Done!
