In [1]:
import cv2
import numpy as np
from dataclasses import dataclass
from time import time
from typing import Tuple, Optional
from pathlib import Path
import numba
from concurrent.futures import ThreadPoolExecutor, Future
import os
import psutil  # For memory monitoring

# Configuration settings for the stitching process
@dataclass(frozen=True)
class Config:
    SCALE: float = 0.2  # Scale factor for resizing images
    NFEAT: int = 50000  # Number of SIFT features
    MATCH_RATIO: float = 0.65  # Ratio for filtering good matches
    KEEP_PERCENT: float = 0.65  # Percentage of best matches to keep
    THRESH: float = 5.0  # RANSAC reprojection threshold
    SIFT_EDGE_THRESHOLD: float = 10.0  # SIFT edge threshold
    SIFT_CONTRAST_THRESHOLD: float = 0.04  # SIFT contrast threshold
    NUM_THREADS: int = os.cpu_count() if os.cpu_count() else 4  # Number of threads
    SEARCH_FIELD_MULTIPLIER: float = 3.0  # Reduced to avoid excessive search regions
    TILE_SIZE: int = 1024  # Size of each tile in pixels
    CACHE_SIZE: int = 500  # Reduced cache size to limit memory usage

class TileManager:
    def __init__(self, tile_size: int, cache_size: int):
        """Initialize the TileManager for handling image tiles."""
        self.tile_size = tile_size
        self.cache = {}  # (i,j) -> tile data
        self.cache_order = []  # For LRU eviction
        self.cache_size = cache_size
        self.tile_dir = Path("tiles")
        self.tile_dir.mkdir(exist_ok=True)
        self.used_tiles = set()  # Track unique tiles accessed

    def get_tile(self, i: int, j: int) -> np.ndarray:
        """Retrieve a tile from cache or disk, creating it if necessary."""
        if (i, j) in self.cache:
            print(f"Tile ({i}, {j}) retrieved from cache")
            return self.cache[(i, j)]
        tile_path = self.tile_dir / f"tile_{i}_{j}.png"
        if tile_path.exists():
            tile = cv2.imread(str(tile_path), cv2.IMREAD_UNCHANGED)
            print(f"Tile ({i}, {j}) loaded from disk")
        else:
            tile = np.zeros((self.tile_size, self.tile_size, 4), dtype=np.uint8)
            print(f"Tile ({i}, {j}) created as zero array")
        self.cache[(i, j)] = tile
        self.cache_order.append((i, j))
        if len(self.cache) > self.cache_size:
            oldest = self.cache_order.pop(0)
            del self.cache[oldest]
            print(f"Evicted tile {oldest} from cache")
        self.used_tiles.add((i, j))
        return tile

    def set_tile(self, i: int, j: int, tile: np.ndarray):
        """Store a tile in cache and save it to disk if modified."""
        self.cache[(i, j)] = tile
        if (i, j) not in self.cache_order:
            self.cache_order.append((i, j))
        if len(self.cache) > self.cache_size:
            oldest = self.cache_order.pop(0)
            del self.cache[oldest]
            print(f"Evicted tile {oldest} from cache")
        tile_path = self.tile_dir / f"tile_{i}_{j}.png"
        cv2.imwrite(str(tile_path), tile)
        print(f"Tile ({i}, {j}) saved to disk")

    def get_region(self, x1: int, y1: int, x2: int, y2: int) -> np.ndarray:
        """Extract a region from the stitched image by combining tiles."""
        i1 = y1 // self.tile_size
        i2 = (y2 - 1) // self.tile_size + 1
        j1 = x1 // self.tile_size
        j2 = (x2 - 1) // self.tile_size + 1
        print(f"Getting region from tiles ({i1},{j1}) to ({i2},{j2})")
        tiles = []
        for i in range(i1, i2):
            row = [self.get_tile(i, j) for j in range(j1, j2)]
            tiles.append(np.hstack(row))
        full_region = np.vstack(tiles)
        tx1, ty1 = j1 * self.tile_size, i1 * self.tile_size
        return full_region[y1 - ty1:y2 - ty1, x1 - tx1:x2 - tx1]

    def set_region(self, x1: int, y1: int, x2: int, y2: int, data: np.ndarray):
        """Update a region by modifying the overlapping tiles."""
        i1 = y1 // self.tile_size
        i2 = (y2 - 1) // self.tile_size + 1
        j1 = x1 // self.tile_size
        j2 = (x2 - 1) // self.tile_size + 1
        print(f"Setting region to tiles ({i1},{j1}) to ({i2},{j2})")
        for i in range(i1, i2):
            for j in range(j1, j2):
                tile = self.get_tile(i, j)
                tile_y1 = max(y1, i * self.tile_size)
                tile_y2 = min(y2, (i + 1) * self.tile_size)
                tile_x1 = max(x1, j * self.tile_size)
                tile_x2 = min(x2, (j + 1) * self.tile_size)
                rel_y1, rel_y2 = tile_y1 - i * self.tile_size, tile_y2 - i * self.tile_size
                rel_x1, rel_x2 = tile_x1 - j * self.tile_size, tile_x2 - j * self.tile_size
                data_y1, data_y2 = tile_y1 - y1, tile_y2 - y1
                data_x1, data_x2 = tile_x1 - x1, tile_x2 - x1
                tile[rel_y1:rel_y2, rel_x1:rel_x2] = data[data_y1:data_y2, data_x1:data_x2]
                self.set_tile(i, j, tile)

    def get_used_tiles_count(self) -> int:
        """Return the number of unique tiles that have been used."""
        return len(self.used_tiles)

    def get_cached_tiles_count(self) -> int:
        """Return the number of tiles currently stored in the cache."""
        return len(self.cache)

