In [2]:
import cv2
import numpy as np
from dataclasses import dataclass
from time import time
from typing import Tuple, Optional
from pathlib import Path

@dataclass(frozen=True)
class Config:
    # Image processing parameters
    SCALE: float = 0.25
    WIDEN: int = 3000
    SEARCH_WINDOW: int = 2000
    OVERLAP_MARGIN: int = 500
    
    # Feature detection parameters
    NFEAT: int = 100000
    MATCH_RATIO: float = 0.7
    KEEP_PERCENT: float = 0.75
    THRESH: float = 250.0
    MIN_INLIERS: float = 0.01

class ImageStitcher:
    def __init__(self, path: str = "./folder2/"):
        self.path = Path(path)
        self.sift = cv2.SIFT_create(nfeatures=Config.NFEAT)
        self.matcher = cv2.FlannBasedMatcher(
            dict(algorithm=1, trees=5),
            dict(checks=50)
        )
        self.last_match_region: Optional[Tuple[int, int]] = None
        
        # Pre-allocate reusable buffers
        self._gray_buffer1: Optional[np.ndarray] = None
        self._gray_buffer2: Optional[np.ndarray] = None

    def read_image(self, idx: int) -> np.ndarray:
        """Read and preprocess image with given index."""
        img_path = self.path / f"2023_09_01_SonyRX1RM2_g201b20538_f001_{idx:04}.JPG"
        img = cv2.imread(str(img_path))
        if img is None:
            raise FileNotFoundError(f"Cannot read image {idx} at {img_path}")

        # Resize image
        img = cv2.resize(img, None, fx=Config.SCALE, fy=Config.SCALE, 
                        interpolation=cv2.INTER_LINEAR)
        
        # Convert to BGRA
        rgba = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA)
        rgba[:, :, 3] = 255
        return rgba

    def get_search_region(self, img: np.ndarray) -> Tuple[int, int, int, int]:
        """Determine the search region based on the last successful match."""
        h, w = img.shape[:2]
        
        if self.last_match_region is None:
            return 0, 0, w, h

        cx, cy = self.last_match_region
        half_window = Config.SEARCH_WINDOW // 2
        
        # Calculate ROI with bounds checking
        x1 = max(0, cx - half_window - Config.OVERLAP_MARGIN)
        y1 = max(0, cy - half_window - Config.OVERLAP_MARGIN)
        x2 = min(w, cx + half_window + Config.OVERLAP_MARGIN)
        y2 = min(h, cy + half_window + Config.OVERLAP_MARGIN)
        
        return x1, y1, x2, y2

    def update_last_match_region(self, matrix: np.ndarray, img_shape: Tuple[int, int]):
        """Update the last match region based on the transformation matrix."""
        h, w = img_shape[:2]
        center = np.array([[w/2], [h/2], [1]], dtype=np.float32)
        transformed = matrix @ center
        # Extract scalar values using item() to avoid deprecation warning
        self.last_match_region = (int(transformed[0].item()), int(transformed[1].item()))

    def find_matches(self, img1: np.ndarray, img2: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """Find feature matches between two images."""
        # Get search region and extract ROI
        x1, y1, x2, y2 = self.get_search_region(img1)
        img1_roi = img1[y1:y2, x1:x2]

        # Reuse grayscale buffers if possible
        if (self._gray_buffer1 is None or 
            self._gray_buffer1.shape != img1_roi.shape[:2]):
            self._gray_buffer1 = cv2.cvtColor(img1_roi, cv2.COLOR_BGRA2GRAY)
        else:
            cv2.cvtColor(img1_roi, cv2.COLOR_BGRA2GRAY, 
                        dst=self._gray_buffer1)

        if (self._gray_buffer2 is None or 
            self._gray_buffer2.shape != img2.shape[:2]):
            self._gray_buffer2 = cv2.cvtColor(img2, cv2.COLOR_BGRA2GRAY)
        else:
            cv2.cvtColor(img2, cv2.COLOR_BGRA2GRAY, 
                        dst=self._gray_buffer2)

        # Detect and compute features
        kp1, desc1 = self.sift.detectAndCompute(self._gray_buffer1, None)
        kp2, desc2 = self.sift.detectAndCompute(self._gray_buffer2, None)

        if not kp1 or not kp2:
            raise ValueError("No features detected in one or both images")

        # Match features
        matches = self.matcher.knnMatch(desc1, desc2, k=2)
        good = [m for m, n in matches if m.distance < Config.MATCH_RATIO * n.distance]
        if len(good) < 4:
            raise ValueError("Not enough good matches found")

        good = sorted(good, key=lambda x: x.distance)[:int(len(good) * Config.KEEP_PERCENT)]

        # Adjust keypoint coordinates
        pts1 = np.float32([kp1[m.queryIdx].pt for m in good]) + [x1, y1]
        pts2 = np.float32([kp2[m.trainIdx].pt for m in good])
        
        return pts1, pts2

    def align_images(self, pts1: np.ndarray, pts2: np.ndarray, 
                    img_shape: Tuple[int, int]) -> np.ndarray:
        """Align images using feature matches."""
        matrix, inliers = cv2.estimateAffinePartial2D(
            pts2, pts1,
            method=cv2.RANSAC,
            ransacReprojThreshold=Config.THRESH
        )

        if matrix is None or np.count_nonzero(inliers) < Config.MIN_INLIERS * len(pts1):
            raise ValueError("Not enough inliers for alignment")

        self.update_last_match_region(matrix, img_shape)
        return matrix

    def widen_image(self, img: np.ndarray) -> np.ndarray:
        """Expand image canvas with padding."""
        h, w = img.shape[:2]
        pad = int(Config.WIDEN * Config.SCALE)
        canvas = np.zeros((h + pad, w + pad, 4), dtype=np.uint8)
        offset = pad // 2
        canvas[offset:offset + h, offset:offset + w] = img
        return canvas

    def blend_images(self, img1: np.ndarray, img2: np.ndarray) -> np.ndarray:
        """Blend two images using alpha channel."""
        result = np.zeros_like(img1)
        a1 = img1[:, :, 3].astype(np.float32) / 255.0
        a2 = img2[:, :, 3].astype(np.float32) / 255.0
        a_out = a2 + a1 * (1 - a2)
        mask = a_out > 0

        for c in range(3):
            result[mask, c] = (
                img2[mask, c] * a2[mask] + 
                img1[mask, c] * a1[mask] * (1 - a2[mask])
            ) / a_out[mask]

        result[:, :, 3] = (a_out * 255).astype(np.uint8)
        return result

    def crop_result(self, img: np.ndarray) -> np.ndarray:
        """Crop image to content bounds."""
        coords = np.argwhere(img[:, :, 3] > 0)
        if len(coords) == 0:
            return img
        y1, x1 = coords.min(axis=0)
        y2, x2 = coords.max(axis=0) + 1
        return img[y1:y2, x1:x2].copy()

    def stitch(self, start: int = 3, end: int = 21):
        """Stitch images in the given range."""
        print(f"\nStarting image stitching process from {start} to {end}")
        start_time = time()
        result = self.read_image(start)
        total_images = end - start - 1
        processed = 0
        failures = 0

        for idx in range(start + 1, end):
            try:
                iter_start = time()
                print(f"\nProcessing image {idx}/{end-1} ", end="")

                # Read and prepare images
                current = self.read_image(idx)
                current = self.widen_image(current)
                result = self.widen_image(result)

                # Align and blend
                pts1, pts2 = self.find_matches(result, current)
                matrix = self.align_images(pts1, pts2, current.shape)
                aligned = cv2.warpAffine(
                    current, matrix,
                    (result.shape[1], result.shape[0]),
                    flags=cv2.INTER_LINEAR,
                    borderMode=cv2.BORDER_TRANSPARENT
                )
                result = self.blend_images(result, aligned)
                result = self.crop_result(result)

                # Save intermediate results
                if idx == 20:
                    cv2.imwrite(f"result_{idx}.png", result)
                    print(f"[Saved result_{idx}.png]", end="")

                processed += 1
                print(f"[{time() - iter_start:.1f}s]")

            except Exception as e:
                failures += 1
                print(f"\nError at image {idx}: {str(e)}")
                continue

        total_time = time() - start_time
        print(f"\nStitching completed:")
        print(f"Total images processed: {processed}/{total_images} ({failures} failed)")
        print(f"Total time: {total_time:.1f}s "
              f"(avg {total_time/total_images:.1f}s per image)")

if __name__ == "__main__":
    stitcher = ImageStitcher()
    stitcher.stitch()


Starting image stitching process from 3 to 21

Processing image 4/20 [4.8s]

Processing image 5/20 [5.5s]

Processing image 6/20 [5.9s]

Processing image 7/20 [6.6s]

Processing image 8/20 [6.2s]

Processing image 9/20 [6.2s]

Processing image 10/20 [6.7s]

Processing image 11/20 [6.5s]

Processing image 12/20 [6.5s]

Processing image 13/20 [6.4s]

Processing image 14/20 [6.6s]

Processing image 15/20 [6.9s]

Processing image 16/20 [8.0s]

Processing image 17/20 [8.1s]

Processing image 18/20 [8.3s]

Processing image 19/20 [8.3s]

Processing image 20/20 [Saved result_20.png][9.2s]

Stitching completed:
Total images processed: 17/17 (0 failed)
Total time: 117.0s (avg 6.9s per image)
