In [None]:
import pandas as pd
import numpy as np
import imageio
import cv2
import matplotlib.pyplot as plt
import os
from pathlib import Path
import re
from PIL import Image
import re
from collections import defaultdict

from colour_demosaicing import (
    ROOT_RESOURCES_EXAMPLES,
    demosaicing_CFA_Bayer_bilinear,
    demosaicing_CFA_Bayer_Malvar2004,
    demosaicing_CFA_Bayer_Menon2007,
    mosaicing_CFA_Bayer)

In [None]:
from RawHandler.RawHandler import RawHandler
from RawHandler.utils import linear_to_srgb
from src.training.load_config import load_config

def apply_gamma(x, gamma=2.2):
    return x ** (1 / gamma)

def reverse_gamma(x, gamma=2.2):
    return x ** gamma

In [None]:
run_config = load_config()
raw_path = Path(run_config['base_data_dir'])
outpath = Path(run_config['cropped_raw_subdir'])
alignment_csv =  outpath / run_config['align_csv']
colorspace = run_config['colorspace']
crop_size = run_config['cropped_raw_size']
file_list = os.listdir(raw_path)

In [None]:
def pair_images_by_scene(file_list, min_iso=100):
    """
    Given a list of RAW image file paths:
      1. Extract ISO from filenames
      2. Remove files with ISO < min_iso
      3. Group by scene name
      4. Pair each image with the lowest-ISO version of the scene

    Args:
        file_list (list of str): Paths to RAW files
        min_iso (int): Minimum ISO to keep (default=100)

    Returns:
        dict: {scene_name: [(img_path, gt_path), ...]}
    """
    iso_pattern = re.compile(r"_ISO(\d+)_")
    scene_pairs = {}

    # Step 1: Extract iso and scene
    images = []
    for path in file_list:
        filename = os.path.basename(path)
        match = iso_pattern.search(filename)
        if not match:
            continue  # skip if no ISO
        iso = int(match.group(1))
        if iso < min_iso:
            continue  # filter out low ISOs

        # Extract scene name:
        if "_GT_" in filename:
            scene = filename.split("_GT_")[0]
        else:
            # Scene = part before "_ISO"
            scene = filename.split("_ISO")[0]
        if 'X-Trans' in filename:
            continue

        images.append((scene, iso, path))

    # Step 2: Group by scene
    grouped = defaultdict(list)
    for scene, iso, path in images:
        grouped[scene].append((iso, path))

    # Step 3: For each scene, pick lowest ISO as GT
    for scene, iso_paths in grouped.items():
        iso_paths.sort(key=lambda x: x[0])  # sort by ISO ascending
        gt_iso, gt_path = iso_paths[0]      # lowest ISO ≥ min_iso
        pairs = [(path, gt_path) for iso, path in iso_paths if path != gt_path]
        scene_pairs[scene] = pairs

    return scene_pairs


In [None]:
pair_file_list = pair_images_by_scene(file_list)

In [None]:
def get_file(impath, crop_size=crop_size):
        rh = RawHandler(impath)
        
        width, height = rh.raw.shape

        # Check if the image is large enough to be cropped.
        if width < crop_size or height < crop_size:
            im = rh.apply_colorspace_transform(colorspace=colorspace)
        else:
            # Calculate the coordinates for the center crop.
            left = (width - crop_size) // 2
            top = (height - crop_size) // 2

            # Ensure the top-left corner is on an even pixel coordinate for Bayer alignment.
            if left % 2 != 0:
                left -= 1
            if top % 2 != 0:
                top -= 1
            
            # Calculate the bottom-right corner based on the adjusted top-left corner.
            # Since crop_size is even, right and bottom will also be even.
            right = left + crop_size
            bottom = top + crop_size
        return rh.raw[left:right, top:bottom], rh

In [None]:
from tqdm import tqdm

In [None]:
from pidng.core import RAW2DNG, DNGTags, Tag
from pidng.defs import *

def get_ratios(string, rh):
    return [x.as_integer_ratio() for x in rh.full_metadata[string].values]


