In [1]:
import cv2
import numpy as np
from dataclasses import dataclass
from time import time
from typing import Tuple, Optional
from pathlib import Path
import numba
from concurrent.futures import ThreadPoolExecutor, Future
import os
import psutil  # For memory monitoring

# Configuration settings for the stitching process
@dataclass(frozen=True)
class Config:
    SCALE: float = 0.20  # Initial scale for processing
    NFEAT: int = 50000
    MATCH_RATIO: float = 0.75
    KEEP_PERCENT: float = 0.75
    THRESH: float = 250.0
    SIFT_EDGE_THRESHOLD: float = 10.0
    SIFT_CONTRAST_THRESHOLD: float = 0.04
    NUM_THREADS: int = os.cpu_count() if os.cpu_count() else 4
    SEARCH_FIELD_MULTIPLIER: float = 2.4
    DOWNSCALE_FACTOR: float = 0.5  # Downscale by 2x before saving

class TileManager:
    """Manages a tiled representation of an image with caching to disk."""
    def __init__(self, tile_size=1000, cache_dir="cache"):
        self.tile_size = tile_size
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)
        self.tiles = {}  # (tile_x, tile_y) -> np.ndarray or None
        self.active_tiles = set()  # Tiles with non-zero alpha
        self.tile_bboxes = {}  # Bounding boxes within each tile

    def get_tile(self, tile_x: int, tile_y: int) -> Optional[np.ndarray]:
        """Retrieve a tile, loading from disk if necessary."""
        if (tile_x, tile_y) in self.tiles:
            return self.tiles[(tile_x, tile_y)]
        file_path = self.cache_dir / f"tile_{tile_x}_{tile_y}.png"
        if file_path.exists():
            tile = cv2.imread(str(file_path), cv2.IMREAD_UNCHANGED)
            self.tiles[(tile_x, tile_y)] = tile
            return tile
        return None

    def set_tile(self, tile_x: int, tile_y: int, data: np.ndarray):
        """Set a tile and update its bounding box if it has non-zero alpha."""
        self.tiles[(tile_x, tile_y)] = data
        mask = data[:, :, 3] > 0
        if np.any(mask):
            rows = np.any(mask, axis=1)
            cols = np.any(mask, axis=0)
            ty1, ty2 = np.where(rows)[0][[0, -1]]
            tx1, tx2 = np.where(cols)[0][[0, -1]]
            self.tile_bboxes[(tile_x, tile_y)] = (tx1, ty1, tx2 + 1, ty2 + 1)
            self.active_tiles.add((tile_x, tile_y))
        else:
            self.tile_bboxes.pop((tile_x, tile_y), None)
            self.active_tiles.discard((tile_x, tile_y))

    def cache_tile(self, tile_x: int, tile_y: int):
        """Cache a tile to disk and remove it from memory."""
        if (tile_x, tile_y) in self.tiles:
            tile = self.tiles[(tile_x, tile_y)]
            if tile is not None:
                file_path = self.cache_dir / f"tile_{tile_x}_{tile_y}.png"
                cv2.imwrite(str(file_path), tile)
                del self.tiles[(tile_x, tile_y)]

    def get_region(self, x1: int, y1: int, x2: int, y2: int) -> np.ndarray:
        """Extract a contiguous region from the tiled image."""
        tile_x1 = x1 // self.tile_size
        tile_y1 = y1 // self.tile_size
        tile_x2 = (x2 - 1) // self.tile_size
        tile_y2 = (y2 - 1) // self.tile_size
        w = (tile_x2 - tile_x1 + 1) * self.tile_size
        h = (tile_y2 - tile_y1 + 1) * self.tile_size
        region = np.zeros((h, w, 4), dtype=np.uint8)
        for ty in range(tile_y1, tile_y2 + 1):
            for tx in range(tile_x1, tile_x2 + 1):
                tile = self.get_tile(tx, ty)
                if tile is None:
                    tile = np.zeros((self.tile_size, self.tile_size, 4), dtype=np.uint8)
                px = (tx - tile_x1) * self.tile_size
                py = (ty - tile_y1) * self.tile_size
                region[py:py + self.tile_size, px:px + self.tile_size] = tile
        ox = x1 - tile_x1 * self.tile_size
        oy = y1 - tile_y1 * self.tile_size
        return region[oy:oy + (y2 - y1), ox:ox + (x2 - x1)]

    def set_region(self, x1: int, y1: int, x2: int, y2: int, data: np.ndarray):
        """Set a region into the tiled image, updating affected tiles."""
        tile_x1 = x1 // self.tile_size
        tile_y1 = y1 // self.tile_size
        tile_x2 = (x2 - 1) // self.tile_size
        tile_y2 = (y2 - 1) // self.tile_size
        for ty in range(tile_y1, tile_y2 + 1):
            for tx in range(tile_x1, tile_x2 + 1):
                tx_start = max(x1, tx * self.tile_size)
                ty_start = max(y1, ty * self.tile_size)
                tx_end = min(x2, (tx + 1) * self.tile_size)
                ty_end = min(y2, (ty + 1) * self.tile_size)
                if tx_start < tx_end and ty_start < ty_end:
                    tile = self.get_tile(tx, ty)
                    if tile is None:
                        tile = np.zeros((self.tile_size, self.tile_size, 4), dtype=np.uint8)
                    tpx_start = tx_start - tx * self.tile_size
                    tpy_start = ty_start - ty * self.tile_size
                    tpx_end = tpx_start + (tx_end - tx_start)
                    tpy_end = tpy_start + (ty_end - ty_start)
                    dx_start = tx_start - x1
                    dy_start = ty_start - y1
                    dx_end = dx_start + (tx_end - tx_start)
                    dy_end = dy_start + (ty_end - ty_start)
                    tile[tpy_start:tpy_end, tpx_start:tpx_end] = data[dy_start:dy_end, dx_start:dx_end]
                    self.set_tile(tx, ty, tile)

    def set_large_image(self, x0: int, y0: int, large_img: np.ndarray):
        """Set a large image into the tiled structure at position (x0, y0)."""
        h, w = large_img.shape[:2]
        tile_x1 = x0 // self.tile_size
        tile_y1 = y0 // self.tile_size
        tile_x2 = (x0 + w - 1) // self.tile_size
        tile_y2 = (y0 + h - 1) // self.tile_size
        for ty in range(tile_y1, tile_y2 + 1):
            for tx in range(tile_x1, tile_x2 + 1):
                tx_start = max(x0, tx * self.tile_size)
                ty_start = max(y0, ty * self.tile_size)
                tx_end = min(x0 + w, (tx + 1) * self.tile_size)
                ty_end = min(y0 + h, (ty + 1) * self.tile_size)
                if tx_start < tx_end and ty_start < ty_end:
                    lx_start = tx_start - x0
                    ly_start = ty_start - y0
                    lx_end = tx_end - x0
                    ly_end = ty_end - y0
                    tile = self.get_tile(tx, ty)
                    if tile is None:
                        tile = np.zeros((self.tile_size, self.tile_size, 4), dtype=np.uint8)
                    tpx_start = tx_start - tx * self.tile_size
                    tpy_start = ty_start - ty * self.tile_size
                    tpx_end = tpx_start + (lx_end - lx_start)
                    tpy_end = tpy_start + (ly_end - ly_start)
                    tile[tpy_start:tpy_end, tpx_start:tpx_end] = large_img[ly_start:ly_end, lx_start:lx_end]
                    self.set_tile(tx, ty, tile)

    def cache_far_tiles(self, center_x: int, center_y: int, threshold: float):
        """Cache tiles farther than the threshold from the center."""
        to_cache = []
        for (tx, ty), tile in list(self.tiles.items()):
            if tile is not None:
                tile_center_x = (tx + 0.5) * self.tile_size
                tile_center_y = (ty + 0.5) * self.tile_size
                distance = np.sqrt((tile_center_x - center_x)**2 + (tile_center_y - center_y)**2)
                if distance > threshold:
                    to_cache.append((tx, ty))
        for tx, ty in to_cache:
            self.cache_tile(tx, ty)

    def get_bounding_box(self) -> Optional[Tuple[int, int, int, int]]:
        """Get the bounding box of all active tiles."""
        if not self.active_tiles:
            return None
        min_x = min((tx * self.tile_size + self.tile_bboxes[(tx, ty)][0]) for tx, ty in self.active_tiles)
        min_y = min((ty * self.tile_size + self.tile_bboxes[(tx, ty)][1]) for tx, ty in self.active_tiles)
        max_x = max((tx * self.tile_size + self.tile_bboxes[(tx, ty)][2]) for tx, ty in self.active_tiles)
        max_y = max((ty * self.tile_size + self.tile_bboxes[(tx, ty)][3]) for tx, ty in self.active_tiles)
        return min_x, min_y, max_x, max_y

