In [None]:
import cv2
import numpy as np
from dataclasses import dataclass
from time import time

@dataclass
class Config:
    KEEP_PERCENT: float = 0.75
    THRESH: float = 250.0
    NFEAT: int = 100000
    WIDEN: int = 4000
    SCALE: float = 0.15
    MATCH_RATIO: float = 0.7
    MIN_INLIERS: float = 0.01
    # Search window size in pixels (after scaling)
    SEARCH_WINDOW: int = 1000
    # Overlap margin to ensure we don't miss matches
    OVERLAP_MARGIN: int = 200

class ImageStitcher:
    def __init__(self, path: str = "./folder2/"):
        self.path = path
        self.sift = cv2.SIFT_create(nfeatures=Config.NFEAT)
        self.matcher = cv2.FlannBasedMatcher(dict(algorithm=1, trees=5), dict(checks=50))
        self.last_match_region = None

    def read_image(self, idx: int) -> np.ndarray:
        img = cv2.imread(f"{self.path}2023_09_01_SonyRX1RM2_g201b20538_f001_{idx:04}.JPG")
        if img is None:
            raise FileNotFoundError(f"Cannot read image {idx}")
        
        img = cv2.resize(img, None, fx=Config.SCALE, fy=Config.SCALE)
        rgba = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA)
        rgba[:, :, 3] = 255
        return rgba

    def get_search_region(self, img: np.ndarray) -> tuple:
        """Determine the search region based on the last successful match"""
        h, w = img.shape[:2]
        
        if self.last_match_region is None:
            # For the first match, use the right half of the image
            x1 = max(0, w // 2 - Config.OVERLAP_MARGIN)
            y1 = 0
            x2 = w
            y2 = h
        else:
            # Use the region around the last match
            x1 = max(0, self.last_match_region[0] - Config.SEARCH_WINDOW // 2)
            y1 = max(0, self.last_match_region[1] - Config.SEARCH_WINDOW // 2)
            x2 = min(w, x1 + Config.SEARCH_WINDOW)
            y2 = min(h, y1 + Config.SEARCH_WINDOW)
            
            # Add overlap margin
            x1 = max(0, x1 - Config.OVERLAP_MARGIN)
            y1 = max(0, y1 - Config.OVERLAP_MARGIN)
            x2 = min(w, x2 + Config.OVERLAP_MARGIN)
            y2 = min(h, y2 + Config.OVERLAP_MARGIN)
        
        return (x1, y1, x2, y2)

    def update_last_match_region(self, matrix: np.ndarray, img_shape: tuple):
        """Update the last match region based on the transformation matrix"""
        # Calculate the center point of the transformed image
        h, w = img_shape[:2]
        center = np.array([[w/2, h/2, 1]], dtype=np.float32).T
        transformed_center = matrix @ center
        
        self.last_match_region = (int(transformed_center[0]), int(transformed_center[1]))

    def find_matches(self, img1: np.ndarray, img2: np.ndarray):
        # Get the search region for the first image
        x1, y1, x2, y2 = self.get_search_region(img1)
        
        # Extract the region of interest from img1
        img1_roi = img1[y1:y2, x1:x2]
        
        # Convert to grayscale
        gray1 = cv2.cvtColor(img1_roi, cv2.COLOR_BGRA2GRAY)
        gray2 = cv2.cvtColor(img2, cv2.COLOR_BGRA2GRAY)
        
        # Detect features
        kp1, desc1 = self.sift.detectAndCompute(gray1, None)
        kp2, desc2 = self.sift.detectAndCompute(gray2, None)
        
        if len(kp1) == 0 or len(kp2) == 0:
            raise ValueError("No features detected in one or both images")
        
        # Match features
        matches = self.matcher.knnMatch(desc1, desc2, k=2)
        good = [m for m, n in matches if m.distance < Config.MATCH_RATIO * n.distance]
        good = sorted(good, key=lambda x: x.distance)[:int(len(good) * Config.KEEP_PERCENT)]
        
        if len(good) < 4:
            raise ValueError("Not enough good matches found")
        
        # Adjust keypoint coordinates for the ROI offset
        pts1 = np.float32([kp1[m.queryIdx].pt for m in good])
        pts1 += [x1, y1]  # Add offset to coordinates
        pts2 = np.float32([kp2[m.trainIdx].pt for m in good])
        
        return pts1, pts2

    def align_images(self, pts1: np.ndarray, pts2: np.ndarray, img_shape: tuple):
        matrix, inliers = cv2.estimateAffinePartial2D(pts2, pts1, method=cv2.RANSAC,
                                                     ransacReprojThreshold=Config.THRESH)
        
        if matrix is None or np.count_nonzero(inliers) < Config.MIN_INLIERS * len(pts1):
            raise ValueError("Not enough matches")
        
        # Update the last match region
        self.update_last_match_region(matrix, img_shape)
        
        return matrix

    def widen_image(self, img: np.ndarray) -> np.ndarray:
        h, w = img.shape[:2]
        pad = int(Config.WIDEN * Config.SCALE)
        canvas = np.zeros((h + pad, w + pad, 4), dtype=np.uint8)
        offset = pad // 2
        canvas[offset:offset + h, offset:offset + w] = img
        return canvas

    def blend_images(self, img1: np.ndarray, img2: np.ndarray) -> np.ndarray:
        result = np.zeros_like(img1)
        a1 = img1[:, :, 3].astype(float) / 255
        a2 = img2[:, :, 3].astype(float) / 255
        a_out = a2 + a1 * (1 - a2)
        mask = a_out > 0

        for c in range(3):
            result[mask, c] = (img2[mask, c] * a2[mask] + 
                             img1[mask, c] * a1[mask] * (1 - a2[mask])) / a_out[mask]
        
        result[:, :, 3] = (a_out * 255).astype(np.uint8)
        return result

    def crop_result(self, img: np.ndarray) -> np.ndarray:
        coords = np.argwhere(img[:, :, 3] > 0)
        if len(coords) == 0:
            return img
        y1, x1 = coords.min(axis=0)
        y2, x2 = coords.max(axis=0) + 1
        return img[y1:y2, x1:x2].copy()

    def stitch(self, start: int = 3, end: int = 36):
        print(f"\nStarting image stitching process from {start} to {end}")
        start_time = time()
        result = self.read_image(start)
        total_images = end - start - 1
        processed = 0
        failures = 0
        
        for idx in range(start + 1, end):
            try:
                iter_start = time()
                print(f"\nProcessing image {idx}/{end-1} ", end="")
                
                # Read and prepare images
                current = self.read_image(idx)
                current = self.widen_image(current)
                result = self.widen_image(result)

                # Align images
                pts1, pts2 = self.find_matches(result, current)
                matrix = self.align_images(pts1, pts2, current.shape)
                
                # Warp and blend
                aligned = cv2.warpAffine(current, matrix, (result.shape[1], result.shape[0]),
                                       flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_TRANSPARENT)
                result = self.blend_images(result, aligned)
                result = self.crop_result(result)

                # Save intermediate results
                if idx == 35:
                    cv2.imwrite(f"result_{idx}.png", result)
                    print(f"[Saved result_{idx}.png]", end="")
                
                processed += 1
                print(f"[{time() - iter_start:.1f}s]")
                
            except Exception as e:
                failures += 1
                print(f"\nError at image {idx}: {e}")
                continue
        
        total_time = time() - start_time
        print(f"\nStitching completed:")
        print(f"Total images processed: {processed}/{total_images} ({failures} failed)")
        print(f"Total time: {total_time:.1f}s (avg {total_time/total_images:.1f}s per image)")

if __name__ == "__main__":
    stitcher = ImageStitcher()
    stitcher.stitch()



Starting image stitching process from 3 to 36

Processing image 4/35 

  self.last_match_region = (int(transformed_center[0]), int(transformed_center[1]))


[1.8s]

Processing image 5/35 [2.1s]

Processing image 6/35 [2.2s]

Processing image 7/35 [2.2s]

Processing image 8/35 [2.4s]

Processing image 9/35 [2.4s]

Processing image 10/35 [2.4s]

Processing image 11/35 [2.6s]

Processing image 12/35 [2.4s]

Processing image 13/35 [2.5s]

Processing image 14/35 [2.5s]

Processing image 15/35 