def rational_wb(rh, denominator=1000):
    wb = np.array(rh.core_metadata.camera_white_balance)
    numerator_matrix = np.round(wb * denominator).astype(int)
    return [[num, denominator] for num in numerator_matrix]
def convert_ccm_to_rational(matrix_3x3, denominator=10000):

    numerator_matrix = np.round(matrix_3x3 * denominator).astype(int)
    numerators_flat = numerator_matrix.flatten()
    ccm_rational = [[num, denominator] for num in numerators_flat]
    
    return ccm_rational


def get_as_shot_neutral(rh, denominator=10000):

    cam_mul = rh.core_metadata.camera_white_balance
    
    if cam_mul[0] == 0 or cam_mul[2] == 0:
        return [[denominator, denominator], [denominator, denominator], [denominator, denominator]]

    r_neutral = cam_mul[1] / cam_mul[0]
    g_neutral = 1.0 
    b_neutral = cam_mul[1] / cam_mul[2]

    return [
        [int(r_neutral * denominator), denominator],
        [int(g_neutral * denominator), denominator],
        [int(b_neutral * denominator), denominator],
    ]


def to_dng(uint_img, rh, filepath):
    width = uint_img.shape[1]
    height = uint_img.shape[0]
    bpp = 16

    ccm1 = convert_ccm_to_rational(rh.core_metadata.rgb_xyz_matrix[:3, :])
    t = DNGTags()
    t.set(Tag.ImageWidth, width)
    t.set(Tag.ImageLength, height)
    t.set(Tag.TileWidth, width)
    t.set(Tag.TileLength, height)
    t.set(Tag.BitsPerSample, bpp)

    t.set(Tag.SamplesPerPixel, 1) 
    t.set(Tag.PlanarConfiguration, 1) 

    t.set(Tag.TileWidth, width)
    t.set(Tag.TileLength, height)
    t.set(Tag.PhotometricInterpretation, PhotometricInterpretation.Color_Filter_Array)
    t.set(Tag.CFARepeatPatternDim, [2,2])
    t.set(Tag.CFAPattern, CFAPattern.RGGB)
    bl = rh.core_metadata.black_level_per_channel
    t.set(Tag.BlackLevelRepeatDim, [2,2])
    t.set(Tag.BlackLevel, bl)
    t.set(Tag.WhiteLevel, rh.core_metadata.white_level)

    t.set(Tag.BitsPerSample, bpp)

    t.set(Tag.ColorMatrix1, ccm1)
    t.set(Tag.CalibrationIlluminant1, CalibrationIlluminant.D65)
    wb = get_as_shot_neutral(rh)
    t.set(Tag.AsShotNeutral, wb)
    t.set(Tag.BaselineExposure, [[0,100]])



    try:
        t.set(Tag.Make, rh.full_metadata['Image Make'].values)
        t.set(Tag.Model, rh.full_metadata['Image Model'].values)
        exposures = get_ratios('EXIF ExposureTime', rh)
        fnumber = get_ratios('EXIF FNumber', rh)
        ExposureBiasValue = get_ratios('EXIF ExposureBiasValue', rh) 
        FocalLength = get_ratios('EXIF FocalLength', rh) 
        t.set(Tag.FocalLength, FocalLength)
        t.set(Tag.EXIFPhotoLensModel, rh.full_metadata['EXIF LensModel'].values)
        t.set(Tag.ExposureBiasValue, ExposureBiasValue)
        t.set(Tag.ExposureTime, exposures)
        t.set(Tag.FNumber, fnumber)
        t.set(Tag.PhotographicSensitivity, rh.full_metadata['EXIF ISOSpeedRatings'].values)
        t.set(Tag.Orientation, rh.full_metadata['Image Orientation'].values[0])
    except:
        "ok"
    t.set(Tag.DNGVersion, DNGVersion.V1_4)
    t.set(Tag.DNGBackwardVersion, DNGVersion.V1_2)
    t.set(Tag.PreviewColorSpace, PreviewColorSpace.Adobe_RGB)

    r = RAW2DNG()

    r.options(t, path="", compress=False)

    r.convert(uint_img, filename=filepath)

