In [None]:
from skimage.io import imread, imsave, imshow
import numpy as np
import matplotlib.pyplot as plt
import time
from skimage.color import rgb2gray, rgb2lab
from skimage.filters import laplace
from scipy.ndimage import convolve

In [None]:
# for plotting each iteration
def plot_image(working_image, working_mask, front):
    height, width = working_mask.shape

    inverse_mask = 1 - working_mask
    rgb_inverse_mask = convert_to_rgb(inverse_mask)
    image = working_image * rgb_inverse_mask

    image[:, :, 0] += front * 255

    white_portion = (working_mask - front) * 255
    rgb_white_region = convert_to_rgb(white_portion)
    image += rgb_white_region

    plt.clf()
    plt.figure(figsize=(6, 4))
    plt.imshow(image)
    plt.axis('off')
    plt.draw()
    plt.pause(0.001)

In [None]:
#function to calculate unit normal perpendicular to the front
def calculate_normal(working_mask):
    x_kernel = np.array([[.25, 0, -.25], [.5, 0, -.5], [.25, 0, -.25]])
    y_kernel = np.array([[-.25, -.5, -.25], [0, 0, 0], [.25, .5, .25]])

    x_normal = convolve(working_mask.astype(float), x_kernel)
    y_normal = convolve(working_mask.astype(float), y_kernel)
    normal = np.dstack((x_normal, y_normal))

    height, width = normal.shape[:2]
    norm = np.sqrt(y_normal**2 + x_normal**2).reshape(height, width, 1).repeat(2, axis=2)
    norm[norm == 0] = 1

    unit_normal = normal / norm
    return unit_normal

In [None]:
#Find patch of input patch_size centered at a point
def find_patch(working_image, point, patch_size):
    half_patch_size = (patch_size - 1) // 2
    height, width = working_image.shape[:2]

    #If patch lies at either of the image, we modify our patch accordingly
    patch = [
        [
            max(0, point[0] - half_patch_size),
            min(point[0] + half_patch_size, height - 1)
        ],
        [
            max(0, point[1] - half_patch_size),
            min(point[1] + half_patch_size, width - 1)
        ]
    ]
    return patch

In [None]:
#get pixels from the corresponding source patch
def patch_data(source, patch):
    return source[
        patch[0][0]:patch[0][1] + 1,
        patch[1][0]:patch[1][1] + 1
    ]

#find the patch shape
def patch_shape(patch):
    return (1+patch[0][1]-patch[0][0]), (1+patch[1][1]-patch[1][0])

#find area of patch
def patch_area(patch):
    return (1 + patch[0][1] - patch[0][0]) * (1 + patch[1][1] - patch[1][0])

In [None]:
#Converting image to RGB
def convert_to_rgb(image):
    height, width = image.shape
    return image.reshape(height, width, 1).repeat(3, axis=2)

In [None]:
#function returns gradient matrix for updating data
def calculate_gradient(working_image, working_mask, front, patch_size):
    height, width = working_image.shape[:2]

    grey_image = rgb2gray(working_image)
    grey_image[working_mask == 1] = None

    gradient = np.nan_to_num(np.array(np.gradient(grey_image)))
    gradient_val = np.sqrt(gradient[0]**2 + gradient[1]**2)
    max_gradient = np.zeros([height, width, 2])

    front_positions = np.argwhere(front == 1)
    for point in front_positions:
        patch = find_patch(working_image, point, patch_size)
        patch_y_gradient = patch_data(gradient[0], patch)
        patch_x_gradient = patch_data(gradient[1], patch)
        patch_gradient_val = patch_data(gradient_val, patch)

        patch_max_pos = np.unravel_index(
            patch_gradient_val.argmax(),
            patch_gradient_val.shape
        )

        max_gradient[point[0], point[1], 0] = patch_y_gradient[patch_max_pos]
        max_gradient[point[0], point[1], 1] = patch_x_gradient[patch_max_pos]

    return max_gradient

In [None]:
# find distance between target patch and source patch
def patch_diff(image, working_mask, target_patch, source_patch):
    mask = 1 - patch_data(working_mask, target_patch)
    rgb_mask = convert_to_rgb(mask)
    target_data = patch_data(image, target_patch) * rgb_mask
    source_data = patch_data(image, source_patch) * rgb_mask
    squared_distance = ((target_data - source_data)**2).sum()
    euclidean_distance = np.sqrt(
        (target_patch[0][0] - source_patch[0][0])**2 +
        (target_patch[1][0] - source_patch[1][0])**2
    )
    return squared_distance + euclidean_distance

In [None]:
# finding source patch to fill up target patch
def find_source_patch(working_image, working_mask, target_pixel, patch_size):
    target_patch = find_patch(working_image, target_pixel, patch_size)
    height, width = working_image.shape[:2]
    patch_height, patch_width = patch_shape(target_patch)

    best_match = None
    best_match_diff = 0

    lab_image = rgb2lab(working_image)

    for y in range(height - patch_height + 1):
        for x in range(width - patch_width + 1):
            source_patch = [
                [y, y + patch_height-1],
                [x, x + patch_width-1]
            ]
            if patch_data(working_mask, source_patch).sum() != 0:
                continue

            # Compute sum of Euclidean distance and Squared distance between patches as difference
            diff = patch_diff(lab_image, working_mask, target_patch, source_patch)

            # Select best match as the source patch that has minimum difference from target patch
            if best_match is None or diff < best_match_diff:
                best_match = source_patch
                best_match_diff = diff
    return best_match

