<a href="https://colab.research.google.com/github/soumik12345/image-restoration-primer/blob/main/01_train_aodnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<!--- @wandbcode{tfug-kol} -->

In [None]:
!pip install -q --upgrade wandb

In [None]:
import os
import random
import numpy as np
from glob import glob
from functools import partial
from tqdm.autonotebook import tqdm

import wandb
from wandb.keras import (
    WandbMetricsLogger,
    WandbModelCheckpoint,
    WandbEvalCallback
)

import tensorflow as tf
from tensorflow import keras

In [None]:
# Login to Weights & Biases
wandb.login()

In [None]:
wandb_project = "image-dehazing" #@param {type:"string"}
wandb_entity = "geekyrakshit" #@param {type:"string"}
wandb.init(
    project=wandb_project, entity=wandb_entity, job_type="train"
)

config = wandb.config
config.dataset_artifact = 'geekyrakshit/image-dehazing/dehaze-dataset:v0' #@param {type:"string"}
config.seed = 42 #@param {type:"raw"}

keras.utils.set_random_seed(config.seed)

artifact = wandb.use_artifact(config.dataset_artifact, type='dataset')
artifact_dir = artifact.download()

In [None]:
def get_image_file_list_from_dehazy(dataset_path):
    ground_truth_files = []
    hazy_image_paths = sorted(glob(str(os.path.join(dataset_path, 'train_images/*.jpg'))))
    for image_path in hazy_image_paths:
        image_file_name = image_path.split('/')[-1]
        ground_truth_file_name = image_file_name.split('_')[0] + '_' + image_file_name.split('_')[1] + '.jpg'
        ground_truth_files.append(str(os.path.join(
            dataset_path, 'original_images/' + ground_truth_file_name)))
    return hazy_image_paths, ground_truth_files


dehazy_dataset_path = os.path.join(artifact_dir, "Dehazing")
dehazy_hazy_image_paths, dehazy_ground_truth_paths = get_image_file_list_from_dehazy(dehazy_dataset_path)
print("Number of Hazy Images:", len(dehazy_hazy_image_paths))
print("Number of Ground-truth Images:", len(dehazy_ground_truth_paths))

In [None]:
config.val_split = 0.2 #@param {type:"slider", min:0, max:1, step:0.01}
num_train_images = len(dehazy_hazy_image_paths) - int(len(dehazy_hazy_image_paths) * config.val_split)

train_hazy_image_paths = dehazy_hazy_image_paths[:num_train_images]
train_ground_truth_image_paths = dehazy_ground_truth_paths[:num_train_images]

val_hazy_image_paths = dehazy_hazy_image_paths[num_train_images:]
val_ground_truth_image_paths = dehazy_hazy_image_paths[num_train_images:]

In [None]:
def read_image(image_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, channels=3)
    image = tf.cast(image, dtype=tf.float32) / 255.0
    return image


def random_crop(input_image, enhanced_image, image_size):
    input_image_shape = tf.shape(input_image)[:2]
    low_w = tf.random.uniform(
        shape=(), maxval=input_image_shape[1] - image_size + 1, dtype=tf.int32
    )
    low_h = tf.random.uniform(
        shape=(), maxval=input_image_shape[0] - image_size + 1, dtype=tf.int32
    )
    enhanced_w = low_w
    enhanced_h = low_h
    input_image_cropped = input_image[
        low_h : low_h + image_size, low_w : low_w + image_size
    ]
    enhanced_image_cropped = enhanced_image[
        enhanced_h : enhanced_h + image_size,
        enhanced_w : enhanced_w + image_size
    ]
    return input_image_cropped, enhanced_image_cropped


def load_data(input_image_path, enhanced_image_path, image_size):
    input_image = read_image(input_image_path)
    enhanced_image = read_image(enhanced_image_path)
    input_image, enhanced_image = random_crop(input_image, enhanced_image, image_size)
    return input_image, enhanced_image


