In [None]:
import tensorflow as tf
import numpy as np
import os, shutil
from tensorflow.keras.layers import Conv2D, Add, ReLU, Concatenate, Lambda
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import Input, Model
from tensorflow.keras.models import load_model
import glob
import random
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import matplotlib as mpl

HYPERPARAMETERS

In [None]:
n_epochs = 8
batch_size = 10
learning_rate = 1e-4
weight_decay = 1e-4

NETWORK DEFINITION

In [None]:
def haze_net(input_shape=(480, 640, 3), weight_decay=0.0):
    X = Input(shape=input_shape)

    conv1 = Conv2D(3, (1, 1), padding="SAME", activation="relu", use_bias=True,
                   kernel_initializer=RandomNormal(stddev=0.02),
                   kernel_regularizer=l2(weight_decay))(X)

    conv2 = Conv2D(3, (3, 3), padding="SAME", activation="relu", use_bias=True,
                   kernel_initializer=RandomNormal(stddev=0.02),
                   kernel_regularizer=l2(weight_decay))(conv1)
    concat1 = Concatenate(axis=-1)([conv1, conv2])

    conv3 = Conv2D(3, (5, 5), padding="SAME", activation="relu", use_bias=True,
                   kernel_initializer=RandomNormal(stddev=0.02),
                   kernel_regularizer=l2(weight_decay))(concat1)
    concat2 = Concatenate(axis=-1)([conv2, conv3])

    conv4 = Conv2D(3, (7, 7), padding="SAME", activation="relu", use_bias=True,
                   kernel_initializer=RandomNormal(stddev=0.02),
                   kernel_regularizer=l2(weight_decay))(concat2)
    concat3 = Concatenate(axis=-1)([conv1, conv2, conv3, conv4])

    conv5 = Conv2D(3, (3, 3), padding="SAME", activation="relu", use_bias=True,
                   kernel_initializer=RandomNormal(stddev=0.02),
                   kernel_regularizer=l2(weight_decay))(concat3)

    K = conv5

    output = ReLU(max_value=1.0)(K * X - K + 1.0)

    model = Model(inputs=X, outputs=output)

    trainable_variables = []
    for layer in model.layers:
        trainable_variables += layer.trainable_variables

    return model, trainable_variables
    
def haze_res_net(X, weight_decay=0.0):
    conv1 = Conv2D(3, (1, 1), padding="SAME", activation="relu", use_bias=True,
                   kernel_initializer=tf.initializers.RandomNormal(),
                   kernel_regularizer=l2(weight_decay))(X)

    conv2 = Conv2D(3, (3, 3), padding="SAME", activation="relu", use_bias=True,
                   kernel_initializer=tf.initializers.RandomNormal(),
                   kernel_regularizer=l2(weight_decay))(conv1)

    add1 = Add()([conv1, conv2])

    conv3 = Conv2D(3, (5, 5), padding="SAME", activation="relu", use_bias=True,
                   kernel_initializer=tf.initializers.RandomNormal(),
                   kernel_regularizer=l2(weight_decay))(add1)

    conv4 = Conv2D(3, (7, 7), padding="SAME", activation="relu", use_bias=True,
                   kernel_initializer=tf.initializers.RandomNormal(),
                   kernel_regularizer=l2(weight_decay))(conv3)

    add2 = Add()([conv3, conv4])

    conv5 = Conv2D(3, (3, 3), padding="SAME", activation="relu", use_bias=True,
                   kernel_initializer=tf.initializers.RandomNormal(),
                   kernel_regularizer=l2(weight_decay))(add2)

    add3 = Add()([conv5, conv1])
    K = add3

    output = ReLU(max_value=1.0)(tf.math.multiply(K, X) - K + 1.0)

    model = tf.keras.Model(inputs=X, outputs=output)

    trainable_variables = []
    for layer in model.layers:
        trainable_variables += layer.trainable_variables

    return model, trainable_variables

DATA LOADING AND PRE-PROCESSING

