In [1]:
import os
import glob
from tqdm import tqdm
import cv2
from sklearn.model_selection import KFold

import numpy as np
import tensorflow as tf

In [2]:
for mode in ['sRGB', 'XYZ']:
    files = glob.glob(f"sRGB2XYZ/{mode}_*/*")
    file_dict = {}
    for file in tqdm(files[:5]):
        image = cv2.imread(file)
        file_dict[file] = image.shape

100%|██████████| 5/5 [00:00<00:00, 29.72it/s]
100%|██████████| 5/5 [00:01<00:00,  2.85it/s]


In [3]:
file_dict

{'sRGB2XYZ/XYZ_testing/a0024-_DSC8932.png': (1416, 2128, 3),
 'sRGB2XYZ/XYZ_testing/a0035-dgw_048.png': (1416, 2128, 3),
 'sRGB2XYZ/XYZ_testing/a0042-060813_155838__MG_6361.png': (2184, 1456, 3),
 'sRGB2XYZ/XYZ_testing/a0054-kme_097.png': (1296, 1944, 3),
 'sRGB2XYZ/XYZ_testing/a0070-IMG_4327.png': (2136, 1424, 3)}

In [None]:
class TFRWriter:
    def __init__(self):
        self.main_dir = "sRGB2XYZ"
        self.save_path = os.path.join(self.main_dir, "shards")
        return

    def _bytes_feature(self, value):
        """Returns a bytes_list from a string / byte."""
        if isinstance(value, type(tf.constant(0))):
            value = value.numpy()
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    def serialize_example(self, sample):
        feature = {
            "XYZ_image": self._bytes_feature(sample[0]),
            "sRGB_image": self._bytes_feature(sample[1]),
            "filename": self._bytes_feature(sample[2])}

        example_proto = tf.train.Example(
            features=tf.train.Features(feature=feature))
        return example_proto.SerializeToString()

    def get_samples(self, train_set):
        path = f"{self.main_dir}/sRGB_{train_set}/*.JPG"
        return glob.glob(path)

    def get_sample_data(self, sample):
        xyz_filename = sample.replace("sRGB_", "XYZ_").replace(".JPG", ".png")
        assert os.path.exists(xyz_filename), f"{xyz_filename} file does not exist."

        srgb_image = cv2.imread(sample)
        xyz_image = cv2.imread(xyz_filename)
        filename = os.path.basename(sample)
                
        return [
            tf.io.serialize_tensor(xyz_image, name="XYZ_image"),
            tf.io.serialize_tensor(srgb_image, name="sRGB_image"),
            tf.io.serialize_tensor(filename, name="filename")]

    def write(self):
        for train_set in ["training", "validation", "testing"]:
            shard_path = os.path.join(self.save_path, f"{train_set}.tfrec")
            samples = self.get_samples(train_set)
            with tf.io.TFRecordWriter(shard_path) as f:
                for sample in tqdm(samples, total=len(samples), desc=f"{train_set}"):
                    sample_data = self.get_sample_data(sample)
                    f.write(self.serialize_example(sample_data))

TFRWriter().write()

In [None]:
class DataLoader:
    def __init__(self):
        self.batch_size = 8
        self.buffer_size = 64
        return

    def read_tfrecord(self, example):
        feature_description = {
            'XYZ_image': tf.io.FixedLenFeature([], tf.string),
            'sRGB_image': tf.io.FixedLenFeature([], tf.string),
            'filename': tf.io.FixedLenFeature([], tf.string)}
        
        example = tf.io.parse_single_example(example, feature_description)
        example['XYZ_image'] = tf.io.parse_tensor(example['XYZ_image'], out_type=tf.uint8)
        example['sRGB_image'] = tf.io.parse_tensor(example['sRGB_image'], out_type=tf.uint8)
        example['filename'] = tf.io.parse_tensor(example['filename'], out_type=tf.string)
        return example
    
    def load_dataset(self, files):
        ignore_order = tf.data.Options()
        ignore_order.experimental_deterministic = False
        dataset = tf.data.TFRecordDataset(files)
        dataset = dataset.with_options(ignore_order)
        dataset = dataset.map(self.read_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
        return dataset

    def get_dataset(self, train_set):
        dataset = self.load_dataset(f"sRGB2XYZ/shards/{train_set}.tfrec")
        dataset = dataset.shuffle(self.buffer_size)
        dataset = dataset.batch(self.batch_size)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)
        return dataset

val = DataLoader().get_dataset('validation')
val

In [None]:
next(iter(val))