In [None]:
from __future__ import division, print_function
%matplotlib inline
%load_ext Cython

# Reproduction notebook for Panorama stitching

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from skimage import io
pano_images = io.ImageCollection(
    './images/JDW_9*')

In [None]:
def compare(*images, **kwargs):
    """
    Utility function to display images side by side.
    
    Parameters
    ----------
    image0, image1, image2, ... : ndarrray
        Images to display.
    labels : list
        Labels for the different images.
    """
    if 'vertical' in kwargs:
        vertical = kwargs.pop('vertical')
    else:
        vertical = False
        
    if vertical is not True:
        f, axes = plt.subplots(1, len(images), **kwargs)
    else:
        f, axes = plt.subplots(len(images), 1, **kwargs)

    axes = np.array(axes, ndmin=1)
    
    labels = kwargs.pop('labels', None)
    if labels is None:
        labels = [''] * len(images)
    
    for n, (image, label) in enumerate(zip(images, labels)):
        axes[n].imshow(image, interpolation='nearest', cmap='gray')
        axes[n].set_title(label)
        axes[n].axis('off')
    
    f.subplots_adjust(left=0, right=1, top=1, bottom=0, hspace=0.01, wspace=0.01)
    return f, axes

In [None]:
f, axes = compare(*pano_images, figsize=(12, 10));
# f.savefig('./pano0-originals.png', dpi=300, pad_inches=0, bbox_inches='tight')

In [None]:
import matplotlib.pyplot as plt
from skimage.color import rgb2gray
from skimage.feature import (ORB, match_descriptors,
                             plot_matches)

# Initialize ORB
orb = ORB(n_keypoints=800, fast_threshold=0.05)
keypoints = []
descriptors = []

# Detect features
for image in pano_images:
    orb.detect_and_extract(rgb2gray(image))
    keypoints.append(orb.keypoints)
    descriptors.append(orb.descriptors)

# Match features from images 0 -> 1 and 2 -> 1
matches01 = match_descriptors(descriptors[0],
                              descriptors[1],
                              cross_check=True)
matches12 = match_descriptors(descriptors[1],
                              descriptors[2],
                              cross_check=True)

# Show raw matched features from left to center
fig, ax = plt.subplots()
plot_matches(ax, pano_images[0], pano_images[1],
             keypoints[0], keypoints[1], matches01)
ax.axis('off');
# fig.savefig('./pano1_ORB-raw.png', dpi=500, pad_inches=0, bbox_inches='tight')

In [None]:
from skimage.measure import ransac
from skimage.transform import ProjectiveTransform

# Keypoints from left (src) to middle (dst) images
src = keypoints[0][matches01[:, 0]][:, ::-1]
dst = keypoints[1][matches01[:, 1]][:, ::-1]

model_ransac01, inliers01 = ransac(
    (src, dst), ProjectiveTransform, min_samples=4,
    residual_threshold=1, max_trials=300)

# Keypoints from right (src) to middle (dst) images
src = keypoints[2][matches12[:, 1]][:, ::-1]
dst = keypoints[1][matches12[:, 0]][:, ::-1]

model_ransac12, inliers12 = ransac(
    (src, dst), ProjectiveTransform, min_samples=4,
    residual_threshold=1, max_trials=300)

# Show robust, RANSAC-matched features
fig, ax = plt.subplots()
plot_matches(ax, pano_images[0], pano_images[1],
             keypoints[0], keypoints[1],
             matches01[inliers01])
ax.axis('off');
# fig.savefig('./pano2_ORB-RANSAC.png', dpi=500, pad_inches=0, bbox_inches='tight')

In [None]:
# All three images have the same size
r, c = pano_images[1].shape[:2]

# Note that transformations take coordinates in
# (x, y) format, not (row, column), for literature
# consistency
corners = np.array([[0, 0],
                    [0, r],
                    [c, 0],
                    [c, r]])

