# Notebook 3: Patch-Based Image Restoration

This notebook implements exemplar-based patch inpainting. The algorithm finds the best matching patches from known regions to fill in damaged areas, using priority computation based on confidence and data terms.


In [None]:
import sys
import os
from pathlib import Path
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
from scipy.ndimage import binary_dilation, distance_transform_edt

sys.path.append('..')
from utils import load_image, save_image, ensure_dir, get_image_files

print("Libraries imported successfully!")


## 1. Patch-Based Inpainting Implementation

The exemplar-based algorithm:
1. Computes priority for each patch on the boundary of the damaged region
2. Selects the patch with highest priority
3. Finds the best matching patch from known regions
4. Copies the patch and updates confidence
5. Repeats until all damaged regions are filled


In [None]:
def compute_confidence(confidence_map, mask, patch_size):
    """Compute confidence term for priority calculation."""
    h, w = confidence_map.shape
    conf = np.zeros((h - patch_size + 1, w - patch_size + 1))
    
    for y in range(h - patch_size + 1):
        for x in range(w - patch_size + 1):
            patch_conf = confidence_map[y:y+patch_size, x:x+patch_size]
            patch_mask = mask[y:y+patch_size, x:x+patch_size]
            # Only consider patches that are on the boundary
            if np.any(patch_mask == 0) and np.any(patch_mask == 255):
                conf[y, x] = np.sum(patch_conf * (1 - patch_mask / 255.0)) / (patch_size * patch_size)
    
    return conf


def compute_data_term(image, mask, patch_size):
    """Compute data term (isophote strength) for priority calculation."""
    h, w = image.shape[:2]
    data = np.zeros((h - patch_size + 1, w - patch_size + 1))
    
    # Convert to grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) if len(image.shape) == 3 else image
    
    # Compute gradients
    grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
    grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
    
    # Compute normal to the boundary
    mask_float = mask.astype(np.float64) / 255.0
    normal_x = cv2.Sobel(mask_float, cv2.CV_64F, 1, 0, ksize=3)
    normal_y = cv2.Sobel(mask_float, cv2.CV_64F, 0, 1, ksize=3)
    
    for y in range(h - patch_size + 1):
        for x in range(w - patch_size + 1):
            patch_mask = mask[y:y+patch_size, x:x+patch_size]
            if np.any(patch_mask == 0) and np.any(patch_mask == 255):
                # Get center pixel
                cy, cx = y + patch_size // 2, x + patch_size // 2
                if 0 <= cy < h and 0 <= cx < w:
                    # Compute isophote strength
                    isophote_x = -grad_y[cy, cx]
                    isophote_y = grad_x[cy, cx]
                    normal = np.array([normal_x[cy, cx], normal_y[cy, cx]])
                    if np.linalg.norm(normal) > 0:
                        normal = normal / np.linalg.norm(normal)
                        data[y, x] = abs(isophote_x * normal[0] + isophote_y * normal[1])
    
    return data


def find_best_patch(source_patch, image, mask, patch_size, search_window=None):
    """Find the best matching patch in the known region."""
    h, w = image.shape[:2]
    if search_window is None:
        search_window = (h, w)
    
    best_match = None
    best_error = float('inf')
    best_pos = None
    
    # Known region mask (inverse of damage mask)
    known_mask = (mask == 0).astype(np.uint8)
    
    # Search in known regions only
    for y in range(0, h - patch_size + 1, 2):  # Step by 2 for speed
        for x in range(0, w - patch_size + 1, 2):
            # Check if patch is fully in known region
            patch_known = known_mask[y:y+patch_size, x:x+patch_size]
            if np.all(patch_known == 1):
                candidate = image[y:y+patch_size, x:x+patch_size]
                # Compute SSD only on known pixels of source patch
                source_known = (source_patch != 0).all(axis=2) if len(source_patch.shape) == 3 else (source_patch != 0)
                if np.any(source_known):
                    error = np.sum((candidate[source_known] - source_patch[source_known]) ** 2)
                    if error < best_error:
                        best_error = error
                        best_match = candidate
                        best_pos = (y, x)
    
    return best_match, best_pos


