In [None]:
import os
import numpy as np
import cv2
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

# Folders
bw_folder = "./training_data/filtered_data/Input_2"
color_input_folder = "./training_data/filtered_data/Input_1"
color_output_folder = "./training_data/filtered_data/Output"

# List files
bw_files = sorted(os.listdir(bw_folder))
color_input_files = sorted(os.listdir(color_input_folder))
color_output_files = sorted(os.listdir(color_output_folder))

# Ensure filenames match
assert bw_files == color_input_files == color_output_files, "Filenames do not match!"

# Split files into training and validation set
train_ratio = 0.8
bw_train_files, bw_val_files, _, _ = train_test_split(bw_files, bw_files, test_size=1-train_ratio, random_state=42)

# Ensure filenames match for training and validation
assert set(bw_train_files) & set(bw_val_files) == set(), "Overlap between training and validation files!"

# Load a batch of images from a list of filenames and folder
def load_batch_from_folder(folder, files, start_idx, batch_size, color=True):
    images = []
    end_idx = min(start_idx + batch_size, len(files))
    
    for i in range(start_idx, end_idx):
        img_path = os.path.join(folder, files[i])
        if color:
            img = cv2.imread(img_path, cv2.IMREAD_COLOR)
            #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        else:
            img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
            img = np.expand_dims(img, axis=-1)  # Add channel dimension
        img = img / 255.0
        images.append(img)
    return np.array(images)

# Separate generators for training and validation
def image_data_generator(files, batch_size):
    total_images = len(files)
    while True:
        for start_idx in range(0, total_images, batch_size):
            bw_batch = load_batch_from_folder(bw_folder, files, start_idx, batch_size, color=False)
            color_input_batch = load_batch_from_folder(color_input_folder, files, start_idx, batch_size)
            color_output_batch = load_batch_from_folder(color_output_folder, files, start_idx, batch_size)
            yield ([bw_batch, color_input_batch], color_output_batch)

# Example of usage
batch_size = 2
train_generator = image_data_generator(bw_train_files, batch_size)
val_generator = image_data_generator(bw_val_files, batch_size)

# for input_batch, output_batch in train_generator:
#     print('Train:', input_batch[0].shape, input_batch[1].shape, output_batch.shape)

# for bw_batch, color_input_batch, color_output_batch in val_generator:
#     print('Validation:', bw_batch.shape, color_input_batch.shape, color_output_batch.shape)


In [None]:
# from tensorflow.keras.preprocessing.image import ImageDataGenerator

# batch_size = 32

# # Data Augmentation
# datagen = ImageDataGenerator(
#     rotation_range=20,
#     zoom_range=0.15,
#     width_shift_range=0.2,
#     height_shift_range=0.2,
#     shear_range=0.15,
#     horizontal_flip=True,
#     fill_mode="nearest"
# )

# # Generator for augmented data
# def augmented_data_generator(batch_size):
#     base_generator = image_data_generator(batch_size)
    
#     for bw_batch, color_input_batch, color_output_batch in base_generator:
#         # Augment each batch
#         # Note: We're using the same seed for both black & white and color input images
#         # to ensure they undergo the same transformations.
        
#         # Augmenting BW images
#         bw_gen = datagen.flow(bw_batch, batch_size=batch_size, shuffle=False, seed=42)
        
#         # Augmenting color input images
#         color_input_gen = datagen.flow(color_input_batch, batch_size=batch_size, shuffle=False, seed=42)
        
#         # Augmenting color output images. Since we need to match outputs with inputs, 
#         # we're not shuffling and using a consistent seed.
#         color_output_gen = datagen.flow(color_output_batch, batch_size=batch_size, shuffle=False, seed=42)
        
#         yield [next(bw_gen), next(color_input_gen)], next(color_output_gen)

# # Example of usage
# aug_gen = augmented_data_generator(batch_size)
# for (bw_aug_batch, color_input_aug_batch), color_output_aug_batch in aug_gen:
#     print(bw_aug_batch.shape, color_input_aug_batch.shape, color_output_aug_batch.shape)


In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Concatenate, MaxPooling2D, UpSampling2D