In [None]:
for key in tqdm(pair_file_list.keys()):
    image_pairs = pair_file_list[key]
    for noisy, gt in image_pairs:
        noisy_path = outpath / (noisy)
        if not os.path.exists(str(noisy_path)+'.dng'):
            print(noisy_path)
            if noisy.endswith(('.cr2', '.nef', '.arw', '.orf', '.raf', '.pef', '.crw', '.dng')):
                bayer, rh = get_file(f'{raw_path}/{noisy}')
                to_dng(bayer, rh, str(noisy_path))


        gt_path = outpath / (gt)
        if not os.path.exists(str(gt_path)+'.dng'):
            print(gt_path)
            bayer, rh = get_file(f'{raw_path}/{gt}')
            to_dng(bayer, rh, str(gt_path))


In [None]:
#Testing data is properly copied

In [None]:
rhdng = RawHandler(str(outpath / "Bayer_MuseeL-sol-A7C-brighter_ISO100_sha1=18eaa9931d9a0f6f0511552ef6bf2fd040d82878.arw.dng"))
rh = RawHandler('/Volumes/EasyStore/RAWNIND/Bayer_MuseeL-sol-A7C-brighter_ISO100_sha1=18eaa9931d9a0f6f0511552ef6bf2fd040d82878.arw')


In [None]:
imdng = rhdng.as_rgb()

In [None]:
width, height = rh.raw.shape

left = (width - crop_size) // 2
top = (height - crop_size) // 2

if left % 2 != 0:
    left -= 1
if top % 2 != 0:
    top -= 1

right = left + crop_size
bottom = top + crop_size

im = rh.as_rgb(dims=(left, right, top, bottom))

In [None]:
(imdng-im).mean()

In [None]:
import pandas as pd
import os
from  torch.utils.data import Dataset
import imageio
from colour_demosaicing import (
    ROOT_RESOURCES_EXAMPLES,
    demosaicing_CFA_Bayer_bilinear,
    demosaicing_CFA_Bayer_Malvar2004,
    demosaicing_CFA_Bayer_Menon2007,
    mosaicing_CFA_Bayer)

from src.training.utils import inverse_gamma_tone_curve, cfa_to_sparse
import numpy as np
import torch 
from src.training.align_images import apply_alignment, align_clean_to_noisy
from pathlib import Path
from RawHandler.RawHandler import RawHandler



def global_affine_match(A, D, mask=None):
    """
    Fit D ≈ a + b*A with least squares.
    A, D : 2D arrays, same shape (linear values)
    mask : optional boolean array, True=use pixel
    returns: a, b, D_pred, D_resid (D - (a + b*A))
    """
    A = A.ravel().astype(np.float64)
    D = D.ravel().astype(np.float64)
    if mask is None:
        mask = np.isfinite(A) & np.isfinite(D)
    else:
        mask = mask.ravel() & np.isfinite(A) & np.isfinite(D)

    A0 = A[mask]
    D0 = D[mask]
    # design matrix [1, A]
    X = np.vstack([np.ones_like(A0), A0]).T
    coef, *_ = np.linalg.lstsq(X, D0, rcond=None)
    a, b = coef[0], coef[1]
    D_pred = (a + b * A).reshape(-1)
    D_pred = D_pred.reshape(A.shape) if False else (a + b * A).reshape((-1,))  # keep flatten

    return a, b, (a + b * A)


def random_crop_dim(shape, crop_size, buffer, validation=False):
        h, w = shape
        if not validation:
            top = np.random.randint(0 + buffer, h - crop_size - buffer)
            left = np.random.randint(0 + buffer, w - crop_size - buffer)
        else:
            top = (h - crop_size) // 2
            left = (w - crop_size) // 2

        if top % 2 != 0: top = top - 1
        if left % 2 != 0: left = left - 1
        bottom = top + crop_size
        right = left + crop_size
        return (left, right, top, bottom)