In [None]:
# Updating confidence values for updated working_image and updated working_mask
def update_confidence_value(confidence, working_image, working_mask, patch_size, target_pixel, source_patch):
    target_patch = find_patch(working_image, target_pixel, patch_size)
    pixels_positions = np.argwhere(
        patch_data(
            working_mask,
            target_patch
        ) == 1
    ) + [target_patch[0][0], target_patch[1][0]]
    patch_confidence = confidence[target_pixel[0], target_pixel[1]]
    for point in pixels_positions:
        confidence[point[0], point[1]] = patch_confidence

    mask = patch_data(working_mask, target_patch)
    rgb_mask = convert_to_rgb(mask)
    source_data = patch_data(working_image, source_patch)
    target_data = patch_data(working_image, target_patch)

    new_data = source_data * rgb_mask + target_data * (1 - rgb_mask)

    #Copy new data to target patch in working_image
    working_image[
        target_patch[0][0]:target_patch[0][1] + 1,
        target_patch[1][0]:target_patch[1][1] + 1
    ] = new_data

    working_mask[
        target_patch[0][0]:target_patch[0][1] + 1,
        target_patch[1][0]:target_patch[1][1] + 1
    ] = 0

    return working_image, working_mask

In [None]:
# main function for inpainting
def inpaint_image(image_path, mask_path, patch_size=13, show_progress=False):

    # Read the image and mask inputs and convert them to 'uint8' datatype
    image = imread(image_path).astype('uint8')
    mask = imread(mask_path, as_gray=True).round().astype('uint8')

    #check if image and mask are of the same shape
    if image.shape[:2] != mask.shape:
        raise AttributeError('Image and Mask are not of the same size!!')

    height, width = image.shape[:2]

    # Confidence is initialized as the inverse of the mask (the target region is 0 and the source region is 1)
    confidence = (1 - mask).astype(float)
    data = np.zeros([height, width])

    #working_image and working_mask are initialized as copies of input image and mask
    working_image = np.copy(image)
    working_mask = np.copy(mask)

    continue_inpainting = True
    i=1
    start_time = time.time()

    #continue inpainting until target region is filled
    while continue_inpainting:

        # The front or contour is found using the Laplace filter on the working_mask.
        # Laplace gives us the edges around the target region
        front = (laplace(working_mask) > 0).astype('uint8')

        # Plotting each iteration in our algorithm
        if show_progress:
            plot_image(working_image, working_mask, front)
            print(f"Iteration {i}")
        # Calculate confidence values - confidence is higher for source region pixels that are near the edges of contour
        new_data_confidence = np.copy(confidence)
        front_positions = np.argwhere(front == 1)
        for point in front_positions:
            patch = find_patch(working_image, point, patch_size)
            new_data_confidence[point[0], point[1]] = sum(sum(patch_data(confidence, patch))) / patch_area(patch)

        confidence = new_data_confidence

        # Calculate normal orthogonal to the front and gradient
        normal = calculate_normal(working_mask)
        gradient = calculate_gradient(working_image, working_mask, front, patch_size)

        normal_gradient = normal * gradient
        data = np.sqrt(normal_gradient[:, :, 0]**2 + normal_gradient[:, :, 1]**2) + 0.001

        # Updating priority values - that are assigned to each patch on the fill front
        priority = confidence * data * front

        # Find the highest priority pixel
        target_pixel = np.unravel_index(priority.argmax(), priority.shape)

        find_start_time = time.time()

        #Find the most similar patch to our target pixel
        source_patch = find_source_patch(working_image, working_mask, target_pixel, patch_size)

        #Update working_image, working_mask and confidence values after filling up target patch
        working_image, working_mask = update_confidence_value(confidence, working_image, working_mask, patch_size, target_pixel, source_patch)

        i+=1
        #Check if target region has been filled completely
        if working_mask.sum()==0:
            continue_inpainting = False
            print(f"Completed in {i} iterations. Took {time.time() - start_time} seconds!!")
    return working_image


### Uncomment these lines to check results

In [5]:
# result1 = inpaint_image("jumper.jpg", "jumper_mask.jpg", patch_size= 13, show_progress=True)

In [4]:
# result2 = inpaint_image("baseball.jpg", "baseball_mask.jpg", patch_size= 11, show_progress=True)

In [3]:
# result3 = inpaint_image("Curved_lines.jpg", "Curved_lines_mask.jpg", patch_size= 11, show_progress=True)

In [2]:
# result4 = inpaint_image("aerial.jpg", "aerial_mask.jpg", patch_size= 13, show_progress=True)

In [1]:
# result5 = inpaint_image("plus.jpg", "circle_mask.jpg", patch_size= 13, show_progress=True)