In [None]:
def setup_data_paths(orig_images_path, hazy_images_path):
    orig_image_paths = glob.glob(orig_images_path + "/*.jpg")
    n = len(orig_image_paths)
    random.shuffle(orig_image_paths)

    train_keys = orig_image_paths[:int(0.90 * n)]
    val_keys = orig_image_paths[int(0.90 * n):]

    split_dict = {}
    for key in train_keys:
        split_dict[key] = 'train'
    for key in val_keys:
        split_dict[key] = 'val'

    train_data = []
    val_data = []

    hazy_image_paths = glob.glob(hazy_images_path + "/*.jpg")
    for path in hazy_image_paths:
        label = os.path.basename(path)
        orig_filename = label.split('_')[0] + '_' + label.split('_')[1] + ".jpg"
        orig_path = os.path.join(orig_images_path, orig_filename)
        if orig_path in split_dict:
            if split_dict[orig_path] == 'train':
                train_data.append([path, orig_path])
            else:
                val_data.append([path, orig_path])
        else:
            print(f"Warning: {orig_path} not found in split_dict.")

    return train_data, val_data

In [None]:
def load_image(X):
  X = tf.io.read_file(X)
  X = tf.image.decode_jpeg(X,channels=3)
  X = tf.image.resize(X,(480,640))
  X = X / 255.0
  return X

def showImage(x):
  x = np.asarray(x*255,dtype=np.int32)
  plt.figure()
  plt.imshow(x)
  plt.show()

In [None]:
def create_datasets(train_data, val_data, batch_size):
    train_ds_hazy = tf.data.Dataset.from_tensor_slices([data[0] for data in train_data]).map(lambda x: load_image(x))
    train_ds_orig = tf.data.Dataset.from_tensor_slices([data[1] for data in train_data]).map(lambda x: load_image(x))
    train_ds = tf.data.Dataset.zip((train_ds_hazy, train_ds_orig)).shuffle(100).repeat().batch(batch_size)

    val_ds_hazy = tf.data.Dataset.from_tensor_slices([data[0] for data in val_data]).map(lambda x: load_image(x))
    val_ds_orig = tf.data.Dataset.from_tensor_slices([data[1] for data in val_data]).map(lambda x: load_image(x))
    val_ds = tf.data.Dataset.zip((val_ds_hazy, val_ds_orig)).shuffle(100).repeat().batch(batch_size)

    iterator = iter(train_ds)

    train_init_op = iterator
    val_init_op = iter(val_ds)

    return train_init_op, val_init_op, iterator

TRAINING

In [None]:
np.random.seed(9999)

train_data, val_data = setup_data_paths(
    orig_images_path="./All-In-One-Image-Dehazing-Tensorflow/data/orig_images",
    hazy_images_path="./All-In-One-Image-Dehazing-Tensorflow/data/hazy_images"
)

train_init_op, val_init_op, iterator = create_datasets(train_data, val_data, batch_size)
next_element = iterator.get_next()

hazy_image_placeholder = tf.keras.Input(shape=(480, 640, 3), dtype=tf.float32)
original_image_placeholder = tf.keras.Input(shape=(480, 640, 3), dtype=tf.float32)

model, dehazed_X = haze_net(input_shape=(480, 640, 3))

optimizer = Adam(learning_rate)
model.compile(optimizer=optimizer, loss='mean_squared_error')

trainable_variables = model.trainable_variables


In [None]:
# Training loop
for epoch in range(n_epochs):
    epoch_loss = 0.0
    num_steps = 0

    for step, (hazy_batch, original_batch) in enumerate(train_init_op):
        with tf.GradientTape() as tape:
            dehazed_batch = model(hazy_batch, training=True)
            current_loss = tf.reduce_mean(tf.square(dehazed_batch - original_batch))

        gradients = tape.gradient(current_loss, trainable_variables)
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, 0.1)
        optimizer.apply_gradients(zip(clipped_gradients, trainable_variables))

        epoch_loss += current_loss.numpy()
        num_steps += 1

        if step % 10 == 0:
            print(f"Epoch {epoch}, Step {step}, Loss: {current_loss.numpy()}")

        avg_epoch_loss = epoch_loss / num_steps
        if avg_epoch_loss < 0.1:
            print("Average epoch loss below 0.1. Training stopped.")
            break

    model.save("dehazer1.h5")

EVALUATION

In [None]:
val_iterator = iter(val_init_op)
val_steps_per_epoch = len(val_data) // batch_size

# Evaluation loop
for epoch in range(n_epochs):
    total_val_loss = 0.0
    for step in range(val_steps_per_epoch):
        hazy_batch, original_batch = next(val_iterator)

        dehazed_batch = model(hazy_batch, training=False)

        val_loss = tf.reduce_mean(tf.square(dehazed_batch - original_batch))
        total_val_loss += val_loss.numpy()

    avg_val_loss = total_val_loss / val_steps_per_epoch

    print(f"Epoch {epoch}, Validation Loss: {avg_val_loss}")

model.save("dehazer1.h5")