In [None]:
import cv2
import numpy as np
from dataclasses import dataclass
from time import time
from typing import Tuple, Optional
from pathlib import Path
import numba
from concurrent.futures import ThreadPoolExecutor, Future
import os

@dataclass(frozen=True)
class Config:
    SCALE: float = 0.4
    NFEAT: int = 50000
    MATCH_RATIO: float = 0.65
    KEEP_PERCENT: float = 0.55
    THRESH: float = 250.0
    SIFT_EDGE_THRESHOLD: float = 10.0
    SIFT_CONTRAST_THRESHOLD: float = 0.04
    NUM_THREADS: int = os.cpu_count() if os.cpu_count() else 4
    SEARCH_FIELD_MULTIPLIER: float = 1.6

class ImageStitcher:
    def __init__(self, path: str = "./folder2/"):
        self.path = Path(path)
        if Config.SEARCH_FIELD_MULTIPLIER < 1.0:
            raise ValueError("SEARCH_FIELD_MULTIPLIER must be >= 1.0")
        self.sift = cv2.SIFT_create(
            nfeatures=Config.NFEAT,
            contrastThreshold=Config.SIFT_CONTRAST_THRESHOLD,
            edgeThreshold=Config.SIFT_EDGE_THRESHOLD
        )
        self.matcher = cv2.FlannBasedMatcher(dict(algorithm=1, trees=5), dict(checks=32))
        self.last_match_region: Optional[Tuple[int, int]] = None
        self._gray_buffer1: Optional[np.ndarray] = None
        self._gray_buffer2: Optional[np.ndarray] = None
        self._base_image_size: Optional[Tuple[int, int]] = None
        self._canvas_offset: Tuple[int, int] = (0, 0)
        print(f"Using {Config.NUM_THREADS} worker threads")
        self.thread_pool = ThreadPoolExecutor(max_workers=Config.NUM_THREADS)

    def read_image(self, idx: int) -> Tuple[np.ndarray, dict]:
        timings = {}
        path_start = time()
        img_path = self.path / f"2023_09_01_SonyRX1RM2_g201b20538_f001_{idx:04}.JPG"
        timings['construct_path'] = time() - path_start
        read_start = time()
        img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
        timings['read_from_disk'] = time() - read_start
        resize_start = time()
        img = cv2.resize(img, None, fx=Config.SCALE, fy=Config.SCALE, interpolation=cv2.INTER_LINEAR)
        timings['resize'] = time() - resize_start
        if self._base_image_size is None:
            self._base_image_size = (img.shape[1], img.shape[0])
        rgba_start = time()
        rgba = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8)
        rgba[:, :, :3] = img
        rgba[:, :, 3] = 255
        timings['convert_to_rgba'] = time() - rgba_start
        return rgba, timings

    def get_search_region_and_expand(self, img: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int, int, int], Tuple[int, int]]:
        h, w = img.shape[:2]
        if self.last_match_region is None:
            self.last_match_region = (w // 2, h // 2)
            return img, (0, 0, w, h), (0, 0)
        base_w, base_h = self._base_image_size
        search_w = int(round(base_w * Config.SEARCH_FIELD_MULTIPLIER**2))
        search_h = int(round(base_h * Config.SEARCH_FIELD_MULTIPLIER**2))
        search_w = max(base_w, search_w)
        search_h = max(base_h, search_h)
        cx, cy = self.last_match_region
        x1 = int(round(cx - search_w / 2))
        y1 = int(round(cy - search_h / 2))
        x2 = x1 + search_w
        y2 = y1 + search_h
        pad_left = max(0, -x1)
        pad_right = max(0, x2 - w)
        pad_top = max(0, -y1)
        pad_bottom = max(0, y2 - h)
        max_pad_w = int(round(base_w * Config.SEARCH_FIELD_MULTIPLIER**2))
        max_pad_h = int(round(base_h * Config.SEARCH_FIELD_MULTIPLIER**2))
        pad_left = min(pad_left, max_pad_w)
        pad_right = min(pad_right, max_pad_w)
        pad_top = min(pad_top, max_pad_h)
        pad_bottom = min(pad_bottom, max_pad_h)
        if pad_left or pad_right or pad_top or pad_bottom:
            new_w = w + pad_left + pad_right
            new_h = h + pad_top + pad_bottom
            canvas = np.zeros((new_h, new_w, 4), dtype=np.uint8)
            canvas[pad_top:pad_top + h, pad_left:pad_left + w] = img
            x1 += pad_left
            x2 += pad_left
            y1 += pad_top
            y2 += pad_top
            result_img = canvas
            offset = (pad_left, pad_top)
            self._canvas_offset = (self._canvas_offset[0] + pad_left, self._canvas_offset[1] + pad_top)
        else:
            result_img = img
            offset = (0, 0)
        x1 = max(0, x1)
        y1 = max(0, y1)
        x2 = min(result_img.shape[1], x2)
        y2 = min(result_img.shape[0], y2)
        return result_img, (x1, y1, x2, y2), offset

    def update_last_match_region(self, matrix: np.ndarray, img_shape: Tuple[int, int], 
                                 offset: Tuple[int, int], crop_offset: Optional[Tuple[int, int]] = None):
        h, w = img_shape[:2]
        center = np.array([[w/2], [h/2], [1]], dtype=np.float32)
        transformed = matrix @ center
        new_cx = int(transformed[0, 0]) + offset[0]
        new_cy = int(transformed[1, 0]) + offset[1]
        if crop_offset:
            new_cx -= crop_offset[0]
            new_cy -= crop_offset[1]
        self.last_match_region = (new_cx, new_cy)

    def find_matches(self, img1: np.ndarray, img2: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Tuple[int, int], dict]:
        timings = {}
        img1, (x1, y1, x2, y2), offset = self.get_search_region_and_expand(img1)
        if x2 <= x1 or y2 <= y1:
            raise ValueError(f"Invalid search region: x1={x1}, x2={x2}, y1={y1}, y2={y2}")
        roi_start = time()
        img1_roi = img1[y1:y2, x1:x2]
        timings['extract_roi'] = time() - roi_start
        buffer_start = time()
        if self._gray_buffer1 is None or self._gray_buffer1.shape != img1_roi.shape[:2]:
            self._gray_buffer1 = np.empty(img1_roi.shape[:2], dtype=np.uint8)
        if self._gray_buffer2 is None or self._gray_buffer2.shape != img2.shape[:2]:
            self._gray_buffer2 = np.empty(img2.shape[:2], dtype=np.uint8)
        timings['allocate_buffers'] = time() - buffer_start
        gray_start = time()
        futures = [
            self.thread_pool.submit(cv2.cvtColor, img1_roi, cv2.COLOR_BGRA2GRAY, dst=self._gray_buffer1),
            self.thread_pool.submit(cv2.cvtColor, img2, cv2.COLOR_BGRA2GRAY, dst=self._gray_buffer2)
        ]
        for future in futures:
            future.result()
        timings['convert_to_grayscale'] = time() - gray_start
        def detect_compute(img, sift):
            return sift.detectAndCompute(img, None)
        sift_start = time()
        future_kp1 = self.thread_pool.submit(detect_compute, self._gray_buffer1, self.sift)
        future_kp2 = self.thread_pool.submit(detect_compute, self._gray_buffer2, self.sift)
        kp1, desc1 = future_kp1.result()
        kp2, desc2 = future_kp2.result()
        timings['sift_detection'] = time() - sift_start
        flann_start = time()
        matches = self.matcher.knnMatch(desc1, desc2, k=2)
        timings['flann_matching'] = time() - flann_start
        filter_start = time()
        good = [m for m, n in matches if m.distance < Config.MATCH_RATIO * n.distance]
        good = sorted(good, key=lambda x: x.distance)[:int(len(good) * Config.KEEP_PERCENT)]
        timings['filter_matches'] = time() - filter_start
        points_start = time()
        query_idx = np.array([m.queryIdx for m in good])
        train_idx = np.array([m.trainIdx for m in good])
        pts1 = np.float32([kp.pt for kp in kp1])[query_idx] + [x1, y1]
        pts2 = np.float32([kp.pt for kp in kp2])[train_idx]
        timings['extract_points'] = time() - points_start
        return img1, pts1, pts2, offset, timings

    def align_images(self, pts1: np.ndarray, pts2: np.ndarray, 
                     img_shape: Tuple[int, int], offset: Tuple[int, int]) -> np.ndarray:
        matrix, _ = cv2.estimateAffinePartial2D(
            pts2, pts1,
            method=cv2.RANSAC,
            ransacReprojThreshold=Config.THRESH,
            confidence=0.995,
            maxIters=1000
        )
        if matrix is not None:
            # Extract linear part
            L = matrix[:, :2]
            # Compute scaling factors
            s_x = np.linalg.norm(L[:, 0])
            s_y = np.linalg.norm(L[:, 1])
            s_avg = (s_x + s_y) / 2
            if s_avg > 0:
                # Normalize linear part to remove scaling
                matrix[:, :2] /= s_avg
                print(f"Adjusted scaling factor from {s_avg:.4f} to 1.0")
            else:
                print("Warning: s_avg is zero, cannot adjust scaling")
        else:
            print("Warning: No transformation matrix found")
        return matrix

    @staticmethod
    @numba.jit(nopython=True, parallel=True)
    def overlay_images_numba(base: np.ndarray, overlay: np.ndarray, result: np.ndarray):
        h, w = base.shape[:2]
        for y in numba.prange(h):
            for x in range(w):
                if overlay[y, x, 3] > 0:
                    result[y, x] = overlay[y, x]
                else:
                    result[y, x] = base[y, x]

    def overlay_images(self, base: np.ndarray, overlay: np.ndarray) -> np.ndarray:
        result = np.zeros_like(base, dtype=np.uint8)
        ImageStitcher.overlay_images_numba(base, overlay, result)
        return result

    def crop_result(self, img: np.ndarray) -> Tuple[np.ndarray, Tuple[int, int]]:
        mask = img[:, :, 3] > 0
        rows = np.any(mask, axis=1)
        cols = np.any(mask, axis=0)
        y1, y2 = np.where(rows)[0][[0, -1]]
        x1, x2 = np.where(cols)[0][[0, -1]]
        crop_offset = (x1, y1)
        return img[y1:y2+1, x1:x2+1], crop_offset

    def stitch(self, start: int = 3, end: int = 100):
        print(f"\nStarting image stitching process from {start} to {end}")
        start_time = time()
        result, read_timings = self.read_image(start)
        print(f"  - Read initial image: {time() - start_time:.3f}s")
        expand_start = time()
        result, _, offset = self.get_search_region_and_expand(result)
        expand_time = time() - expand_start
        print(f"  - Expand initial image: {expand_time:.3f}s")
        future: Optional[Future] = None
        for idx in range(start + 1, end + 1):
            iter_start = time()
            print(f"\nProcessing image {idx}/{end}")
            if future is not None:
                current, read_timings = future.result()
            else:
                current, read_timings = self.read_image(idx)
            if idx < end:
                future = self.thread_pool.submit(self.read_image, idx + 1)
            match_start = time()
            result, pts1, pts2, offset, match_timings = self.find_matches(result, current)
            match_time = time() - match_start
            align_start = time()
            matrix = self.align_images(pts1, pts2, current.shape, offset)
            align_time = time() - align_start
            warp_start = time()
            aligned = cv2.warpAffine(
                current, matrix,
                (result.shape[1], result.shape[0]),
                flags=cv2.INTER_LINEAR,
                borderMode=cv2.BORDER_TRANSPARENT
            )
            result = self.overlay_images(result, aligned)
            warp_time = time() - warp_start
            crop_start = time()
            result, crop_offset = self.crop_result(result)
            crop_time = time() - crop_start
            update_start = time()
            self.update_last_match_region(matrix, current.shape, offset, crop_offset)
            update_time = time() - update_start
            print(f"  - Find matches: {match_time:.3f}s")
            print(f"    - Extract ROI: {match_timings['extract_roi']:.3f}s")
            print(f"    - Allocate buffers: {match_timings['allocate_buffers']:.3f}s")
            print(f"    - Convert to grayscale: {match_timings['convert_to_grayscale']:.3f}s")
            print(f"    - SIFT detection: {match_timings['sift_detection']:.3f}s")
            print(f"    - FLANN matching: {match_timings['flann_matching']:.3f}s")
            print(f"    - Filter matches: {match_timings['filter_matches']:.3f}s")
            print(f"    - Extract points: {match_timings['extract_points']:.3f}s")
            print(f"  - Align images: {align_time:.3f}s")
            print(f"  - Warp and overlay: {warp_time:.3f}s")
            print(f"  - Crop result: {crop_time:.3f}s")
            print(f"  - Update match region: {update_time:.3f}s")
            print(f"[Total for image {idx}: {time() - iter_start:.3f}s]")
            if idx % 10 == 0:
                cv2.imwrite(f"result_{idx}.png", result)
                print(f"  - Saved result_{idx}.png")
        total_time = time() - start_time
        print(f"\nStitching completed in {total_time:.3f}s")

    def __del__(self):
        self.thread_pool.shutdown()

if __name__ == "__main__":
    stitcher = ImageStitcher()
    stitcher.stitch()

Using 16 worker threads

Starting image stitching process from 3 to 100
  - Read initial image: 0.345s
  - Expand initial image: 0.000s

Processing image 4/100
Adjusted scaling factor from 0.9907 to 1.0
  - Find matches: 8.782s
    - Extract ROI: 0.000s
    - Allocate buffers: 0.000s
    - Convert to grayscale: 0.017s
    - SIFT detection: 7.005s
    - FLANN matching: 1.685s
    - Filter matches: 0.007s
    - Extract points: 0.049s
  - Align images: 0.005s
  - Warp and overlay: 1.585s
  - Crop result: 0.036s
  - Update match region: 0.001s
[Total for image 4: 10.704s]

Processing image 5/100
Adjusted scaling factor from 0.9581 to 1.0
  - Find matches: 9.689s
    - Extract ROI: 0.000s
    - Allocate buffers: 0.000s
    - Convert to grayscale: 0.011s
    - SIFT detection: 7.817s
    - FLANN matching: 1.731s
    - Filter matches: 0.007s
    - Extract points: 0.098s
  - Align images: 0.001s
  - Warp and overlay: 0.107s
  - Crop result: 0.043s
  - Update match region: 0.000s
[Total for imag