def patch_inpaint(image, mask, patch_size=9, max_iterations=10000):
    """
    Exemplar-based patch inpainting.
    
    Args:
        image: Input damaged image (RGB)
        mask: Binary mask (255 for damaged regions)
        patch_size: Size of patches (odd number recommended)
        max_iterations: Maximum number of iterations
        
    Returns:
        Restored image
    """
    result = image.copy().astype(np.float64)
    mask_work = mask.copy()
    
    # Initialize confidence map
    confidence = (mask == 0).astype(np.float64)
    
    iteration = 0
    while np.any(mask_work == 255) and iteration < max_iterations:
        # Find boundary pixels
        boundary = cv2.Canny(mask_work, 50, 150)
        boundary_dilated = binary_dilation(boundary > 0, structure=np.ones((patch_size//2, patch_size//2)))
        
        if not np.any(boundary_dilated):
            break
        
        # Compute priority
        conf_map = compute_confidence(confidence, mask_work, patch_size)
        data_map = compute_data_term(result.astype(np.uint8), mask_work, patch_size)
        
        # Priority = confidence * data
        priority = conf_map * (data_map + 0.001)  # Add small epsilon to avoid zero
        
        # Find patch with highest priority
        max_priority = np.max(priority)
        if max_priority <= 0:
            break
        
        y, x = np.unravel_index(np.argmax(priority), priority.shape)
        
        # Extract source patch
        source_patch = result[y:y+patch_size, x:x+patch_size].copy()
        source_mask = mask_work[y:y+patch_size, x:x+patch_size]
        
        # Find best matching patch
        best_patch, best_pos = find_best_patch(source_patch, result.astype(np.uint8), mask_work, patch_size)
        
        if best_patch is None:
            # If no match found, use average of surrounding pixels
            best_patch = source_patch.copy()
            for c in range(3):
                channel = source_patch[:, :, c]
                known_pixels = channel[source_mask == 0]
                if len(known_pixels) > 0:
                    channel[source_mask == 255] = np.mean(known_pixels)
                best_patch[:, :, c] = channel
        
        # Copy only the masked portion
        mask_patch = source_mask == 255
        if np.any(mask_patch):
            result[y:y+patch_size, x:x+patch_size][mask_patch] = best_patch[mask_patch]
            # Update confidence
            conf_value = np.sum(confidence[y:y+patch_size, x:x+patch_size] * (1 - source_mask / 255.0))
            conf_value = conf_value / np.sum(1 - source_mask / 255.0) if np.sum(1 - source_mask / 255.0) > 0 else 0
            confidence[y:y+patch_size, x:x+patch_size][mask_patch] = conf_value
            # Update mask
            mask_work[y:y+patch_size, x:x+patch_size][mask_patch] = 0
        
        iteration += 1
        
        # Progress update
        if iteration % 100 == 0:
            remaining = np.sum(mask_work == 255)
            print(f"Iteration {iteration}, remaining pixels: {remaining}")
    
    return result.astype(np.uint8)


def patch_inpaint_fast(image, mask, patch_size=9):
    """
    Faster version using OpenCV's inpainting as fallback and patch matching optimization.
    """
    # Use OpenCV's built-in inpainting for initialization
    mask_uint8 = (mask == 255).astype(np.uint8)
    result = cv2.inpaint(image, mask_uint8, 3, cv2.INPAINT_TELEA)
    
    # Refine with patch-based method on remaining areas
    remaining_mask = (cv2.inpaint(mask_uint8, mask_uint8, 3, cv2.INPAINT_TELEA) < 255).astype(np.uint8) * 255
    
    if np.any(remaining_mask == 255):
        # Apply patch inpainting on remaining areas
        refined = patch_inpaint(result, remaining_mask, patch_size, max_iterations=5000)
        return refined
    
    return result

print("Patch inpainting functions defined!")


## 2. Test on Sample Image


In [None]:
# Load sample data
DAMAGED_DIR = Path('../data/damaged')
MASKS_DIR = Path('../data/masks')
GROUND_TRUTH_DIR = Path('../data/ground_truth')
RESULTS_DIR = Path('../methods/Patch/results')

ensure_dir(RESULTS_DIR)

# Get sample files
damaged_files = get_image_files(DAMAGED_DIR)
if damaged_files:
    # Use a smaller subset for testing (patch method is slower)
    sample_idx = 0
    sample_damaged = load_image(damaged_files[sample_idx])
    sample_mask = cv2.imread(str(MASKS_DIR / f"{damaged_files[sample_idx].stem}.png"), cv2.IMREAD_GRAYSCALE)
    sample_gt = load_image(GROUND_TRUTH_DIR / f"{damaged_files[sample_idx].stem}.png")
    
    print(f"Testing on: {damaged_files[sample_idx].name}")
    print(f"Image shape: {sample_damaged.shape}")
    
    # Test fast version (uses OpenCV + patch refinement)
    print("\nRunning fast patch inpainting...")
    result_fast = patch_inpaint_fast(sample_damaged, sample_mask, patch_size=9)
    
    # Visualize
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    
    axes[0].imshow(sample_gt)
    axes[0].set_title('Ground Truth')
    axes[0].axis('off')
    
    axes[1].imshow(sample_damaged)
    axes[1].set_title('Damaged')
    axes[1].axis('off')
    
    axes[2].imshow(result_fast)
    axes[2].set_title('Patch Inpainted')
    axes[2].axis('off')
    
    diff = np.abs(result_fast.astype(float) - sample_gt.astype(float))
    axes[3].imshow(diff.astype(np.uint8))
    axes[3].set_title('Difference')
    axes[3].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print("\nSample test complete!")


In [None]:
# Process all images
print(f"Processing {len(damaged_files)} images with patch-based inpainting...")
print("Note: This method uses OpenCV's inpainting with patch-based refinement for efficiency.")

PATCH_SIZE = 9

for img_path in tqdm(damaged_files, desc="Patch Inpainting"):
    # Load damaged image and mask
    damaged_img = load_image(img_path)
    mask = cv2.imread(str(MASKS_DIR / f"{img_path.stem}.png"), cv2.IMREAD_GRAYSCALE)
    
    # Apply patch inpainting (fast version)
    restored = patch_inpaint_fast(damaged_img, mask, patch_size=PATCH_SIZE)
    
    # Save result
    save_image(restored, RESULTS_DIR / f"{img_path.stem}.png")

print(f"\nâœ“ All {len(damaged_files)} images processed and saved to {RESULTS_DIR}")


## 4. Summary

Patch-based restoration complete! All restored images saved to `methods/Patch/results/`.

**Parameters used:**
- Patch size: 9x9
- Method: Fast patch inpainting (OpenCV initialization + patch refinement)

The patch-based method works well for textured regions and can handle larger damaged areas by finding similar patterns in the image.