# Warp image corners to their new positions
warped_corners01 = model_ransac01(corners)
warped_corners12 = model_ransac12(corners)

# Extents of both target and warped images
all_corners = np.vstack((warped_corners01,
                         warped_corners12,
                         corners))

# Overall output shape is max - min
corner_min = np.min(all_corners, axis=0)
corner_max = np.max(all_corners, axis=0)
output_shape = (corner_max - corner_min)

# Ensure integer shape
output_shape = np.ceil(
    output_shape[::-1]).astype(int)

In [None]:
from skimage.transform import warp, SimilarityTransform

offset1 = SimilarityTransform(translation= -corner_min)

# Translate pano1 into place
pano1_warped = warp(
    pano_images[1], offset1.inverse, order=3,
    output_shape=output_shape, cval=-1)

# Acquire the image mask for later use
# Mask == 1 inside image, then return backgroun to 0
pano1_mask = (pano1_warped != -1)[..., 0]
pano1_warped[~pano1_mask] = 0

In [None]:
from skimage.transform import warp, SimilarityTransform

offset1 = SimilarityTransform(translation= -corner_min)

# Translate pano1 into place
pano1_warped = warp(
    pano_images[1], offset1.inverse, order=3,
    output_shape=output_shape, cval=-1)

# Acquire the image mask for later use
# Mask == 1 inside image, then return backgroun to 0
pano1_mask = (pano1_warped != -1)[..., 0]
pano1_warped[~pano1_mask] = 0

In [None]:
# Warp left image
transform01 = (model_ransac01 + offset1).inverse
pano0_warped = warp(
    pano_images[0], transform01, order=3,
    output_shape=output_shape, cval=-1)

# Mask == 1 inside image, then return backgroun to 0
pano0_mask = (pano0_warped != -1)[..., 0]
pano0_warped[~pano0_mask] = 0

# Warp right image
transform12 = (model_ransac12 + offset1).inverse
pano2_warped = warp(
    pano_images[2], transform12, order=3,
    output_shape=output_shape, cval=-1)

# Mask == 1 inside image, then return backgroun to 0
pano2_mask = (pano2_warped != -1)[..., 0]
pano1_warped[~pano1_mask] = 0

In [None]:
f, ax = compare(pano0_warped, pano1_warped, pano2_warped, vertical=True)
# f.savefig('./pano3_warped.png', dpi=500, pad_inches=0, bbox_inches='tight')

In [None]:
ymax = output_shape[1] - 1
xmax = output_shape[0] - 1

