In [2]:
import cv2
import numpy as np
from dataclasses import dataclass
from time import time
from typing import Tuple, Optional
from pathlib import Path
import numba
from concurrent.futures import ThreadPoolExecutor
import os

@dataclass(frozen=True)
class Config:
    SCALE: float = 0.25
    NFEAT: int = 100000
    MATCH_RATIO: float = 0.7
    KEEP_PERCENT: float = 0.75
    THRESH: float = 250.0
    SIFT_EDGE_THRESHOLD: float = 10.0
    SIFT_CONTRAST_THRESHOLD: float = 0.04
    NUM_THREADS: int = os.cpu_count() if os.cpu_count() else 4
    SEARCH_FIELD_MULTIPLIER: float = 1.6  # Multiplier for search field size in each dimension
                                          # Results in (SEARCH_FIELD_MULTIPLIER^2)x larger search area
                                          # Also determines maximum canvas expansion when needed
                                          # Must be >= 1.0

class ImageStitcher:
    def __init__(self, path: str = "./folder2/"):
        self.path = Path(path)
        
        # Validate SEARCH_FIELD_MULTIPLIER
        if Config.SEARCH_FIELD_MULTIPLIER < 1.0:
            raise ValueError("SEARCH_FIELD_MULTIPLIER must be >= 1.0")
            
        self.sift = cv2.SIFT_create(
            nfeatures=Config.NFEAT,
            contrastThreshold=Config.SIFT_CONTRAST_THRESHOLD,
            edgeThreshold=Config.SIFT_EDGE_THRESHOLD
        )
        self.matcher = cv2.FlannBasedMatcher(
            dict(algorithm=1, trees=5),
            dict(checks=32)
        )
        self.last_match_region: Optional[Tuple[int, int]] = None
        self._gray_buffer1: Optional[np.ndarray] = None
        self._gray_buffer2: Optional[np.ndarray] = None
        self._base_image_size: Optional[Tuple[int, int]] = None  # Store base image size after scaling
        self._canvas_offset: Tuple[int, int] = (0, 0)  # Track cumulative offset from expansions
        print(f"Using {Config.NUM_THREADS} worker threads (all available cores)")
        print(f"Search field area multiplier: {Config.SEARCH_FIELD_MULTIPLIER}x each dimension "
              f"({Config.SEARCH_FIELD_MULTIPLIER**2:.2f}x total area)")
        self.thread_pool = ThreadPoolExecutor(max_workers=Config.NUM_THREADS)

    def read_image(self, idx: int) -> np.ndarray:
        img_path = self.path / f"2023_09_01_SonyRX1RM2_g201b20538_f001_{idx:04}.JPG"
        img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
        img = cv2.resize(img, None, fx=Config.SCALE, fy=Config.SCALE,
                        interpolation=cv2.INTER_AREA)
        
        # Store base image size if not already set
        if self._base_image_size is None:
            self._base_image_size = (img.shape[1], img.shape[0])  # (width, height)
        
        rgba = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8)
        rgba[:, :, :3] = img
        rgba[:, :, 3] = 255
        return rgba

    def get_search_region_and_expand(self, img: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int, int, int], Tuple[int, int]]:
        """
        Calculate search region and expand image if necessary.
        Returns expanded image, search region coordinates, and new offset.
        Expansion is based on SEARCH_FIELD_MULTIPLIER to ensure consistency.
        """
        h, w = img.shape[:2]
        
        if self.last_match_region is None:
            self.last_match_region = (w // 2, h // 2)  # Initialize to center
            return img, (0, 0, w, h), (0, 0)

        # Get base image dimensions after scaling
        base_w, base_h = self._base_image_size
        
        # Calculate desired search window dimensions
        search_w = int(round(base_w * Config.SEARCH_FIELD_MULTIPLIER**2))
        search_h = int(round(base_h * Config.SEARCH_FIELD_MULTIPLIER**2))
        search_w = max(base_w, search_w)  # Ensure search region is at least base size
        search_h = max(base_h, search_h)

        # Adjust last match region for current canvas offset
        cx, cy = self.last_match_region

        # Calculate desired search region (before boundary check)
        x1 = int(round(cx - search_w / 2))
        y1 = int(round(cy - search_h / 2))
        x2 = x1 + search_w
        y2 = y1 + search_h

        # Check if search region extends beyond image boundaries
        pad_left = max(0, -x1)
        pad_right = max(0, x2 - w)
        pad_top = max(0, -y1)
        pad_bottom = max(0, y2 - h)

        # Maximum allowed expansion based on SEARCH_FIELD_MULTIPLIER
        # Corrected to use multiplier^2 for consistency with search area
        max_pad_w = int(round(base_w * Config.SEARCH_FIELD_MULTIPLIER**2))
        max_pad_h = int(round(base_h * Config.SEARCH_FIELD_MULTIPLIER**2))

        if pad_left or pad_right or pad_top or pad_bottom:
            # Limit padding to maximum allowed
            pad_left = min(pad_left, max_pad_w)
            pad_right = min(pad_right, max_pad_w)
            pad_top = min(pad_top, max_pad_h)
            pad_bottom = min(pad_bottom, max_pad_h)

            # Create new canvas with padding
            new_w = w + pad_left + pad_right
            new_h = h + pad_top + pad_bottom
            canvas = np.zeros((new_h, new_w, 4), dtype=np.uint8)
            canvas[pad_top:pad_top + h, pad_left:pad_left + w] = img
            
            # Update coordinates for new canvas
            x1 += pad_left
            x2 += pad_left
            y1 += pad_top
            y2 += pad_top
            
            result_img = canvas
            offset = (pad_left, pad_top)
            
            # Update canvas offset
            self._canvas_offset = (self._canvas_offset[0] + pad_left, 
                                 self._canvas_offset[1] + pad_top)
        else:
            result_img = img
            offset = (0, 0)

        # Clip search region to image boundaries
        x1 = max(0, x1)
        y1 = max(0, y1)
        x2 = min(result_img.shape[1], x2)
        y2 = min(result_img.shape[0], y2)

        return result_img, (x1, y1, x2, y2), offset

    def update_last_match_region(self, matrix: np.ndarray, img_shape: Tuple[int, int], 
                               offset: Tuple[int, int], crop_offset: Optional[Tuple[int, int]] = None):
        h, w = img_shape[:2]
        center = np.array([[w/2], [h/2], [1]], dtype=np.float32)
        transformed = matrix @ center
        
        # Adjust for expansion offset
        new_cx = int(transformed[0, 0]) + offset[0]
        new_cy = int(transformed[1, 0]) + offset[1]
        
        # Adjust for cropping if applicable
        if crop_offset:
            new_cx -= crop_offset[0]
            new_cy -= crop_offset[1]
            
        self.last_match_region = (new_cx, new_cy)

    def find_matches(self, img1: np.ndarray, img2: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Tuple[int, int]]:
        img1, (x1, y1, x2, y2), offset = self.get_search_region_and_expand(img1)
        
        # Verify that slice indices are valid
        if x2 <= x1 or y2 <= y1:
            raise ValueError(f"Invalid search region: x1={x1}, x2={x2}, y1={y1}, y2={y2}")
            
        img1_roi = img1[y1:y2, x1:x2]

        if (self._gray_buffer1 is None or 
            self._gray_buffer1.shape != img1_roi.shape[:2]):
            self._gray_buffer1 = np.empty(img1_roi.shape[:2], dtype=np.uint8)
        if (self._gray_buffer2 is None or 
            self._gray_buffer2.shape != img2.shape[:2]):
            self._gray_buffer2 = np.empty(img2.shape[:2], dtype=np.uint8)

        futures = [
            self.thread_pool.submit(cv2.cvtColor, img1_roi, cv2.COLOR_BGRA2GRAY, 
                                  dst=self._gray_buffer1),
            self.thread_pool.submit(cv2.cvtColor, img2, cv2.COLOR_BGRA2GRAY, 
                                  dst=self._gray_buffer2)
        ]
        for future in futures:
            future.result()

        def detect_compute(img, sift):
            return sift.detectAndCompute(img, None)
        
        future_kp1 = self.thread_pool.submit(detect_compute, self._gray_buffer1, self.sift)
        future_kp2 = self.thread_pool.submit(detect_compute, self._gray_buffer2, self.sift)
        kp1, desc1 = future_kp1.result()
        kp2, desc2 = future_kp2.result()

        matches = self.matcher.knnMatch(desc1, desc2, k=2)
        good = [m for m, n in matches if m.distance < Config.MATCH_RATIO * n.distance]
        good = sorted(good, key=lambda x: x.distance)[:int(len(good) * Config.KEEP_PERCENT)]

        query_idx = np.array([m.queryIdx for m in good])
        train_idx = np.array([m.trainIdx for m in good])
        pts1 = np.float32([kp.pt for kp in kp1])[query_idx] + [x1, y1]
        pts2 = np.float32([kp.pt for kp in kp2])[train_idx]
        return img1, pts1, pts2, offset

    def align_images(self, pts1: np.ndarray, pts2: np.ndarray, 
                    img_shape: Tuple[int, int], offset: Tuple[int, int]) -> np.ndarray:
        matrix, _ = cv2.estimateAffinePartial2D(
            pts2, pts1,
            method=cv2.RANSAC,
            ransacReprojThreshold=Config.THRESH,
            confidence=0.995,
            maxIters=1000
        )
        return matrix

    @staticmethod
    @numba.jit(nopython=True, parallel=True)
    def overlay_images_numba(base: np.ndarray, overlay: np.ndarray, result: np.ndarray):
        h, w = base.shape[:2]
        for y in numba.prange(h):
            for x in range(w):
                if overlay[y, x, 3] > 0:
                    result[y, x] = overlay[y, x]
                else:
                    result[y, x] = base[y, x]

    def overlay_images(self, base: np.ndarray, overlay: np.ndarray) -> np.ndarray:
        result = np.zeros_like(base, dtype=np.uint8)
        ImageStitcher.overlay_images_numba(base, overlay, result)
        return result

    def crop_result(self, img: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]:
        mask = img[:, :, 3] > 0
        rows = np.any(mask, axis=1)
        cols = np.any(mask, axis=0)
        y1, y2 = np.where(rows)[0][[0, -1]]
        x1, x2 = np.where(cols)[0][[0, -1]]
        crop_offset = (x1, y1)
        return img[y1:y2+1, x1:x2+1], crop_offset

    def stitch(self, start: int = 3, end: int = 21):
        print(f"\nStarting image stitching process from {start} to {end}")
        start_time = time()
        result = self.read_image(start)
        total_images = end - start - 1
        processed = 0

        for idx in range(start, end):
            iter_start = time()
            print(f"\nProcessing image {idx}/{end-1} ", end="")

            current = self.read_image(idx)
            result, pts1, pts2, offset = self.find_matches(result, current)
            matrix = self.align_images(pts1, pts2, current.shape, offset)
            aligned = cv2.warpAffine(
                current, matrix,
                (result.shape[1], result.shape[0]),
                flags=cv2.INTER_LINEAR,
                borderMode=cv2.BORDER_TRANSPARENT
            )
            result = self.overlay_images(result, aligned)
            result, crop_offset = self.crop_result(result)
            
            # Update last match region with both expansion and crop offsets
            self.update_last_match_region(matrix, current.shape, offset, crop_offset)
            
            if idx == 20:
                cv2.imwrite(f"result_{idx}.png", result)
                print(f"[Saved result_{idx}.png]", end="")

            processed += 1
            print(f"[{time() - iter_start:.1f}s]")

        total_time = time() - start_time
        print(f"\nStitching completed:")
        print(f"Total images processed: {processed}/{total_images}")
        print(f"Total time: {total_time:.1f}s "
              f"(avg {total_time/total_images:.1f}s per image)")

    def __del__(self):
        self.thread_pool.shutdown()

if __name__ == "__main__":
    stitcher = ImageStitcher()
    stitcher.stitch()

Using 16 worker threads (all available cores)
Search field area multiplier: 1.6x each dimension (2.56x total area)

Starting image stitching process from 3 to 21

Processing image 3/20 [4.5s]

Processing image 4/20 [4.9s]

Processing image 5/20 [5.0s]

Processing image 6/20 [5.4s]

Processing image 7/20 [5.6s]

Processing image 8/20 [5.9s]

Processing image 9/20 [6.0s]

Processing image 10/20 [5.8s]

Processing image 11/20 [5.9s]

Processing image 12/20 [5.6s]

Processing image 13/20 [6.0s]

Processing image 14/20 [5.6s]

Processing image 15/20 [6.0s]

Processing image 16/20 [5.9s]

Processing image 17/20 [5.9s]

Processing image 18/20 [5.9s]

Processing image 19/20 [5.9s]

Processing image 20/20 [Saved result_20.png][6.8s]

Stitching completed:
Total images processed: 18/17
Total time: 102.8s (avg 6.0s per image)