class ImageStitcher:
    def __init__(self, path: str = "./folder2/"):
        """Initialize the ImageStitcher with a directory path."""
        self.path = Path(path)
        if Config.SEARCH_FIELD_MULTIPLIER < 1.0:
            raise ValueError("SEARCH_FIELD_MULTIPLIER must be >= 1.0")
        self.sift = cv2.SIFT_create(
            nfeatures=Config.NFEAT,
            contrastThreshold=Config.SIFT_CONTRAST_THRESHOLD,
            edgeThreshold=Config.SIFT_EDGE_THRESHOLD
        )
        self.matcher = cv2.FlannBasedMatcher(dict(algorithm=1, trees=5), dict(checks=61))
        self.last_match_region: Optional[Tuple[int, int]] = None
        self._gray_buffer1: Optional[np.ndarray] = None
        self._gray_buffer2: Optional[np.ndarray] = None
        self._base_image_size: Optional[Tuple[int, int]] = None
        self.tile_manager = TileManager(tile_size=1000, cache_dir="cache")
        print(f"Using {Config.NUM_THREADS} worker threads")
        self.thread_pool = ThreadPoolExecutor(max_workers=Config.NUM_THREADS)
        self.process = psutil.Process(os.getpid())
        self.transformations = []  # List to store (matrix, image) pairs

    def get_memory_usage(self) -> float:
        """Get current memory usage in MB."""
        return self.process.memory_info().rss / 1024**2

    def log_memory(self, message: str):
        """Log memory usage with a message."""
        print(f"{message}: {self.get_memory_usage():.2f}MB")

    def calculate_array_size(self, arr: np.ndarray) -> float:
        """Calculate the size of a numpy array in MB."""
        return arr.nbytes / 1024**2

    def read_image(self, idx: int) -> Tuple[np.ndarray, dict]:
        """Read and preprocess an image from disk."""
        timings = {}
        path_start = time()
        img_path = self.path / f"2023_09_01_SonyRX1RM2_g201b20538_f001_{idx:04}.JPG"
        timings['construct_path'] = time() - path_start
        read_start = time()
        img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
        timings['read_from_disk'] = time() - read_start
        resize_start = time()
        img = cv2.resize(img, None, fx=Config.SCALE, fy=Config.SCALE, interpolation=cv2.INTER_LINEAR)
        timings['resize'] = time() - resize_start
        if self._base_image_size is None:
            self._base_image_size = (img.shape[1], img.shape[0])
            if self.last_match_region is None:
                self.last_match_region = (img.shape[1] // 2, img.shape[0] // 2)
        rgba_start = time()
        rgba = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8)
        rgba[:, :, :3] = img
        rgba[:, :, 3] = 255
        timings['convert_to_rgba'] = time() - rgba_start
        return rgba, timings

    def get_search_region_and_expand(self) -> Tuple[Tuple[int, int, int, int], Tuple[int, int]]:
        """Define the search region in absolute coordinates."""
        if self.last_match_region is None:
            raise ValueError("last_match_region not set")
        cx, cy = self.last_match_region
        base_w, base_h = self._base_image_size
        search_w = int(round(base_w * Config.SEARCH_FIELD_MULTIPLIER**2))
        search_h = int(round(base_h * Config.SEARCH_FIELD_MULTIPLIER**2))
        search_w = max(base_w, search_w)
        search_h = max(base_h, search_h)
        x1 = int(round(cx - search_w / 2))
        y1 = int(round(cy - search_h / 2))
        x2 = x1 + search_w
        y2 = y1 + search_h
        return (x1, y1, x2, y2), (0, 0)  # No padding with TileManager

    def update_last_match_region(self, matrix: np.ndarray, img_shape: Tuple[int, int]):
        """Update the last match region based on the transformation."""
        h, w = img_shape[:2]
        center = np.array([[w/2], [h/2], [1]], dtype=np.float32)
        transformed = matrix @ center
        new_cx = int(round(transformed[0, 0]))
        new_cy = int(round(transformed[1, 0]))
        self.last_match_region = (new_cx, new_cy)

    def find_matches(self, current: np.ndarray) -> Tuple[np.ndarray, np.ndarray, Tuple[int, int], dict]:
        """Find feature matches using the tiled result image."""
        timings = {}
        self.log_memory("  - Memory before find_matches")
        search_region, offset = self.get_search_region_and_expand()
        x1, y1, x2, y2 = search_region
        if x2 <= x1 or y2 <= y1:
            raise ValueError(f"Invalid search region: x1={x1}, x2={x2}, y1={y1}, y2={y2}")
        roi_start = time()
        img1_roi = self.tile_manager.get_region(x1, y1, x2, y2)
        timings['extract_roi'] = time() - roi_start
        self.log_memory("  - Memory after extracting ROI")
        print(f"  - ROI size: {self.calculate_array_size(img1_roi):.2f}MB")

        buffer_start = time()
        if self._gray_buffer1 is None or self._gray_buffer1.shape != img1_roi.shape[:2]:
            self._gray_buffer1 = np.empty(img1_roi.shape[:2], dtype=np.uint8)
        if self._gray_buffer2 is None or self._gray_buffer2.shape != current.shape[:2]:
            self._gray_buffer2 = np.empty(current.shape[:2], dtype=np.uint8)
        timings['allocate_buffers'] = time() - buffer_start

        gray_start = time()
        futures = [
            self.thread_pool.submit(cv2.cvtColor, img1_roi, cv2.COLOR_BGRA2GRAY, dst=self._gray_buffer1),
            self.thread_pool.submit(cv2.cvtColor, current, cv2.COLOR_BGRA2GRAY, dst=self._gray_buffer2)
        ]
        for future in futures:
            future.result()
        timings['convert_to_grayscale'] = time() - gray_start

        def detect_compute(img, sift):
            return sift.detectAndCompute(img, None)
        sift_start = time()
        future_kp1 = self.thread_pool.submit(detect_compute, self._gray_buffer1, self.sift)
        future_kp2 = self.thread_pool.submit(detect_compute, self._gray_buffer2, self.sift)
        kp1, desc1 = future_kp1.result()
        kp2, desc2 = future_kp2.result()
        timings['sift_detection'] = time() - sift_start
        if desc1 is not None and desc2 is not None:
            print(f"  - Descriptors size (img1): {self.calculate_array_size(desc1):.2f}MB")
            print(f"  - Descriptors size (img2): {self.calculate_array_size(desc2):.2f}MB")

        flann_start = time()
        matches = self.matcher.knnMatch(desc1, desc2, k=2)
        timings['flann_matching'] = time() - flann_start

        filter_start = time()
        good = [m for m, n in matches if m.distance < Config.MATCH_RATIO * n.distance]
        good = sorted(good, key=lambda x: x.distance)[:int(len(good) * Config.KEEP_PERCENT)]
        timings['filter_matches'] = time() - filter_start

        points_start = time()
        query_idx = np.array([m.queryIdx for m in good], dtype=int)
        train_idx = np.array([m.trainIdx for m in good], dtype=int)
        pts1 = np.float32([kp.pt for kp in kp1])[query_idx] + [x1, y1]
        pts2 = np.float32([kp.pt for kp in kp2])[train_idx]
        timings['extract_points'] = time() - points_start
        return pts1, pts2, offset, timings

    def align_images(self, pts1: np.ndarray, pts2: np.ndarray) -> np.ndarray:
        """Estimate the affine transformation between two sets of points."""
        matrix, _ = cv2.estimateAffinePartial2D(
            pts2, pts1,
            method=cv2.RANSAC,
            ransacReprojThreshold=Config.THRESH,
            confidence=0.999,
            maxIters=10000
        )
        if matrix is not None:
            L = matrix[:, :2]
            s_x = np.linalg.norm(L[:, 0])
            s_y = np.linalg.norm(L[:, 1])
            s_avg = (s_x + s_y) / 2
            if s_avg > 0:
                matrix[:, :2] /= s_avg
                print(f"Adjusted scaling factor from {s_avg:.4f} to 1.0")
        return matrix

    @staticmethod
    @numba.jit(nopython=True, parallel=True)
    def overlay_images_numba(base: np.ndarray, overlay: np.ndarray, result: np.ndarray):
        """Overlay two images efficiently using Numba."""
        h, w = base.shape[:2]
        for y in numba.prange(h):
            for x in range(w):
                if overlay[y, x, 3] > 0:
                    result[y, x] = overlay[y, x]
                else:
                    result[y, x] = base[y, x]

    def overlay_images(self, base: np.ndarray, overlay: np.ndarray) -> np.ndarray:
        """Overlay one image onto another."""
        result = np.zeros_like(base, dtype=np.uint8)
        self.overlay_images_numba(base, overlay, result)
        return result

    def stitch(self, start: int = 3, end: int = 6):
        """Stitch images, downscaling by 2x before saving to avoid memory overload."""
        print(f"\nStarting image stitching process from {start} to {end}")
        start_time = time()
        self.log_memory("Memory before starting stitch")
        
        # Load the first image
        result, read_timings = self.read_image(start)
        self.tile_manager.set_large_image(0, 0, result)
        self.log_memory(f"Memory after setting initial image {start}")
        self.transformations.append((np.eye(2, 3, dtype=np.float32), result))  # Identity matrix for first image

        future: Optional[Future] = None
        for idx in range(start + 1, end + 1):
            iter_start = time()
            print(f"\nProcessing image {idx}/{end}")
            self.log_memory("  - Memory at start of iteration")

            # Read the current image
            if future is not None:
                current, read_timings = future.result()
            else:
                current, read_timings = self.read_image(idx)
            if idx < end:
                future = self.thread_pool.submit(self.read_image, idx + 1)
            self.log_memory(f"  - Memory after reading image {idx}")
            print(f"  - Current image size: {self.calculate_array_size(current):.2f}MB")

            # Find matches and align
            match_start = time()
            pts1, pts2, offset, match_timings = self.find_matches(current)
            match_time = time() - match_start
            self.log_memory("  - Memory after find_matches")

            align_start = time()
            matrix = self.align_images(pts1, pts2)
            align_time = time() - align_start

            # Store transformation and image
            self.transformations.append((matrix, current))

            # Update match region and cache tiles
            update_start = time()
            self.update_last_match_region(matrix, current.shape)
            update_time = time() - update_start

            cache_start = time()
            self.tile_manager.cache_far_tiles(self.last_match_region[0], self.last_match_region[1], 5000)
            cache_time = time() - cache_start
            self.log_memory("  - Memory after caching far tiles")

            # Log timings
            print(f"  - Find matches: {match_time:.3f}s")
            print(f"    - Extract ROI: {match_timings['extract_roi']:.3f}s")
            print(f"    - Allocate buffers: {match_timings['allocate_buffers']:.3f}s")
            print(f"    - Convert to grayscale: {match_timings['convert_to_grayscale']:.3f}s")
            print(f"    - SIFT detection: {match_timings['sift_detection']:.3f}s")
            print(f"    - FLANN matching: {match_timings['flann_matching']:.3f}s")
            print(f"    - Filter matches: {match_timings['filter_matches']:.3f}s")
            print(f"    - Extract points: {match_timings['extract_points']:.3f}s")
            print(f"  - Align images: {align_time:.3f}s")
            print(f"  - Update match region: {update_time:.3f}s")
            print(f"  - Cache far tiles: {cache_time:.3f}s")
            print(f"[Total for image {idx}: {time() - iter_start:.3f}s]")

        # Final downscaling and saving
        if self.transformations:
            # Compute full-size bounding box
            min_x, min_y, max_x, max_y = self.tile_manager.get_bounding_box()
            for matrix, current in self.transformations[1:]:  # Skip first image already in tile_manager
                h, w = current.shape[:2]
                pts = np.array([[0,0], [w,0], [w,h], [0,h]], dtype=np.float32)
                transformed = cv2.transform(pts[None, :, :], matrix)[0]
                min_x = min(min_x, int(np.floor(transformed[:,0].min())))
                min_y = min(min_y, int(np.floor(transformed[:,1].min())))
                max_x = max(max_x, int(np.ceil(transformed[:,0].max())))
                max_y = max(max_y, int(np.ceil(transformed[:,1].max())))

            # Create downscaled canvas (2x smaller)
            downscale = Config.DOWNSCALE_FACTOR  # 0.5
            canvas_width = int((max_x - min_x) * downscale)
            canvas_height = int((max_y - min_y) * downscale)
            canvas = np.zeros((canvas_height, canvas_width, 4), dtype=np.uint8)
            print(f"  - Downscaled canvas size: {canvas_width}x{canvas_height}")

            # Warp and overlay each image onto the downscaled canvas
            for matrix, current in self.transformations:
                # Adjust transformation for downscaling and offset
                M_downscaled = matrix.copy()
                M_downscaled[0, 2] = (M_downscaled[0, 2] - min_x) * downscale
                M_downscaled[1, 2] = (M_downscaled[1, 2] - min_y) * downscale
                M_downscaled[0, 0] *= downscale
                M_downscaled[0, 1] *= downscale
                M_downscaled[1, 0] *= downscale
                M_downscaled[1, 1] *= downscale

                # Warp the image directly onto the canvas
                aligned = cv2.warpAffine(
                    current, M_downscaled,
                    (canvas_width, canvas_height),
                    flags=cv2.INTER_LINEAR,
                    borderMode=cv2.BORDER_TRANSPARENT
                )

                # Overlay onto the canvas
                canvas = self.overlay_images(canvas, aligned)

            # Save the downscaled result
            cv2.imwrite(f"result_{end}.png", canvas)
            print(f"  - Saved downscaled result_{end}.png")

        total_time = time() - start_time
        self.log_memory("Final memory usage")
        print(f"\nStitching completed in {total_time:.3f}s")

    def __del__(self):
        """Shutdown the thread pool."""
        self.thread_pool.shutdown()

if __name__ == "__main__":
    stitcher = ImageStitcher()
    stitcher.stitch()

Using 16 worker threads

Starting image stitching process from 3 to 6
Memory before starting stitch: 114.21MB
Memory after setting initial image 3: 139.89MB

Processing image 4/6
  - Memory at start of iteration: 139.91MB
  - Memory after reading image 4: 146.41MB
  - Current image size: 6.44MB
  - Memory before find_matches: 146.41MB
  - Memory after extracting ROI: 511.54MB
  - ROI size: 213.49MB
  - Descriptors size (img1): 24.41MB
  - Descriptors size (img2): 14.14MB
  - Memory after find_matches: 31.56MB
Adjusted scaling factor from 0.9912 to 1.0
  - Memory after caching far tiles: 30.83MB
  - Find matches: 25.334s
    - Extract ROI: 0.703s
    - Allocate buffers: 0.000s
    - Convert to grayscale: 0.021s
    - SIFT detection: 22.275s
    - FLANN matching: 2.215s
    - Filter matches: 0.012s
    - Extract points: 0.080s
  - Align images: 0.019s
  - Update match region: 0.004s
  - Cache far tiles: 0.082s
[Total for image 4: 25.723s]

Processing image 5/6
  - Memory at start of iter