# Start anywhere along the top and bottom
mask_pts01 = [[0,    ymax // 3],
              [xmax, ymax // 3]]

# Start anywhere along the top and bottom
mask_pts12 = [[0,    2*ymax // 3],
              [xmax, 2*ymax // 3]]

## Cost array and flood fill functions from appendix

In [None]:
%%cython
import cython
import numpy as np
cimport numpy as cnp


# Compiler directives
@cython.cdivision(True)
@cython.boundscheck(False)
@cython.nonecheck(False)
@cython.wraparound(False)
def flood_fill(unsigned char[:, ::1] data, tuple start_coords,
               Py_ssize_t fill_value):
    """
    Flood fill algorithm
    
    Parameters
    ----------
    data : (M, N) ndarray of uint8 type
        Image with flood to be filled. Modified inplace.
    start_coords : tuple
        Length-2 tuple of ints defining (row, col) start coordinates.
    fill_value : int
        Value the flooded area will take after the fill.
        
    Returns
    -------
    None, ``data`` is modified inplace.
    """
    cdef:
        Py_ssize_t x, y, xsize, ysize, orig_value, ystart, xstart
        set stack
    
    xsize = data.shape[0]
    ysize = data.shape[1]
    xstart = start_coords[0]
    ystart = start_coords[1]
    orig_value = data[start_coords[0], start_coords[1]]
    
    if fill_value == orig_value:
        raise ValueError("Filling region with same value "
                         "already present is unsupported. "
                         "Did you already fill this region?")
    
    stack = set(((start_coords[0], start_coords[1]),))

    while stack:
        x, y = stack.pop()

        if data[x, y] == orig_value:
            data[x, y] = fill_value
            if x > 0:
                stack.add((x - 1, y))
            if x < (xsize - 1):
                stack.add((x + 1, y))
            if y > 0:
                stack.add((x, y - 1))
            if y < (ysize - 1):
                stack.add((x, y + 1))

In [None]:
def generate_costs(diff_image, mask, vertical=True,
                   gradient_cutoff=2.,
                   zero_edges=True):
    """
    Ensure equal-cost paths from edges to
    region of interest.

    Parameters
    ----------
    diff_image : (M, N) ndarray of floats
        Difference of two overlapping images.
    mask : (M, N) ndarray of bools
        Mask representing the region of interest in
        ``diff_image``.
    vertical : bool
        Control if stitching line is vertical or
        horizontal.
    gradient_cutoff : float
        Controls how far out of parallel lines can
        be to edges before correction is terminated.
        The default (2.) is good for most cases.
    zero_edges : bool
        If True, the edges are set to zero so the
        seed is not bound to any specific horizontal
        location.

    Returns
    -------
    costs_arr : (M, N) ndarray of floats
        Adjusted costs array, ready for use.
    """
    if vertical is not True:  # run transposed
        return generate_costs(
            diff_image.T, mask.T, vertical=True,
            gradient_cutoff=gradient_cutoff).T

    # Start with a high-cost array of 1's
    diff_image = rgb2gray(diff_image)
    costs_arr = np.ones_like(diff_image)

    # Obtain extent of overlap
    row, col = mask.nonzero()
    cmin = col.min()
    cmax = col.max()

    # Label discrete regions
    cslice = slice(cmin, cmax + 1)
    labels = mask[:, cslice].astype(np.uint8).copy()

    # Fill top and bottom with unique labels
    masked_pts = np.where(labels)
    flood_fill(labels, (masked_pts[0][0], 
                        masked_pts[1][0]), 2)
    flood_fill(labels, (0, labels.shape[0] // 2), 1)
    flood_fill(labels, (labels.shape[0] - 1, 
                        labels.shape[1] // 2), 3)

    # Find distance from edge to region
    upper = (labels == 1).sum(axis=0)
    lower = (labels == 3).sum(axis=0)

    # Reject areas of high change
    ugood = np.abs(
        np.gradient(upper)) < gradient_cutoff
    lgood = np.abs(
        np.gradient(lower)) < gradient_cutoff

    # Cost break to areas slightly farther from edge
    costs_upper = np.ones_like(upper,
                               dtype=np.float64)
    costs_lower = np.ones_like(lower,
                               dtype=np.float64)
    costs_upper[ugood] = (
        upper.min() / np.maximum(upper[ugood], 1))
    costs_lower[lgood] = (
        lower.min() / np.maximum(lower[lgood], 1))

    # Expand from 1d back to 2d
    vdis = mask.shape[0]
    costs_upper = (
        costs_upper[np.newaxis, :].repeat(vdis, axis=0))
    costs_lower = (
        costs_lower[np.newaxis, :].repeat(vdis, axis=0))

    # Place these in output array
    costs_arr[:, cslice] = costs_upper * (labels==1)
    costs_arr[:, cslice] += costs_lower * (labels==3)

    # Finally, place the difference image
    costs_arr[mask] = np.abs(diff_image[mask])

    if zero_edges is True:  # top & bottom rows = zero
        costs_arr[0, :] = 0
        costs_arr[-1, :] = 0

    return costs_arr

In [None]:
# Use the generate_costs function
costs01 = generate_costs(pano0_warped - pano1_warped,
                         pano0_mask & pano1_mask)
costs12 = generate_costs(pano1_warped - pano2_warped,
                         pano1_mask & pano2_mask)

In [None]:
from skimage.graph import route_through_array

# Find the MCP
pts01, _ = route_through_array(
    costs01, mask_pts01[0], mask_pts01[1],
    fully_connected=True)

pts01 = np.array(pts01)

# Create final mask for the left image
mask0 = np.zeros_like(pano0_warped[..., 0],
                      dtype=np.uint8)
mask0[pts01[:, 0], pts01[:, 1]] = 1
flood_fill(mask0, (0, 0), 1)

In [None]:
fig, ax = plt.subplots(figsize=(12, 12))
import skimage.morphology as morph

# Plot the difference image
ax.imshow(costs01, cmap='gray', vmin=-1 * costs01.max(), vmax=costs01.max())

# Overlay the minimum-cost path
ax.plot(pts01[:, 1], pts01[:, 0])  

plt.tight_layout()
ax.axis('off');
# fig.savefig('./pano4_mcp.png', dpi=600, bbox_inches='tight')

In [None]:
# New constraint modifying cost array
costs12[mask0 > 0] = 1

pts12, _ = route_through_array(
    costs12, mask_pts12[0], mask_pts12[1],
    fully_connected=True)

pts12 = np.array(pts12)

# Final mask for right image
mask2 = np.zeros_like(mask0, dtype=np.uint8)
mask2[pts12[:, 0], pts12[:, 1]] = 1
flood_fill(mask2, (mask2.shape[0] - 1,
                   mask2.shape[1] - 1), 1)

# Mask for middle image is one of exclusion
mask1 = ~(mask0 | mask2).astype(bool)

In [None]:
fig, ax = plt.subplots(figsize=(12, 12))
import skimage.morphology as morph

# Plot the difference image
ax.imshow(costs12, cmap='gray', vmin=-1 * costs12.max(), vmax=costs12.max())

# Overlay the minimum-cost path
ax.plot(pts12[:, 1], pts12[:, 0])  

plt.tight_layout()
ax.axis('off');
# fig.savefig('./pano4_mcp.png', dpi=600, bbox_inches='tight')

In [None]:
# Convenience function for alpha blending
def add_alpha(img, mask=None):
    """
    Adds a masked alpha channel to an image.

    Parameters
    ----------
    img : (M, N[, 3]) ndarray
        Image data, should be rank-2 or rank-3
        with RGB channels
    mask : (M, N[, 3]) ndarray, optional
        Mask to be applied. If None, the alpha channel
        is added with full opacity assumed (1) for all
        locations.
    """
    from skimage.color import gray2rgb
    if mask is None:
        mask = np.ones_like(img)

    if img.ndim == 2:
        img = gray2rgb(img)

    return np.dstack((img, mask))

# Applying this function
left_final = add_alpha(pano0_warped, mask0)
middle_final = add_alpha(pano1_warped, mask1)
right_final = add_alpha(pano2_warped, mask2)

In [None]:
fig, ax = plt.subplots()

# Turn off matplotlib's interpolation
ax.imshow(left_final, interpolation='none')
ax.imshow(middle_final, interpolation='none')
ax.imshow(right_final, interpolation='none')

ax.axis('off')
fig.tight_layout()
fig.show()

In [None]:
from skimage.color import gray2rgb

# Start with empty image
pano_combined = np.zeros_like(pano0_warped)

# Place the masked portion of each image into the array
# masks are 2d, they need to be (M, N, 3) to match the color images
pano_combined += pano0_warped * gray2rgb(mask0)
pano_combined += pano1_warped * gray2rgb(mask1)
pano_combined += pano2_warped * gray2rgb(mask2)

# Save the output - precision loss warning is expected
# moving from floating point -> uint8
fig, ax = plt.subplots()
ax.imshow(pano_combined)
ax.axis('off');
plt.show()
# io.imsave('./pano5_final.png', pano_combined)