In [None]:
import pandas as pd
import numpy as np
import imageio
import cv2
import matplotlib.pyplot as plt
import os
from pathlib import Path
import re
from PIL import Image
import re
from collections import defaultdict

from colour_demosaicing import (
    ROOT_RESOURCES_EXAMPLES,
    demosaicing_CFA_Bayer_bilinear,
    demosaicing_CFA_Bayer_Malvar2004,
    demosaicing_CFA_Bayer_Menon2007,
    mosaicing_CFA_Bayer)

In [None]:
from RawHandler.RawHandler import RawHandler
from RawHandler.utils import linear_to_srgb
from src.training.load_config import load_config

def apply_gamma(x, gamma=2.2):
    return x ** (1 / gamma)

def reverse_gamma(x, gamma=2.2):
    return x ** gamma

In [None]:
run_config = load_config()
raw_path = Path(run_config['base_data_dir'])
outpath = Path(run_config['cropped_raw_subdir'])
alignment_csv =  outpath / run_config['align_csv']
colorspace = run_config['colorspace']
crop_size = run_config['cropped_raw_size']
file_list = os.listdir(raw_path)

In [None]:
def pair_images_by_scene(file_list, min_iso=100):
    """
    Given a list of RAW image file paths:
      1. Extract ISO from filenames
      2. Remove files with ISO < min_iso
      3. Group by scene name
      4. Pair each image with the lowest-ISO version of the scene

    Args:
        file_list (list of str): Paths to RAW files
        min_iso (int): Minimum ISO to keep (default=100)

    Returns:
        dict: {scene_name: [(img_path, gt_path), ...]}
    """
    iso_pattern = re.compile(r"_ISO(\d+)_")
    scene_pairs = {}

    # Step 1: Extract iso and scene
    images = []
    for path in file_list:
        filename = os.path.basename(path)
        match = iso_pattern.search(filename)
        if not match:
            continue  # skip if no ISO
        iso = int(match.group(1))
        if iso < min_iso:
            continue  # filter out low ISOs

        # Extract scene name:
        if "_GT_" in filename:
            scene = filename.split("_GT_")[0]
        else:
            # Scene = part before "_ISO"
            scene = filename.split("_ISO")[0]
        if 'X-Trans' in filename:
            continue

        images.append((scene, iso, path))

    # Step 2: Group by scene
    grouped = defaultdict(list)
    for scene, iso, path in images:
        grouped[scene].append((iso, path))

    # Step 3: For each scene, pick lowest ISO as GT
    for scene, iso_paths in grouped.items():
        iso_paths.sort(key=lambda x: x[0])  # sort by ISO ascending
        gt_iso, gt_path = iso_paths[0]      # lowest ISO â‰¥ min_iso
        pairs = [(path, gt_path) for iso, path in iso_paths if path != gt_path]
        scene_pairs[scene] = pairs

    return scene_pairs


In [None]:
pair_file_list = pair_images_by_scene(file_list)

In [None]:
def get_file(impath, crop_size=crop_size):
        rh = RawHandler(impath)
        
        width, height = rh.raw.shape

        # Check if the image is large enough to be cropped.
        if width < crop_size or height < crop_size:
            im = rh.apply_colorspace_transform( colorspace=colorspace)
        else:
            # Calculate the coordinates for the center crop.
            left = (width - crop_size) // 2
            top = (height - crop_size) // 2

            # Ensure the top-left corner is on an even pixel coordinate for Bayer alignment.
            if left % 2 != 0:
                left -= 1
            if top % 2 != 0:
                top -= 1
            
            # Calculate the bottom-right corner based on the adjusted top-left corner.
            # Since crop_size is even, right and bottom will also be even.
            right = left + crop_size
            bottom = top + crop_size

            im = rh.apply_colorspace_transform(dims=(left, right, top, bottom), colorspace=colorspace)
            im = im.astype(np.float16)
        # im_scaled = im * 65535.0
        # im_clipped = np.clip(im_scaled, 0.0, 65535.0)

        # im_uint16 = im_clipped.astype(np.uint16)
        return im

In [None]:
from tqdm import tqdm

In [None]:
for key in tqdm(pair_file_list.keys()):
    image_pairs = pair_file_list[key]
    for noisy, gt in image_pairs:
        try:
            noisy_bayer = get_file(f'{raw_path}/{noisy}')
            noisy_path = outpath / (noisy + ".f16.raw")
            noisy_bayer.tofile(noisy_path)

            gt_path = outpath / (gt + ".f16.raw")
            if not os.path.exists(gt_path):
                gt_bayer = get_file(f'{raw_path}/{gt}')
                gt_bayer.tofile(gt_path)
        except:
            print(f"Skipping {raw_path}/{noisy}, {raw_path}/{gt}")