In [None]:
import pandas as pd
import numpy as np
import imageio
import cv2
import matplotlib.pyplot as plt
import os
from pathlib import Path
import re
from PIL import Image
import re
from collections import defaultdict

from colour_demosaicing import (
    ROOT_RESOURCES_EXAMPLES,
    demosaicing_CFA_Bayer_bilinear,
    demosaicing_CFA_Bayer_Malvar2004,
    demosaicing_CFA_Bayer_Menon2007,
    mosaicing_CFA_Bayer)

In [None]:
from RawHandler.RawHandler import RawHandler
from RawHandler.utils import linear_to_srgb
from src.training.load_config import load_config

def apply_gamma(x, gamma=2.2):
    return x ** (1 / gamma)

def reverse_gamma(x, gamma=2.2):
    return x ** gamma

In [None]:
run_config = load_config()
raw_path = Path(run_config['base_data_dir'])
outpath = Path(run_config['jpeg_output_subdir'])
alignment_csv =  outpath / run_config['align_csv']
outpath_cropped =  run_config['cropped_jpeg_subdir']
colorspace = run_config['colorspace']

file_list = os.listdir(raw_path)

In [None]:
def pair_images_by_scene(file_list, min_iso=100):
    """
    Given a list of RAW image file paths:
      1. Extract ISO from filenames
      2. Remove files with ISO < min_iso
      3. Group by scene name
      4. Pair each image with the lowest-ISO version of the scene

    Args:
        file_list (list of str): Paths to RAW files
        min_iso (int): Minimum ISO to keep (default=100)

    Returns:
        dict: {scene_name: [(img_path, gt_path), ...]}
    """
    iso_pattern = re.compile(r"_ISO(\d+)_")
    scene_pairs = {}

    # Step 1: Extract iso and scene
    images = []
    for path in file_list:
        filename = os.path.basename(path)
        match = iso_pattern.search(filename)
        if not match:
            continue  # skip if no ISO
        iso = int(match.group(1))
        if iso < min_iso:
            continue  # filter out low ISOs

        # Extract scene name:
        if "_GT_" in filename:
            scene = filename.split("_GT_")[0]
        else:
            # Scene = part before "_ISO"
            scene = filename.split("_ISO")[0]
        if 'X-Trans' in filename:
            continue

        images.append((scene, iso, path))

    # Step 2: Group by scene
    grouped = defaultdict(list)
    for scene, iso, path in images:
        grouped[scene].append((iso, path))

    # Step 3: For each scene, pick lowest ISO as GT
    for scene, iso_paths in grouped.items():
        iso_paths.sort(key=lambda x: x[0])  # sort by ISO ascending
        gt_iso, gt_path = iso_paths[0]      # lowest ISO â‰¥ min_iso
        pairs = [(path, gt_path) for iso, path in iso_paths if path != gt_path]
        scene_pairs[scene] = pairs

    return scene_pairs


In [None]:
def get_initial_warp_matrix(img1_gray, img2_gray, num_features=2000):
    """
    Finds an initial warp matrix using ORB feature matching.

    Args:
        img1_gray (np.array): The first grayscale image (template).
        img2_gray (np.array): The second grayscale image (to be warped).
        num_features (int): The number of features for ORB to detect.

    Returns:
        np.array: The 2x3 Euclidean warp matrix, or the identity matrix if it fails.
    """
    try:
        # Initialize ORB detector
        orb = cv2.ORB_create(nfeatures=num_features)

        # Find the keypoints and descriptors with ORB
        keypoints1, descriptors1 = orb.detectAndCompute(img1_gray, None)
        keypoints2, descriptors2 = orb.detectAndCompute(img2_gray, None)
        
        # Descriptors can be None if no keypoints are found
        if descriptors1 is None or descriptors2 is None:
            return np.eye(2, 3, dtype=np.float32)

        # Create BFMatcher object
        # NORM_HAMMING is used for binary descriptors like ORB
        bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)

        # Match descriptors
        matches = bf.match(descriptors1, descriptors2)

        # Sort them in the order of their distance (best matches first)
        matches = sorted(matches, key=lambda x: x.distance)

        # Keep only the top matches (e.g., top 50 or 15% of matches)
        num_good_matches = min(len(matches), 50)
        if num_good_matches < 10: # Need at least ~6-10 points for a robust estimate
            return np.eye(2, 3, dtype=np.float32)
            
        matches = matches[:num_good_matches]

        # Extract location of good matches
        points1 = np.zeros((len(matches), 2), dtype=np.float32)
        points2 = np.zeros((len(matches), 2), dtype=np.float32)

        for i, match in enumerate(matches):
            points1[i, :] = keypoints1[match.queryIdx].pt
            points2[i, :] = keypoints2[match.trainIdx].pt

        # Find the rigid transformation (Euclidean) using RANSAC
        # cv2.estimateAffinePartial2D is perfect for finding a Euclidean transform
        warp_matrix, _ = cv2.estimateAffinePartial2D(points2, points1, method=cv2.RANSAC)
        
        # If estimation fails, it returns None
        if warp_matrix is None:
            return np.eye(2, 3, dtype=np.float32)

        return warp_matrix.astype(np.float32)

    except cv2.error as e:
        print(f"OpenCV error during feature matching: {e}")
        return np.eye(2, 3, dtype=np.float32)