def get_dataset(input_images, enhanced_images, image_size, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices((input_images, enhanced_images))
    dataset = dataset.map(
        partial(load_data, image_size=image_size),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    dataset = dataset.batch(batch_size, drop_remainder=True)
    return dataset

In [None]:
config.image_size = 256 #@param {type:"integer"}
config.batch_size = 16 #@param {type:"integer"}

train_dataset = get_dataset(train_hazy_image_paths, train_ground_truth_image_paths, config.image_size, config.batch_size)
val_dataset = get_dataset(val_hazy_image_paths, val_ground_truth_image_paths, config.image_size, config.batch_size)

In [None]:
class AODNet(tf.keras.Model):

    def __init__(self, stddev: float = 0.02, weight_decay: float = 1e-4):
        super(AODNet, self).__init__()
        self.conv_layer_1 = keras.layers.Conv2D(
            filters=3, kernel_size=1, strides=1,
            padding="same", activation="relu", use_bias=True,
            kernel_initializer=keras.initializers.random_normal(stddev=stddev),
            kernel_regularizer=keras.regularizers.L2(weight_decay)
        )
        self.conv_layer_2 = keras.layers.Conv2D(
            filters=3, kernel_size=1, strides=1,
            padding="same", activation="relu", use_bias=True,
            kernel_initializer=keras.initializers.random_normal(stddev=stddev),
            kernel_regularizer=keras.regularizers.L2(weight_decay)
        )
        self.conv_layer_3 = keras.layers.Conv2D(
            filters=3, kernel_size=5, strides=1,
            padding="same", activation="relu", use_bias=True,
            kernel_initializer=keras.initializers.random_normal(stddev=stddev),
            kernel_regularizer=keras.regularizers.L2(weight_decay)
        )
        self.conv_layer_4 = keras.layers.Conv2D(
            filters=3, kernel_size=7, strides=1,
            padding="same", activation="relu", use_bias=True,
            kernel_initializer=keras.initializers.random_normal(stddev=stddev),
            kernel_regularizer=keras.regularizers.L2(weight_decay)
        )
        self.conv_layer_5 = keras.layers.Conv2D(
            filters=3, kernel_size=3, strides=1,
            padding="same", activation="relu", use_bias=True,
            kernel_initializer=keras.initializers.random_normal(stddev=stddev),
            kernel_regularizer=keras.regularizers.L2(weight_decay)
        )
        self.relu = keras.layers.ReLU(max_value=1.0)

    def call(self, inputs, *args, **kwargs):
        conv_1 = self.conv_layer_1(inputs)
        conv_2 = self.conv_layer_2(conv_1)
        concat_1 = tf.concat([conv_1, conv_2], axis=-1)
        conv_3 = self.conv_layer_3(concat_1)
        concat_2 = tf.concat([conv_2, conv_3], axis=-1)
        conv_4 = self.conv_layer_4(concat_2)
        concat_3 = tf.concat([conv_1, conv_2, conv_3, conv_4], axis=-1)
        k = self.conv_layer_5(concat_3)
        j = tf.math.multiply(k, inputs) - k + 1.0
        output = self.relu(j)
        return output

In [None]:
config.stddev = 0.02 #@param {type:"number"}
config.weight_decay = 1e-4 #@param {type:"number"}
config.learning_rate = 1e-4 #@param {type:"number"}
config.use_cosine_decay = True #@param {type:"boolean"}
config.epochs = 30 #@param {type:"slider", min:1, max:30, step:1}
config.save_best_only = True #@param {type:"boolean"}


model = AODNet(
    stddev=config.stddev,
    weight_decay=config.weight_decay
)
model.build((1, config.image_size, config.image_size, 3))

def peak_signal_noise_ratio(y_true, y_pred):
    return tf.image.psnr(y_pred, y_true, max_val=1.0)

def structural_similarity(y_true, y_pred):
    return tf.image.ssim(y_pred, y_true, max_val=1.0)

lr_schedule_fn = keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=config.learning_rate,
    decay_steps=tf.data.experimental.cardinality(train_dataset).numpy() * config.epochs,
    alpha=1e-6,
) if config.use_cosine_decay else config.learning_rate

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=lr_schedule_fn),
    loss=keras.losses.MeanSquaredError(),
    metrics=[
        peak_signal_noise_ratio,
        structural_similarity
    ]
)

In [None]:
class DehazingEvaluationCallback(WandbEvalCallback):

    def __init__(
        self, validation_data, data_table_columns, pred_table_columns
    ):
        super().__init__(data_table_columns, pred_table_columns)
        self.evaluation_summary_table = wandb.Table(columns=pred_table_columns)
        self.validation_data = validation_data
        self.dataset_cardinality = tf.data.experimental.cardinality(validation_data).numpy()
    
    def postprocess(self, image):
        return (image * 255.0).clip(0, 255).astype(np.uint8)
    
    def add_ground_truth(self, logs=None):
        for _ in tqdm(range(self.dataset_cardinality)):
            hazy_image_batch, ground_truth_batch = next(iter(self.validation_data))
            hazy_image_batch, ground_truth_batch = (
                hazy_image_batch.numpy(), ground_truth_batch.numpy()
            )
            hazy_image_batch = self.postprocess(hazy_image_batch)
            ground_truth_batch = self.postprocess(ground_truth_batch)
            for hazy_image, ground_truth in zip(hazy_image_batch, ground_truth_batch):
                self.data_table.add_data(wandb.Image(hazy_image), wandb.Image(ground_truth))
    
    def add_model_predictions(self, epoch, logs=None):
        for count in tqdm(range(self.dataset_cardinality)):
            hazy_image_batch, ground_truth_batch = next(iter(self.validation_data))
            prediction_batch = self.model.predict(hazy_image_batch, verbose=0)
            psnr = tf.image.psnr(ground_truth_batch, prediction_batch, max_val=1.0).numpy()
            ssim = tf.image.ssim(ground_truth_batch, prediction_batch, max_val=1.0).numpy()
            prediction_batch = self.postprocess(prediction_batch)
            data_table_ref = self.data_table_ref
            table_idxs = data_table_ref.get_index()
            for idx, prediction in enumerate(prediction_batch):
                self.pred_table.add_data(
                    epoch,
                    data_table_ref.data[idx + count][0],
                    data_table_ref.data[idx + count][1],
                    wandb.Image(prediction),
                    psnr[idx],
                    ssim[idx]
                )
                self.evaluation_summary_table.add_data(
                    epoch,
                    data_table_ref.data[idx + count][0],
                    data_table_ref.data[idx + count][1],
                    wandb.Image(prediction),
                    psnr[idx],
                    ssim[idx]
                )
        
    def on_train_end(self, logs=None):
        wandb.run.log({"Evaluation-Table": self.evaluation_summary_table})

In [None]:
callbacks = [
    WandbMetricsLogger(),
    WandbModelCheckpoint(
        filepath="aodnet",
        save_best_only=config.save_best_only
    ),
    DehazingEvaluationCallback(
        val_dataset.take(2),
        data_table_columns=["Hazy-Image", "Ground-Truth"],
        pred_table_columns=[
            "Epoch",
            "Hazy-Image",
            "Ground-Truth",
            "Predicted-Image",
            "Peak-Signal-To-Noise-Ratio",
            "Structural-Similarity"
        ]
    )
]

model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=config.epochs,
    callbacks=callbacks
)

In [None]:
wandb.finish()