<a href="https://colab.research.google.com/github/ssundar6087/simple_pano/blob/main/build_your_own_pano.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import os
import glob
%matplotlib inline

# Get & load the images

In [None]:
!git clone https://github.com/ssundar6087/simple_pano.git

In [None]:
!ls /content/simple_pano/images/

In [None]:
img_dir = "/content/simple_pano/images/"
img_list = []
for file in os.listdir(img_dir):
  if file.endswith(".png"):
    img_list.append(os.path.join(img_dir, file))
sorted_img_list = sorted(img_list)
print(sorted_img_list)

In [None]:
imgs = []
render_imgs = []
for fname in sorted_img_list:
  img = cv2.imread(fname)
  imgs.append(img)
  render_imgs.append(img)

for img in imgs:
  img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  plt.imshow(img_rgb)
  plt.axis("off")
  plt.pause(2)



# Quick and dirty

In [None]:
stitched = cv2.Stitcher_create()
(status, pano) = stitched.stitch(imgs)

In [None]:
if status != cv2.STITCHER_OK:
  print("Error generatting panorama")
else:
  img_rgb = cv2.cvtColor(pano, cv2.COLOR_BGR2RGB)
  plt.figure(figsize=(20,8))
  plt.imshow(img_rgb)
  plt.axis("off")  

# Deep dive

## Image Matching

### Detect Features

In [None]:
def detect_features(in_img):
  gray = cv2.cvtColor(in_img, cv2.COLOR_BGR2GRAY)
  feat_detector = cv2.ORB_create()
  keypts, ftrs = feat_detector.detectAndCompute(gray, None)

  return (keypts, ftrs)

In [None]:
def draw_features(src_img, keypoints):
  cv2.drawKeypoints(src_img, keypoints, src_img)
  plt_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
  plt.imshow(plt_img)
  plt.axis("off")

In [None]:
for fname in sorted_img_list:
  image = cv2.imread(fname)
  test_kpts, _ = detect_features(image)
  draw_features(image, test_kpts)
  plt.pause(1)

### Match features

In [None]:
def match_features(ftrs1, ftrs2):
  matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=False)
  matches = matcher.match(ftrs1, ftrs2)
  sorted_matches = sorted(matches, key = lambda x:x.distance)
  
  if len(sorted_matches) < 5:
    raise Exception("insufficient matches")

  match_count = int(0.10 * len(sorted_matches))
  best_matches = sorted_matches[:match_count]
  
  return best_matches

In [None]:
keypoint_list = []
feature_list = []
for img in imgs:
  kpts, ftrs = detect_features(img)
  keypoint_list.append(kpts)
  feature_list.append(ftrs)

print(len(feature_list[0]), len(feature_list[1]))
matches_01 = match_features(feature_list[0], feature_list[1])
print(len(matches_01))

In [None]:
def draw_matches(img1, kpt1, img2, kpt2, matches):
  res_img = cv2.drawMatches(img1,
                            kpt1,
                            img2,
                            kpt2,
                            matches,
                            None,
                            flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS,
                            )
  render_img = cv2.cvtColor(res_img, cv2.COLOR_BGR2RGB)
  plt.figure(figsize=(20, 8))
  plt.imshow(render_img)
  plt.axis("off")

In [None]:
draw_matches(imgs[0], 
             keypoint_list[0], 
             imgs[1],
             keypoint_list[1], 
             matches_01)

In [None]:
matches_12 = match_features(feature_list[2], feature_list[1])
draw_matches(render_imgs[2], 
             keypoint_list[2], 
             render_imgs[1],
             keypoint_list[1], 
             matches_12)

### Determine transformation 

In [None]:
def compute_homography(kpts1, kpts2, ftrs1, ftrs2, matches, thresh):
  if len(matches) < 4: # need 4 matches min for homography
    raise Exception("insufficient matches to compute homography")

  matched_pts1 = np.float32([kpts1[m.queryIdx].pt for m in matches])
  matched_pts2 = np.float32([kpts2[m.trainIdx].pt for m in matches])

  (H, status) = cv2.findHomography(matched_pts1, 
                                   matched_pts2,
                                   cv2.RANSAC,
                                   thresh,
                                   )

  return (H, status)

In [None]:
(H_01, status) = compute_homography(keypoint_list[0],
                                 keypoint_list[1],
                                 feature_list[0],
                                 feature_list[1],
                                 matches_01,
                                 4)
print(H_01)

In [None]:
(H_12, status) = compute_homography(keypoint_list[2],
                                 keypoint_list[1],
                                 feature_list[2],
                                 feature_list[1],
                                 matches_12,
                                 4)
print(H_12)

## Build Panorama

### Warp into place

In [None]:
# Reference: https://stackoverflow.com/questions/13063201/how-to-show-the-whole-image-when-using-opencv-warpperspective
def stitch_images(img1, img2, H):
  h1,w1 = img1.shape[:2]
  h2,w2 = img2.shape[:2]
  pts1 = np.float32([[0,0],[0,h1],[w1,h1],[w1,0]]).reshape(-1,1,2)
  pts2 = np.float32([[0,0],[0,h2],[w2,h2],[w2,0]]).reshape(-1,1,2)
  pts2_ = cv2.perspectiveTransform(pts2, H)
  pts = np.concatenate((pts1, pts2_), axis=0)
  [xmin, ymin] = np.int32(pts.min(axis=0).ravel() - 0.5)
  [xmax, ymax] = np.int32(pts.max(axis=0).ravel() + 0.5)
  t = [-xmin,-ymin]
  Ht = np.array([[1,0,t[0]],[0,1,t[1]],[0,0,1]]) # translate

  result = cv2.warpPerspective(img2, Ht.dot(H), (xmax-xmin, ymax-ymin))
  result[t[1]:h1+t[1],t[0]:w1+t[0]] = img1

  # Crop result
  [xmin_res, ymin_res] = np.int32(pts.min(axis=0).ravel() - 0.5)
  t_res = [-xmin_res, -ymin_res]
  pts = pts.astype(int)
  result = result[t_res[1] : h1 + t_res[1], :, :]

  return result


  

In [None]:
tmp_res1 = stitch_images(imgs[1], imgs[0], H_01)
tmp_render1 = cv2.cvtColor(tmp_res1, cv2.COLOR_BGR2RGB)

In [None]:
plt.figure(figsize=(20, 8))
plt.imshow(tmp_render1);
plt.axis("off");

In [None]:
tmp_res2 = stitch_images(imgs[1], imgs[2], H_12)
tmp_render2 = cv2.cvtColor(tmp_res2, cv2.COLOR_BGR2RGB)

In [None]:
plt.figure(figsize=(20, 8))
plt.imshow(tmp_render2);
plt.axis("off");

### Match and Stitch left half with right

In [None]:
keypoint_pano, feature_pano = detect_features(tmp_res1)
matches_final = match_features(feature_list[2], feature_pano)
print(len(matches_final))

In [None]:
(H_final, status) = compute_homography(keypoint_list[2],
                                 keypoint_pano,
                                 feature_list[2],
                                 feature_pano,
                                 matches_final,
                                 4)
print(H_final)

In [None]:
res = stitch_images(tmp_res1, imgs[2], H_final)
res_render = cv2.cvtColor(res, cv2.COLOR_BGR2RGB)

In [None]:
plt.figure(figsize=(20, 8))
plt.imshow(res_render);
plt.axis("off");