## Processor

In [1]:
import pdb

import os
from os.path import join

import rawpy
from rawpy import RawPy
import exifread

import numpy as np
import torch

import yaml
import tifffile as tff

import argparse
from tqdm import tqdm

In [2]:
def process_ratio(value, simplify=False):
    if isinstance(value, int):
        if simplify:
            return value
        else:
            return f"{value}/1"
    elif isinstance(value, exifread.utils.Ratio):
        return f"{value.num}/{value.den}"


def process_exposure_time(values):
    return process_ratio(values[0])


def process_f_number(values):
    return process_ratio(values[0])


def process_focal_length(values):
    return process_ratio(values[0])


def process_iso_sensitivity(values):
    return process_ratio(values[0], True)


def process_orientation(values):
    return "Horizontal (normal)"


def extract_extra(tags):
    field_tags = [
        ("camera_name", ["Image Model"], None),
        ("exposure_time", ["EXIF ExposureTime"], process_exposure_time),
        ("f_number", ["EXIF FNumber"], process_f_number),
        (
            "focal_length",
            ["EXIF FocalLengthIn35mmFilm", "EXIF FocalLength"],
            process_focal_length,
        ),
        ("iso_sensitivity", ["EXIF ISOSpeedRatings"], process_iso_sensitivity),
        ("orientation", ["Image Orientation"], process_orientation),
    ]
    extras = {}

    for field, tag_candidates, process in field_tags:
        found = False
        for cand in tag_candidates:
            # Keep falling back to next candidates if the tag is not found
            if cand in tags:
                values = tags[cand].values
                value = process(values) if process is not None else values
                extras[field] = value
                found = True
                break

        if not found:
            return None

    return extras


def rawpy_to_meta(raw: RawPy, extra: dict) -> dict | None:
    """
    Convert a rawpy.RawPy object to a standardized metadata dictionary.
    We assume the two green channels are the same.
    """
    if raw.color_desc.decode() != "RGBG":
        return None

    # 1. Image size (height, width)
    image_size = (raw.sizes.height // 512 * 512, raw.sizes.width // 512 * 512)
    # 2. Bayer pattern as a 2×2 uint8 array
    bayer_pattern = raw.raw_pattern.astype(np.uint8)
    bayer_pattern[bayer_pattern == 3] = 1
    # 3. Per-channel black level and overall white (saturation) level (ignore G2)
    black_level = raw.black_level_per_channel[:3]
    white_level = raw.white_level
    # 4. Camera white-balance gains (ignore G2)
    cam_wb = np.array(raw.camera_whitebalance[:3], dtype=float)
    # 5. Camera color‐correction matrix (3×3), here taking the first 3 columns (ignore G2)
    full_mat = np.array(raw.color_matrix)
    camera_matrix = full_mat[:, :3].astype(np.float32)

    return {
        "image_size": image_size,
        "bayer_pattern": bayer_pattern,
        "black_level": black_level,
        "white_level": white_level,
        "white_balance": cam_wb,
        "camera_matrix": camera_matrix,
        'camera_name': extra["camera_name"]
    }


def crop(raw, rgb, target_shape):
    shape = raw.shape
    starts = ((shape[0] - target_shape[0]) // 2, (shape[1] - target_shape[1]) // 2)
    starts = (starts[0] & -2, starts[1] & -2)
    ends = (starts[0] + target_shape[0], starts[1] + target_shape[1])
    return (
        raw[starts[0] : ends[0], starts[1] : ends[1]],
        rgb[starts[0] : ends[0], starts[1] : ends[1], :],
    )
    # return (
    #     raw[0 : target_shape[0], 0 : target_shape[1]],
    #     rgb[0 : target_shape[0], 0 : target_shape[1], :],
    # )

## Inject

In [3]:
data_folder = "../patchsets/fivek_original/raw_photos/HQa1to700/photos/"
image_folders = os.listdir(data_folder)


def load(id):
    image_folder = [folder for folder in image_folders if id in folder]
    if len(image_folder) == 0:
        return None
    
    path = join(data_folder, image_folder[0])
    raw_data = rawpy.imread(path)

    raw = raw_data.raw_image_visible
    rgb = raw_data.postprocess(user_flip=0)

    with open(path, "rb") as f:
        tags = exifread.process_file(f)
        extra = extract_extra(tags)

    meta = rawpy_to_meta(raw_data, extra)

    return raw_data, raw, rgb, extra, meta


def save(out_dir, raw, rgb, extra, meta):
    extra_path = join(out_dir, "extra.yml")
    if not os.path.exists(extra_path):
        with open(extra_path, "w") as f:
            yaml.dump(extra, f, sort_keys=True)

    meta_path = join(out_dir, "metadata.pt")
    if not os.path.exists(meta_path):
        torch.save(meta, meta_path)

    raw, rgb = crop(raw, rgb, meta["image_size"])
    height, width = meta["image_size"]

    for r in range(0, height, 512):
        for c in range(0, width, 512):
            raw_patch = raw[r : r + 512, c : c + 512, None]
            rgb_patch = rgb[r : r + 512, c : c + 512, :]

            raw_filename = f"raw-512-{r:05d}-{c:05d}.tif"
            rgb_filename = f"rgb-512-{r:05d}-{c:05d}.tif"

            raw_path = join(out_dir, raw_filename)
            rgb_path = join(out_dir, rgb_filename)

            # if 0 in raw_patch.shape or 0 in rgb_patch.shape:
            #     breakpoint()
            tff.imwrite(raw_path, raw_patch)
            tff.imwrite(rgb_path, rgb_patch)

            # if r == 0 and c == 0:
            #     print(raw_patch[:6, :6, 0])


def change_pattern(raw, source_pattern, target_pattern):
    height, width = raw.shape

    pos0 = source_pattern == 0
    pos1 = source_pattern == 1
    pos2 = source_pattern == 2
    pos3 = source_pattern == 3

    expand = (height // 2, width // 2)
    color0 = raw[np.tile(pos0, expand)]
    color1 = raw[np.tile(pos1, expand)]
    color2 = raw[np.tile(pos2, expand)]
    color3 = raw[np.tile(pos3, expand)]

    new_pos0 = target_pattern == 0
    new_pos1 = target_pattern == 1
    new_pos2 = target_pattern == 2
    new_pos3 = target_pattern == 3

    new_raw = np.zeros((height, width), dtype=np.uint16)
    new_raw[np.tile(new_pos0, expand)] = color0
    new_raw[np.tile(new_pos1, expand)] = color1
    new_raw[np.tile(new_pos2, expand)] = color2
    new_raw[np.tile(new_pos3, expand)] = color3

    return new_raw

In [4]:
a0002, raw0002, rgb0002, extra0002, meta0002 = load("a0002")
a0011, raw0011, rgb0011, extra0011, meta0011 = load("a0011")

In [5]:
meta0011["image_size"]

(1536, 2560)

In [6]:
# a0002.raw_pattern, a0011.raw_pattern

In [7]:
# a0011.raw_image_visible[:6, :6]

In [11]:
new_raw0011 = change_pattern(
    raw0011,
    a0011.raw_pattern,
    a0002.raw_pattern,
)
# meta0011["bayer_pattern"] = meta0002["bayer_pattern"]
save("../patchsets/fivek/a0000", new_raw0011, rgb0011, extra0011, meta0011)

In [None]:
# save("../patchsets/fivek/a0000", raw0011, rgb0011, extra0011, meta0011)