class ImageStitcher:
    def __init__(self, path: str = "./folder2/"):
        """Initialize the ImageStitcher with a directory path."""
        self.path = Path(path)
        if Config.SEARCH_FIELD_MULTIPLIER < 1.0:
            raise ValueError("SEARCH_FIELD_MULTIPLIER must be >= 1.0")
        self.sift = cv2.SIFT_create(
            nfeatures=Config.NFEAT,
            contrastThreshold=Config.SIFT_CONTRAST_THRESHOLD,
            edgeThreshold=Config.SIFT_EDGE_THRESHOLD
        )
        self.matcher = cv2.FlannBasedMatcher(dict(algorithm=1, trees=5), dict(checks=32))
        self.last_match_region: Optional[Tuple[int, int]] = None
        self._gray_buffer1: Optional[np.ndarray] = None
        self._gray_buffer2: Optional[np.ndarray] = None
        self._base_image_size: Optional[Tuple[int, int]] = None
        self.tile_manager = TileManager(Config.TILE_SIZE, Config.CACHE_SIZE)
        print(f"Using {Config.NUM_THREADS} worker threads")
        self.thread_pool = ThreadPoolExecutor(max_workers=Config.NUM_THREADS)
        self.process = psutil.Process(os.getpid())
        # Initialize attributes for bounding box tracking and stitch counting
        self.overall_min_x = None
        self.overall_min_y = None
        self.overall_max_x = None
        self.overall_max_y = None
        self.stitched_count = 0

    def get_memory_usage(self) -> float:
        """Return current memory usage in MB."""
        return self.process.memory_info().rss / 1024**2

    def log_memory(self, message: str):
        """Log memory usage with a descriptive message."""
        print(f"{message}: {self.get_memory_usage():.2f}MB")

    def calculate_array_size(self, arr: np.ndarray) -> float:
        """Calculate memory size of a NumPy array in MB."""
        return arr.nbytes / 1024**2

    def read_image(self, idx: int) -> Tuple[np.ndarray, dict]:
        """Read and preprocess an image from disk."""
        timings = {}
        path_start = time()
        img_path = self.path / f"2023_09_01_SonyRX1RM2_g201b20538_f001_{idx:04}.JPG"
        timings['construct_path'] = time() - path_start
        read_start = time()
        img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
        timings['read_from_disk'] = time() - read_start
        resize_start = time()
        img = cv2.resize(img, None, fx=Config.SCALE, fy=Config.SCALE, interpolation=cv2.INTER_LINEAR)
        timings['resize'] = time() - resize_start
        if self._base_image_size is None:
            self._base_image_size = (img.shape[1], img.shape[0])
        rgba_start = time()
        rgba = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8)
        rgba[:, :, :3] = img
        rgba[:, :, 3] = 255
        timings['convert_to_rgba'] = time() - rgba_start
        return rgba, timings

    def get_search_region(self) -> Tuple[int, int, int, int]:
        """Calculate the search region in global coordinates."""
        if self.last_match_region is None:
            raise ValueError("Last match region is not set")
        cx, cy = self.last_match_region
        base_w, base_h = self._base_image_size
        search_w = int(round(base_w * Config.SEARCH_FIELD_MULTIPLIER))
        search_h = int(round(base_h * Config.SEARCH_FIELD_MULTIPLIER))
        search_w = max(base_w, search_w)
        search_h = max(base_h, search_h)
        x1 = int(round(cx - search_w / 2))
        y1 = int(round(cy - search_h / 2))
        x2 = x1 + search_w
        y2 = y1 + search_h
        print(f"Search region: ({x1}, {y1}) to ({x2}, {y2})")
        return x1, y1, x2, y2

    def find_matches(self, img2: np.ndarray) -> Tuple[np.ndarray, np.ndarray, dict]:
        """Find feature matches between the stitched image and the new image."""
        timings = {}
        self.log_memory("  - Memory before find_matches")
        x1, y1, x2, y2 = self.get_search_region()
        roi_start = time()
        img1_roi = self.tile_manager.get_region(x1, y1, x2, y2)
        timings['extract_roi'] = time() - roi_start
        self.log_memory("  - Memory after extracting ROI")
        print(f"  - ROI size: {self.calculate_array_size(img1_roi):.2f}MB")

        buffer_start = time()
        if self._gray_buffer1 is None or self._gray_buffer1.shape != img1_roi.shape[:2]:
            self._gray_buffer1 = np.empty(img1_roi.shape[:2], dtype=np.uint8)
        if self._gray_buffer2 is None or self._gray_buffer2.shape != img2.shape[:2]:
            self._gray_buffer2 = np.empty(img2.shape[:2], dtype=np.uint8)
        timings['allocate_buffers'] = time() - buffer_start
        self.log_memory("  - Memory after allocating buffers")

        gray_start = time()
        futures = [
            self.thread_pool.submit(cv2.cvtColor, img1_roi, cv2.COLOR_BGRA2GRAY, dst=self._gray_buffer1),
            self.thread_pool.submit(cv2.cvtColor, img2, cv2.COLOR_BGRA2GRAY, dst=self._gray_buffer2)
        ]
        for future in futures:
            future.result()
        timings['convert_to_grayscale'] = time() - gray_start
        self.log_memory("  - Memory after converting to grayscale")

        def detect_compute(img, sift):
            return sift.detectAndCompute(img, None)
        sift_start = time()
        future_kp1 = self.thread_pool.submit(detect_compute, self._gray_buffer1, self.sift)
        future_kp2 = self.thread_pool.submit(detect_compute, self._gray_buffer2, self.sift)
        kp1, desc1 = future_kp1.result()
        kp2, desc2 = future_kp2.result()
        timings['sift_detection'] = time() - sift_start
        self.log_memory("  - Memory after SIFT detection")
        if desc1 is not None and desc2 is not None:
            print(f"  - Descriptors size (img1): {self.calculate_array_size(desc1):.2f}MB")
            print(f"  - Descriptors size (img2): {self.calculate_array_size(desc2):.2f}MB")

        flann_start = time()
        matches = self.matcher.knnMatch(desc1, desc2, k=2)
        timings['flann_matching'] = time() - flann_start
        self.log_memory("  - Memory after FLANN matching")

        filter_start = time()
        good = [m for m, n in matches if m.distance < Config.MATCH_RATIO * n.distance]
        good = sorted(good, key=lambda x: x.distance)[:int(len(good) * Config.KEEP_PERCENT)]
        timings['filter_matches'] = time() - filter_start
        self.log_memory("  - Memory after filtering matches")

        points_start = time()
        query_idx = np.array([m.queryIdx for m in good], dtype=int)
        train_idx = np.array([m.trainIdx for m in good], dtype=int)
        pts1 = np.float32([kp.pt for kp in kp1])[query_idx] + [x1, y1]  # Global coordinates
        pts2 = np.float32([kp.pt for kp in kp2])[train_idx]
        timings['extract_points'] = time() - points_start
        self.log_memory("  - Memory after extracting points")
        return pts1, pts2, timings

    def align_images(self, pts1: np.ndarray, pts2: np.ndarray) -> np.ndarray:
        """Estimate the affine transformation between two sets of points."""
        matrix, _ = cv2.estimateAffinePartial2D(
            pts2, pts1,
            method=cv2.RANSAC,
            ransacReprojThreshold=Config.THRESH,
            confidence=0.995,
            maxIters=1000
        )
        if matrix is not None:
            L = matrix[:, :2]
            s_x = np.linalg.norm(L[:, 0])
            s_y = np.linalg.norm(L[:, 1])
            s_avg = (s_x + s_y) / 2
            if s_avg > 0:
                matrix[:, :2] /= s_avg
                print(f"Adjusted scaling factor from {s_avg:.4f} to 1.0")
            else:
                print("Warning: s_avg is zero, cannot adjust scaling")
        else:
            print("Warning: No transformation matrix found")
        return matrix

    @staticmethod
    @numba.jit(nopython=True, parallel=True)
    def overlay_images_numba(base: np.ndarray, overlay: np.ndarray, result: np.ndarray):
        """Overlay two images using Numba for performance."""
        h, w = base.shape[:2]
        for y in numba.prange(h):
            for x in range(w):
                if overlay[y, x, 3] > 0:
                    result[y, x] = overlay[y, x]
                else:
                    result[y, x] = base[y, x]

    def overlay_images(self, base: np.ndarray, overlay: np.ndarray) -> np.ndarray:
        """Overlay the aligned image onto the base region."""
        result = np.zeros_like(base, dtype=np.uint8)
        ImageStitcher.overlay_images_numba(base, overlay, result)
        return result

    def update_last_match_region(self, matrix: np.ndarray, img_shape: Tuple[int, int]):
        """Update the last match region based on the transformation matrix."""
        h, w = img_shape[:2]
        center = np.array([[w / 2], [h / 2], [1]], dtype=np.float32)
        transformed = matrix @ center
        self.last_match_region = (int(round(transformed[0, 0])), int(round(transformed[1, 0])))

    def stitch(self, start: int = 3, end: int = 2249):
        """Stitch images from start to end index with memory logging and save every 5 photos."""
        print(f"\nStarting image stitching process from {start} to {end}")
        start_time = time()
        self.log_memory("Memory before starting stitch")
        first_image, read_timings = self.read_image(start)
        self.log_memory(f"Memory after reading image {start}")
        print(f"  - First image size: {self.calculate_array_size(first_image):.2f}MB")

        # Initialize the first image and bounding box
        h, w = first_image.shape[:2]
        self.tile_manager.set_region(0, 0, w, h, first_image)
        self.last_match_region = (w // 2, h // 2)
        # Initialize bounding box and stitched count
        self.overall_min_x = 0
        self.overall_min_y = 0
        self.overall_max_x = w
        self.overall_max_y = h
        self.stitched_count = 1  # First image is stitched

        future: Optional[Future] = None
        for idx in range(start + 1, end + 1):
            iter_start = time()
            print(f"\nProcessing image {idx}/{end}")
            self.log_memory("  - Memory at start of iteration")

            # Read the current image
            if future is not None:
                current, read_timings = future.result()
            else:
                current, read_timings = self.read_image(idx)
            if idx < end:
                future = self.thread_pool.submit(self.read_image, idx + 1)
            self.log_memory(f"  - Memory after reading image {idx}")
            print(f"  - Current image size: {self.calculate_array_size(current):.2f}MB")

            # Find matches
            match_start = time()
            pts1, pts2, match_timings = self.find_matches(current)
            match_time = time() - match_start
            self.log_memory("  - Memory after find_matches")

            # Align images
            align_start = time()
            matrix = self.align_images(pts1, pts2)
            align_time = time() - align_start
            # Check if alignment failed
            if matrix is None:
                print(f"Warning: Could not find transformation for image {idx}, skipping.")
                continue

            # Calculate bounding box of the warped image
            h, w = current.shape[:2]
            corners = np.array([[0, 0, 1], [w, 0, 1], [w, h, 1], [0, h, 1]], dtype=np.float32).T
            transformed_corners = matrix @ corners
            xs, ys = transformed_corners[0, :], transformed_corners[1, :]
            bx1, by1 = int(np.floor(min(xs))), int(np.floor(min(ys)))
            bx2, by2 = int(np.ceil(max(xs))), int(np.ceil(max(ys)))

            # Warp and overlay the image
            warp_start = time()
            base_region = self.tile_manager.get_region(bx1, by1, bx2, by2)
            M_adjusted = np.vstack([matrix, [0, 0, 1]])
            T = np.array([[1, 0, -bx1], [0, 1, -by1], [0, 0, 1]])
            M_adjusted = (T @ M_adjusted)[:2, :]
            aligned = cv2.warpAffine(
                current, M_adjusted,
                (bx2 - bx1, by2 - by1),
                flags=cv2.INTER_LINEAR,
                borderMode=cv2.BORDER_TRANSPARENT
            )
            result_region = self.overlay_images(base_region, aligned)
            self.tile_manager.set_region(bx1, by1, bx2, by2, result_region)
            warp_time = time() - warp_start
            self.log_memory("  - Memory after warp and overlay")

            # Update overall bounding box and stitched count
            self.overall_min_x = min(self.overall_min_x, bx1)
            self.overall_min_y = min(self.overall_min_y, by1)
            self.overall_max_x = max(self.overall_max_x, bx2)
            self.overall_max_y = max(self.overall_max_y, by2)
            self.stitched_count += 1

            # Save the stitched image every 5 successfully stitched photos
            if self.stitched_count % 30 == 0:
                save_start = time()
                stitched_image = self.tile_manager.get_region(
                    self.overall_min_x, self.overall_min_y,
                    self.overall_max_x, self.overall_max_y
                )
                cv2.imwrite(f"stitched_up_to_{idx}.png", stitched_image)
                save_time = time() - save_start
                print(f"  - Saved stitched image up to {idx} in {save_time:.3f}s")

            # Update last match region
            update_start = time()
            self.update_last_match_region(matrix, current.shape)
            update_time = time() - update_start

            # Print timing details
            print(f"  - Find matches: {match_time:.3f}s")
            print(f"    - Extract ROI: {match_timings['extract_roi']:.3f}s")
            print(f"    - Allocate buffers: {match_timings['allocate_buffers']:.3f}s")
            print(f"    - Convert to grayscale: {match_timings['convert_to_grayscale']:.3f}s")
            print(f"    - SIFT detection: {match_timings['sift_detection']:.3f}s")
            print(f"    - FLANN matching: {match_timings['flann_matching']:.3f}s")
            print(f"    - Filter matches: {match_timings['filter_matches']:.3f}s")
            print(f"    - Extract points: {match_timings['extract_points']:.3f}s")
            print(f"  - Align images: {align_time:.3f}s")
            print(f"  - Warp and overlay: {warp_time:.3f}s")
            print(f"  - Update match region: {update_time:.3f}s")
            print(f"[Total for image {idx}: {time() - iter_start:.3f}s]")

            # Print tile usage information
            print(f"  - Total unique tiles used: {self.tile_manager.get_used_tiles_count()}")
            print(f"  - Current tiles in cache: {self.tile_manager.get_cached_tiles_count()}")

        total_time = time() - start_time
        self.log_memory("Final memory usage")
        print(f"\nStitching completed in {total_time:.3f}s")

    def __del__(self):
        """Clean up resources."""
        self.thread_pool.shutdown()

if __name__ == "__main__":
    stitcher = ImageStitcher()
    stitcher.stitch()

Using 16 worker threads

Starting image stitching process from 3 to 2249
Memory before starting stitch: 113.88MB
Memory after reading image 3: 122.83MB
  - First image size: 6.44MB
Setting region to tiles (0,0) to (2,2)
Tile (0, 0) loaded from disk
Tile (0, 0) saved to disk
Tile (0, 1) loaded from disk
Tile (0, 1) saved to disk
Tile (1, 0) loaded from disk
Tile (1, 0) saved to disk
Tile (1, 1) loaded from disk
Tile (1, 1) saved to disk

Processing image 4/2249
  - Memory at start of iteration: 139.07MB
  - Memory after reading image 4: 145.62MB
  - Current image size: 6.44MB
  - Memory before find_matches: 145.62MB
Search region: (-1590, -1062) to (3180, 2121)
Getting region from tiles (-2,-2) to (3,4)
Tile (-2, -2) created as zero array
Tile (-2, -1) loaded from disk
Tile (-2, 0) loaded from disk
Tile (-2, 1) loaded from disk
Tile (-2, 2) created as zero array
Tile (-2, 3) created as zero array
Tile (-1, -2) created as zero array
Tile (-1, -1) loaded from disk
Tile (-1, 0) loaded from

MemoryError: Unable to allocate 4.00 MiB for an array with shape (1024, 1024, 4) and data type uint8