class RawDatasetDNG(Dataset):
    def __init__(self, path, csv, colorspace, crop_size=180, buffer=10, validation=False, run_align=False, dimensions=2000):
        super().__init__()
        self.df = pd.read_csv(csv)
        self.path = path
        self.crop_size = crop_size
        self.buffer = buffer
        self.coordinate_iso = 6400
        self.validation=validation
        self.run_align = run_align
        self.dtype = np.float16
        self.dimensions = dimensions
        self.colorspace = colorspace

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        # Load images
        name =  Path(f"{row.bayer_path}").name
        name = str(self.path / name.replace('_bayer.jpg', '.dng'))
        noisy_rh = RawHandler(name)
        
        name =  Path(f"{row.gt_path}").name
        name = str(self.path / name.replace('.jpg', '.dng'))
        gt_rh = RawHandler(name)


        dims = random_crop_dim(noisy_rh.raw.shape, self.crop_size, self.buffer, validation=self.validation)
        bayer_data = noisy_rh.apply_colorspace_transform(dims=dims, colorspace=self.colorspace)
        noisy = noisy_rh.as_rgb(dims=dims, colorspace=self.colorspace)
        rggb = noisy_rh.as_rggb(dims=dims, colorspace=self.colorspace)

        expanded_dims = [dims[0]-self.buffer, dims[1]+self.buffer, dims[0]-self.buffer, dims[1]+self.buffer]
        gt_expanded = gt_rh.as_rgb(dims=expanded_dims, colorspace=self.colorspace)
        aligned = apply_alignment(gt_expanded.transpose(1, 2, 0), row.to_dict())[self.buffer:-self.buffer, self.buffer:-self.buffer]
        gt_non_aligned = gt_expanded.transpose(1, 2, 0)[self.buffer:-self.buffer, self.buffer:-self.buffer]
        # Convert to tensors
        output = {
            "bayer": torch.tensor(bayer_data).to(float).clip(0,1), 
            "gt_non_aligned": torch.tensor(gt_non_aligned).to(float).permute(2, 0, 1).clip(0,1), 
            "aligned": torch.tensor(aligned).to(float).permute(2, 0, 1).clip(0,1), 
            # "sparse": torch.tensor(sparse).to(float).clip(0,1),
            "noisy": torch.tensor(noisy).to(float).clip(0,1), 
            "rggb": torch.tensor(rggb).to(float).clip(0,1),
            "conditioning": torch.tensor([row.iso/self.coordinate_iso]).to(float), 
            # "noise_est": noise_est,
            # "rggb_gt": rggb_gt,
        }
        return output

In [None]:
from src.training.load_config import load_config

run_config = load_config()
dataset_path = Path(run_config['cropped_raw_subdir'])
align_csv = dataset_path / run_config['secondary_align_csv']


device=run_config['device']

batch_size = run_config['batch_size']
lr = run_config['lr_base'] * batch_size
clipping =  run_config['clipping']

num_epochs = run_config['num_epochs_pretraining']
val_split = run_config['val_split']
crop_size = run_config['crop_size']
experiment = run_config['mlflow_experiment']
mlflow_path = run_config['mlflow_path']
colorspace = run_config['colorspace']

rggb = True

In [None]:
dataset = SmallRawDatasetNumpy(dataset_path, align_csv, colorspace, crop_size=crop_size, validation=True)


In [None]:
output = dataset[0]
import matplotlib.pyplot as plt


In [None]:
import matplotlib.pyplot as plt
plt.imshow(output['noisy'].permute(1, 2, 0)**(1/2.2))

In [None]:
import matplotlib.pyplot as plt
plt.imshow((output['noisy']-output['aligned']).permute(1, 2, 0)+0.5)

In [None]:
import matplotlib.pyplot as plt
plt.imshow((output['noisy']-output['gt_non_aligned']).permute(1, 2, 0)+0.5)

In [None]:
output['gt_non_aligned'].shape