In [None]:
def save_warp_dataframe(warp_matrix):
    """
    Save warp matrix + metadata into a CSV with pandas.
    warp_matrix: 2x3 or 3x3 numpy array
    metadata: dict of other info
    """
    flat = warp_matrix.flatten()
    cols = [f"m{i}{j}" for i in range(warp_matrix.shape[0]) for j in range(warp_matrix.shape[1])]
    row = dict(zip(cols, flat))
    return row

In [None]:
def get_align_hybrid(noisy_fname, gt_fname, path, downsample_factor=4):
    """
    Hybrid function to align images using feature-based pre-alignment (coarse)
    and ECC (fine).
    """
    # 1. Load raw files
    noisy_handler = RawHandler(f'{path}/{noisy_fname}', colorspace=colorspace)
    gt_handler = RawHandler(f'{path}/{gt_fname}', colorspace=colorspace)

    noisy_bayer = noisy_handler.apply_colorspace_transform(colorspace='lin_rec2020', clip=True).astype(np.float32)
    gt_bayer = gt_handler.apply_colorspace_transform(colorspace='lin_rec2020', clip=True).astype(np.float32)
    noisy_bayer = apply_gamma(noisy_bayer)
    gt_bayer = apply_gamma(gt_bayer)
    
    noisy_image = demosaicing_CFA_Bayer_Malvar2004(noisy_bayer)
    gt_image = demosaicing_CFA_Bayer_Malvar2004(gt_bayer)
    noisy_image = np.clip(noisy_image, 0, 1)
    gt_image = np.clip(gt_image, 0, 1)


    # Note: OpenCV expects BGR order, RawHandler might give RGB. Ensure consistency.
    # Assuming BGR for cvtColor. If RGB, use cv2.COLOR_RGB2GRAY.
    gt_image_uint8 = (gt_image * 255.0).clip(0, 255).astype(np.uint8)
    noisy_image_uint8 = (noisy_image * 255.0).clip(0, 255).astype(np.uint8)

    # 3. Convert to grayscale using the faster uint8 versions
    noisy_gray = cv2.cvtColor(noisy_image_uint8, cv2.COLOR_BGR2GRAY)
    gt_gray = cv2.cvtColor(gt_image_uint8, cv2.COLOR_BGR2GRAY)
    h, w = noisy_gray.shape

    # 4. --- NEW: Get initial warp matrix from feature matching ---
    # We run this on the full-res grayscale images for better keypoint detection
    warp_matrix = get_initial_warp_matrix(gt_gray, noisy_gray)

    # 5. Downsample for ECC refinement
    if downsample_factor > 1:
        # We need to scale the initial guess to the downsampled size
        warp_matrix_scaled = warp_matrix.copy()
        warp_matrix_scaled[0, 2] /= downsample_factor
        warp_matrix_scaled[1, 2] /= downsample_factor

        noisy_gray_small = cv2.resize(noisy_gray, (w // downsample_factor, h // downsample_factor), interpolation=cv2.INTER_AREA)
        gt_gray_small = cv2.resize(gt_gray, (w // downsample_factor, h // downsample_factor), interpolation=cv2.INTER_AREA)
    else:
        noisy_gray_small = noisy_gray
        gt_gray_small = gt_gray
        warp_matrix_scaled = warp_matrix

    # 6. ECC alignment (refinement) using the improved initial guess
    warp_mode = cv2.MOTION_EUCLIDEAN
    criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 200, 1e-5)
    # We provide `warp_matrix_scaled` as the initial guess!
    try:

        (cc, warp_matrix_final_scaled) = cv2.findTransformECC(gt_gray_small, noisy_gray_small, warp_matrix_scaled, warp_mode, criteria)
    except cv2.error:
        # If ECC fails, use the initial matrix from feature matching
        cc = -1.0 # Indicate failure or that we used the fallback
        warp_matrix_final_scaled = warp_matrix_scaled
    
    # 7. Scale the final matrix translation back to full resolution
    
    warp_matrix_final = warp_matrix_final_scaled.copy()
    if downsample_factor > 1:
        warp_matrix_final[0, 2] *= downsample_factor
        warp_matrix_final[1, 2] *= downsample_factor

    # 8. Warp the original full-resolution FLOAT image
    gt_aligned = cv2.warpAffine(gt_image, warp_matrix_final, (w, h), flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP)

    # 9. Metadata / stats
    try:
        iso = float(re.findall('ISO([0-9]+)', noisy_fname)[0])
    except (IndexError, ValueError):
        iso = 0

    info_dict = {
        "cc": cc,
        "noisy_image": noisy_fname,
        "gt_image": gt_fname,
        "gt_mean": gt_image.mean(),
        "noisy_mean": noisy_image.mean(),
        "noise_level": (noisy_image-gt_image).std(axis=(0,1)),
        **save_warp_dataframe(warp_matrix_final),
        "iso": iso,
    }

    return info_dict, noisy_image, gt_aligned, noisy_bayer[0], gt_image

In [None]:
pair_file_list = pair_images_by_scene(file_list)

In [None]:
def as_8bit(x):
    return (x * 255).astype(np.uint8)

In [None]:
# Test loop so we can visualize the alignment performance

list = []
idx = 0
for key in pair_file_list.keys():
    image_pairs = pair_file_list[key]
    print(idx, idx/len(pair_file_list))
    idx+=1
    jdx = 0
    for (noise, gt) in image_pairs:
        output, noisy_image, gt_aligned, noisy_bayer, gt_image = get_align_hybrid(noise, gt, path, downsample_factor=1)
        break
    break


In [None]:
plt.subplots(2, 3, figsize=(30, 20))

plt.subplot(2,3,1)
plt.imshow(noisy_image[1000:1100, 3000:3100])

plt.subplot(2,3,2)
plt.imshow(noisy_bayer[1000:1100, 3000:3100])

plt.subplot(2,3, 3)
plt.imshow(gt_image[1000:1100, 3000:3100])


plt.subplot(2,3, 4)
plt.imshow(noisy_image[1000:1100, 3000:3100]-gt_image[1000:1100, 3000:3100]+0.5)


plt.subplot(2,3, 5)
plt.imshow(noisy_image[1000:1100, 3000:3100]-gt_aligned[1000:1100, 3000:3100]+0.5)

plt.subplot(2,3, 6)
# plt.imshow(noisy_bayer[1000:1100, 3000:3100]-noisy_image[1000:1100, 3000:3100]+0.5)

In [None]:
# Align and save the jpegs
list = []
idx = 0
for key in pair_file_list.keys():
    image_pairs = pair_file_list[key]
    print(idx, idx/len(pair_file_list))
    idx+=1
    jdx = 0
    for (noise, gt) in image_pairs:
        try:
            output, noisy_image, gt_aligned, noisy_bayer, gt_image = get_align_hybrid(noise, gt, path, downsample_factor=1)
            list.append(output)
            noisy_image
            imageio.imwrite(f"{outpath}/{noise}.jpg", as_8bit(noisy_image), quality=100)
            imageio.imwrite(f"{outpath}/{noise}_bayer.jpg", as_8bit(noisy_bayer), quality=100)
            if jdx==0:
                imageio.imwrite(f"{outpath}/{gt}.jpg", as_8bit(gt_image), quality=100)
            jdx+=1
        except:
            print(f"skipping {noise}")


In [None]:
df = pd.DataFrame(list)
df.to_csv(alignment_csv)

In [None]:
###
## The following code allows for the existing dataset to be further reduced by cropping into center squares
###

In [None]:
# Load saved dataset and visualize alignment

In [None]:
df = pd.read_csv(alignment_csv)

In [None]:
row = df.iloc[-2]

# Get Row Matrix
shape=(2,3)
cols = [f"m{i}{j}" for i in range(shape[0]) for j in range(shape[1])]
flat = np.array([row.pop(c) for c in cols], dtype=np.float32)
warp_matrix = flat.reshape(shape)
warp_matrix


noisy_name = row.noisy_image
gt_name = row.gt_image

with imageio.imopen(f"{outpath}/{noisy_name}_bayer.jpg", "r") as image_resource:
    bayer_data = image_resource.read()

with imageio.imopen(f"{outpath}/{noisy_name}.jpg", "r") as image_resource:
    noisy = image_resource.read()


with imageio.imopen(f"{outpath}/{gt_name}.jpg", "r") as image_resource:
    gt_image = image_resource.read()

noisy = noisy/255
gt_image = gt_image/255
bayer_data = bayer_data/255
h, w, _ = noisy.shape
gt = cv2.warpAffine(gt_image, warp_matrix, (w, h), flags=cv2.INTER_LINEAR + cv2.WARP_INVERSE_MAP)

warp_matrix



In [None]:
demosaiced = demosaicing_CFA_Bayer_Malvar2004(bayer_data)

In [None]:
plt.subplots(2, 3, figsize=(30, 20))

plt.subplot(2,3,1)
plt.imshow(noisy[1000:1100, 3000:3100])

plt.subplot(2,3,2)
plt.imshow(demosaiced[1000:1100, 3000:3100])

plt.subplot(2,3, 3)
plt.imshow(gt[1000:1100, 3000:3100])


plt.subplot(2,3, 4)
plt.imshow(gt[1000:1100, 3000:3100]-demosaiced[1000:1100, 3000:3100]+0.5)


plt.subplot(2,3, 5)
plt.imshow(noisy[1000:1100, 3000:3100]-demosaiced[1000:1100, 3000:3100]+0.5) 

plt.subplot(2,3, 6)
plt.imshow(noisy[1000:1100, 3000:3100]-gt_image[1000:1100, 3000:3100]+0.5)

In [None]:
##
## Crop to center to save even more space
##

In [None]:

def crop_center_square(input_dir, output_dir, crop_size):
    """
    Loops over image files in a directory, crops a center square, and saves them.

    Args:
        input_dir (str): The path to the directory containing the original images.
        output_dir (str): The path to the directory where cropped images will be saved.
        crop_size (int): The width and height of the square to crop (must be an even number).
    """
    # --- 1. Input Validation ---
    if not os.path.isdir(input_dir):
        print(f"Error: Input directory not found at '{input_dir}'")
        return

    if crop_size % 2 != 0:
        print(f"Error: Crop size must be an even number. You provided {crop_size}.")
        return
        
    # --- 2. Create Output Directory ---
    # Create the output directory if it doesn't already exist.
    os.makedirs(output_dir, exist_ok=True)
    print(f"Output will be saved to: '{output_dir}'")

    # --- 3. Process Images ---
    # Get a list of all files in the input directory.
    files = os.listdir(input_dir)

    processed_count = 0
    for filename in files:
        # Construct the full path for the input file.
        input_path = os.path.join(input_dir, filename)

        # Process only files, not subdirectories.
        if os.path.isfile(input_path):
            try:
                # Open the image using Pillow.
                with Image.open(input_path) as img:
                    width, height = img.size

                    # Check if the image is large enough to be cropped.
                    if width < crop_size or height < crop_size:
                        print(f"Skipping '{filename}': smaller than crop size.")
                        continue
                    
                    # Calculate the coordinates for the center crop.
                    left = (width - crop_size) // 2
                    top = (height - crop_size) // 2

                    # Ensure the top-left corner is on an even pixel coordinate for Bayer alignment.
                    if left % 2 != 0:
                        left -= 1
                    if top % 2 != 0:
                        top -= 1
                    
                    # Calculate the bottom-right corner based on the adjusted top-left corner.
                    # Since crop_size is even, right and bottom will also be even.
                    right = left + crop_size
                    bottom = top + crop_size

                    # Perform the crop. The box is a 4-tuple defining the left, upper, right, and lower pixel coordinate.
                    img_cropped = np.array(img.crop((left, top, right, bottom)))

                    # Construct the full path for the output file.
                    output_path = os.path.join(output_dir, filename)
                    
                    # Save the cropped image.
                    # img_cropped.save(output_path)
                    imageio.imwrite(output_path,img_cropped, quality=95)
                    processed_count += 1

            except (IOError, OSError) as e:
                # Handle cases where the file is not a valid image.
                print(f"Could not process '{filename}'. It might not be an image file. Error: {e}")
            except Exception as e:
                print(f"An unexpected error occurred with file '{filename}': {e}")
                
    print(f"\nProcessing complete. Cropped {processed_count} images.")

In [None]:
crop_center_square(outpath, outpath_cropped, crop_size=2000)