def DualInput():
    # Black & white input
    bw_input = Input(shape=(2048, 1400, 1))
    bw_layer1 = Conv2D(64, (3, 3), activation="relu", padding="same", strides=2)(bw_input)
    bw_layer2 = Conv2D(128, (3, 3), activation="relu", padding="same")(bw_layer1)
    bw_layer3 = Conv2D(128, (3, 3), activation="relu", padding="same", strides=2)(bw_layer2)
    bw_layer4 = Conv2D(256, (3, 3), activation="relu", padding="same")(bw_layer3)
    bw_layer5 = Conv2D(256, (3, 3), activation="relu", padding="same", strides=2)(bw_layer4)
    bw_layer6 = Conv2D(512, (3, 3), activation="relu", padding="same")(bw_layer5)
    bw_layer7 = Conv2D(512, (3, 3), activation="relu", padding="same")(bw_layer6)
    bw_layer8 = Conv2D(256, (3, 3), activation="relu", padding="same")(bw_layer7)
    
    # Colored reference input
    color_input = Input(shape=(2048, 1400, 3))
    color_layer1 = Conv2D(64, (3, 3), activation="relu", padding="same", strides=2)(color_input)
    color_layer2 = Conv2D(128, (3, 3), activation="relu", padding="same")(color_layer1)
    color_layer3 = Conv2D(128, (3, 3), activation="relu", padding="same", strides=2)(color_layer2)
    color_layer4 = Conv2D(256, (3, 3), activation="relu", padding="same")(color_layer3)
    color_layer5 = Conv2D(256, (3, 3), activation="relu", padding="same", strides=2)(color_layer4)
    color_layer6 = Conv2D(512, (3, 3), activation="relu", padding="same")(color_layer5)
    color_layer7 = Conv2D(512, (3, 3), activation="relu", padding="same")(color_layer6)
    color_layer8 = Conv2D(256, (3, 3), activation="relu", padding="same")(color_layer7)

    # Merge inputs
    merge_layer1 = Concatenate()([bw_layer8, color_layer8])
    
    merge_layer2 = (Conv2D(128, (3, 3), activation="relu", padding="same"))(merge_layer1)
    merge_layer3 = (UpSampling2D((2, 2)))(merge_layer2)
    merge_layer4 = (Conv2D(64, (3, 3), activation="relu", padding="same"))(merge_layer3)
    merge_layer5 = (Conv2D(64, (3, 3), activation="relu", padding="same"))(merge_layer4)
    merge_layer6 = (UpSampling2D((2, 2)))(merge_layer5)
    merge_layer7 = (Conv2D(32, (3, 3), activation="relu", padding="same"))(merge_layer6)
    merge_layer8 = (Conv2D(2, (3, 3), activation="relu", padding="same"))(merge_layer7)
    merge_layer9 = (UpSampling2D((2, 2)))(merge_layer8)

    # Output layer to produce a 3-channel image
    outputs = Conv2D(3, (1, 1), activation='tanh')(merge_layer9)

    return tf.keras.Model(inputs=[bw_input, color_input], outputs=outputs)

model = DualInput()
model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])


In [None]:
import tensorflow as tf
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

In [None]:
model.summary()

In [None]:
import math
import datetime
steps_per_epoch = math.ceil(len(bw_train_files) / batch_size)
validation_steps = math.ceil(len(bw_val_files) / batch_size)
 
time = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
log_dir = "./logs/fit/" + time
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

checkpoint_path = f"./checkpoints/{time}/cp_{time}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

earlystop_callback = tf.keras.callbacks.EarlyStopping(monitor='loss', 
                                                      patience=2, 
                                                      verbose=1,
                                                     restore_best_weights=True)

# 3. Train the model
with tf.device('/GPU:0'):
    history = model.fit(
        train_generator,
        validation_data=val_generator,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps,
        epochs=3,
        verbose= 1,
        callbacks=[tensorboard_callback, cp_callback, earlystop_callback]
    )

In [None]:
# Access the last loss and accuracy values from the history object
training_loss = round(history.history['loss'][-1], 2)
training_accuracy = round(history.history['accuracy'][-1], 2)
validation_loss = round(history.history['val_loss'][-1], 2)
validation_accuracy = round(history.history['val_accuracy'][-1], 2)

model_save_path = "./models"
# Save the entire model as a `.keras` zip archive.
model.save(f'{model_save_path}/model-{time}-{training_loss}-{training_accuracy}-{validation_loss}-{validation_accuracy}.keras')