# **TRAIN LLE_UNET FOR IMAGE DENOISING**                    

**Dataset - smartphone-image-denoising-dataset 
from Kaggle**

https://www.kaggle.com/code/pamuduranasinghe/train-lle-unet-for-image-denoisng

This can be directly use from the Kaggle with its dataset

In [None]:
import os
import cv2
import sys
import numpy as np
from matplotlib import pyplot as plt
import math
import random


import tensorflow as tf
from tensorflow.keras.layers import Dense, Activation, Concatenate, GlobalAveragePooling2D, Multiply,GlobalMaxPooling2D
from tensorflow.keras.layers import Conv2D, BatchNormalization, MaxPool2D, Conv2DTranspose, Input
from tensorflow.keras.models import Model
from tensorflow.keras.applications.vgg16 import VGG16

In [None]:
# set Radndom Seed
SEED = 0
np.random.seed(SEED)
tf.random.set_seed(SEED)
tf.config.run_functions_eagerly(True)

In [None]:
# Avoid OOM errors by setting GPU Memory Consumption Growth
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

**Padding mechanism to work with images of any width and height**

In [None]:
def padding_calc(input_dim,multiplier=32):
    return math.ceil(input_dim/multiplier)*multiplier - input_dim

# Add Padding
def pad_image(image,mood = "center_padding"):
    img_h = image.shape[0]
    img_w = image.shape[1]
  
    pad_y = padding_calc(img_h)
    pad_x = padding_calc(img_w)
    
    if mood == "center_padding":
        pad_y2 = pad_y//2
        pad_x2 = pad_x//2

        padded_img = image.copy()
        if pad_y%2 != 0:
            padded_img = np.pad(image, ((pad_y2, pad_y2+1), (pad_x2, pad_x2), (0, 0)), mode='constant')
        if pad_x%2 != 0:
            padded_img = np.pad(image, ((pad_y2, pad_y2), (pad_x2, pad_x2+1), (0, 0)), mode='constant')
        if (pad_y%2 == 0) & (pad_x%2 == 0):
            padded_img = np.pad(image, ((pad_y2, pad_y2), (pad_x2, pad_x2), (0, 0)), mode='constant')

    elif mood == "corner_padding":
        padded_img = np.pad(image, ((0, pad_y), (0, pad_x), (0, 0)), mode='constant')
    return padded_img

# Remove Padding
def inverse_padding(pad_image,image_dim,pad_method="center_padding"):
  pad_img_height = pad_image.shape[0]
  pad_img_width = pad_image.shape[1]
  
  img_height = image_dim[0]
  img_width = image_dim[1]
  
  if pad_method == "center_padding":
    pad_y1 = (pad_img_height - img_height)//2
    if pad_y1*2 == (pad_img_height - img_height):pad_y2 = pad_y1
    else: pad_y2 = pad_y1+1

    pad_x1 = (pad_img_width - img_width)//2
    if pad_x1*2 == (pad_img_width - img_width):pad_x2 = pad_x1
    else: pad_x2 = pad_x1+1
    extract_image = pad_image[pad_y1:pad_img_height-pad_y2,pad_x1:pad_img_width-pad_x2]
  
  if pad_method == "corner_padding":
    extract_image = pad_image[0:img_height,0:img_width]

 
  return extract_image

In [None]:
## test image padding mechanism
test_H = 300
test_W = 533
test_image = np.zeros((test_H, test_W, 3), dtype=np.uint8)
img_pad = pad_image(test_image,'center_padding')
print(test_image.shape,img_pad.shape)

In [None]:
data_path = r'/kaggle/input/smartphone-image-denoising-dataset/SIDD_Small_sRGB_Only/Data'

In [None]:
gt_image_paths    = []
noisy_image_paths = []
img_folder_names  = os.listdir(data_path)

for img_folder_name in img_folder_names:
    img_names = os.listdir(os.path.join(data_path,img_folder_name))
    for img_name in img_names:
        img_label = img_name.split("_")[0]
        img_path = os.path.join(data_path,img_folder_name,img_name)
        if img_label == "GT":
            gt_image_paths.append(img_path)
        elif img_label == "NOISY":
            noisy_image_paths.append(img_path)
        else:
            print(f'{img_path} NOT GT OR NOISY')

In [None]:
gt_image_paths    = sorted(gt_image_paths)
noisy_image_paths = sorted(noisy_image_paths)

print(f'N_GROUND TRUTH IMAGES : {len(gt_image_paths)}')
print(f'N_NOISY IMAGES : {len(noisy_image_paths)}')

In [None]:
def split_data(image_paths, ground_truth_paths, train_ratio, val_ratio, test_ratio):        
    data_pairs = list(zip(image_paths, ground_truth_paths))
    random.shuffle(data_pairs)
    
    # Calculate the number of samples for each split
    total_samples = len(data_pairs)
    train_samples = int(train_ratio * total_samples)
    val_samples = int(val_ratio * total_samples)
    test_samples = total_samples - train_samples - val_samples
    
    # Split the data into training, validation, and testing sets
    train_data = data_pairs[:train_samples]
    val_data = data_pairs[train_samples:train_samples+val_samples]
    test_data = data_pairs[train_samples+val_samples:]
    
    # Unzip the pairs into separate lists for images and ground truth
    train_x, train_y = zip(*train_data)
    val_x, val_y = zip(*val_data)
    test_x, test_y = zip(*test_data)
    
    return train_x, train_y, val_x, val_y, test_x, test_y

**Split data into Train, Validation and Test**

In [None]:
train_ratio = 0.8  
val_ratio   = 0.1  
test_ratio  = 0.1

train_x_paths, train_y_paths, val_x_paths, val_y_paths, test_x_paths, test_y_paths = split_data(gt_image_paths, noisy_image_paths, train_ratio, val_ratio, test_ratio)
train_x_paths = sorted(train_x_paths)
train_y_paths = sorted(train_y_paths)
val_x_paths = sorted(val_x_paths)
val_y_paths = sorted(val_y_paths)
test_x_paths = sorted(test_x_paths)
test_y_paths = sorted(test_y_paths)

In [None]:
print(f'X_train : {len(train_x_paths)}')
print(f'Y_train : {len(train_y_paths)}')
print(f'X_val   : {len(val_x_paths)}')
print(f'Y_val   : {len(val_y_paths)}')
print(f'X_test  : {len(test_x_paths)}')
print(f'Y_test  : {len(test_y_paths)}')

**Create DataLoading functions for Load the Kaggle SIDD Dataset**

In [None]:
@tf.function
def load_image_file(file_path):
    file_path = file_path.numpy().decode("utf-8")
    img = cv2.imread(file_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img,(1024,1024))
    # img = cv2.resize(img,(0,0),fx=0.5, fy=0.5)
    preprocess_img = pad_image(img)
    preprocess_img = preprocess_img/255
    return img

def image_dataset(image_list):
    files = tf.data.Dataset.from_tensor_slices(image_list)
    dataset = files.map(lambda x: tf.py_function(load_image_file, [x], tf.float32))
    return dataset

In [None]:
train_x = image_dataset(list(train_x_paths))
train_y = image_dataset(list(train_y_paths))

BATCH_SIZE = 2

# combine input and output
train = tf.data.Dataset.zip((train_x, train_y))
# train = train.take(100)
# train = train.shuffle(100)
train = train.batch(BATCH_SIZE)
train.prefetch(tf.data.AUTOTUNE)

In [None]:
val_x = image_dataset(list(val_x_paths))
val_y = image_dataset(list(val_y_paths))

# combine input and output
val = tf.data.Dataset.zip((val_x, val_y))
# train = train.shuffle(100)
val = val.batch(BATCH_SIZE)
val.prefetch(tf.data.AUTOTUNE)

In [None]:
test_x = image_dataset(list(test_x_paths))
test_y = image_dataset(list(test_y_paths))

# combine input and output
test = tf.data.Dataset.zip((test_x, test_y))
# train = train.shuffle(100)
test = test.batch(BATCH_SIZE)
test.prefetch(tf.data.AUTOTUNE)

In [None]:
# sample = train.take(1)
# train_sample = sample.as_numpy_iterator()
# res = train_sample.next()
# print(len(res))

# res_low_1 = res[0][0]
# res_high_1 = res[1][0]

# fig, axs = plt.subplots(ncols=2);
# axs[0].imshow(res_low_1);
# axs[1].imshow(res_high_1);

# **LLE UNET**
[LLE UNET FULL CODE](https://github.com/pamudu123/Low_Light_Image_Enhancement)

**CBAM ATTENTION**

In [None]:
def channel_attention_module(x, ratio=8):
    batch, _, _, channel = x.shape

    ## Shared layers
    l1 = Dense(channel//ratio, activation="relu", use_bias=False)
    l2 = Dense(channel, use_bias=False)

    ## Global Average Pooling
    x1 = GlobalAveragePooling2D()(x)
    x1 = l1(x1)
    x1 = l2(x1)

    ## Global Max Pooling
    x2 = GlobalMaxPooling2D()(x)
    x2 = l1(x2)
    x2 = l2(x2)

    ## Add both the features and pass through sigmoid
    feats = x1 + x2
    feats = Activation("sigmoid")(feats)
    feats = Multiply()([x, feats])

    return feats

def spatial_attention_module(x):
    ## Average Pooling
    x1 = tf.reduce_mean(x, axis=-1)
    x1 = tf.expand_dims(x1, axis=-1)

    ## Max Pooling
    x2 = tf.reduce_max(x, axis=-1)
    x2 = tf.expand_dims(x2, axis=-1)

    ## Concatenat both the features
    feats = Concatenate()([x1, x2])
    ## Conv layer
    feats = Conv2D(1, kernel_size=7, padding="same", activation="sigmoid")(feats)
    feats = Multiply()([x, feats])

    return feats

def CBAM(x):
    x = channel_attention_module(x)
    x = spatial_attention_module(x)
    return x

In [None]:
def conv_block(input, num_filters):
    x = Conv2D(num_filters, 3, padding="same")(input)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

def encoder_block(input, num_filters):
    x = conv_block(input, num_filters)
    p = MaxPool2D((2, 2))(x)
    return x, p

def decoder_block(input, skip_features, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

def build_model(input_shape):
    inputs = Input(input_shape)

    vgg_model = VGG16(include_top=False, weights="imagenet",input_tensor=inputs)
    vgg_model.trainable = False

    # Encoder
    s1 = vgg_model.get_layer("block1_conv2").output                             ## (512 x 512)
    s2 = vgg_model.get_layer("block2_conv2").output                             ## (256 x 256)
    s3 = vgg_model.get_layer("block3_conv3").output                             ## (128 x 128)
    s4 = vgg_model.get_layer("block4_conv3").output                             ## (64 x 64)

    b1 = vgg_model.get_layer("block5_conv3").output                             ## (32 x 32)

    # Attention 
    s1 = CBAM(s1)
    s2 = CBAM(s2)
    s3 = CBAM(s3)
    s4 = CBAM(s4)

    # Decoder
    d1 = decoder_block(b1, s4, 512)                                             ## (64 x 64)
    d2 = decoder_block(d1, s3, 256)                                             ## (128 x 128)
    d3 = decoder_block(d2, s2, 128)                                             ## (256 x 256)
    d4 = decoder_block(d3, s1, 64)                                              ## (512 x 512)

    # Output
    outputs = Conv2D(3, 1, padding="same", activation="sigmoid")(d4)

    model = Model(inputs, outputs, name="VGG_Model_LowLight_Enhancement")
    return model


In [None]:
input_shape = (None, None, 3)
model = build_model(input_shape)
model.summary()

In [None]:
print("trainable_weights:", len(model.trainable_weights))
print("non_trainable_weights:", len(model.non_trainable_weights))

In [None]:
#optimizer = tf.keras.optimizers.experimental.AdamW(learning_rate=0.0001,weight_decay=0.004)
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4)
MSEloss = tf.keras.losses.MeanSquaredError()

def charbonnier_loss(y_true, y_pred):
    return tf.reduce_mean(tf.sqrt(tf.square(y_true - y_pred) + tf.square(1e-3)))

def psnr_loss_fn(y_true, y_pred):
    return tf.image.psnr(y_pred, y_true, max_val=1.0)

def ssim_loss_fn(y_true,y_pred):
    return tf.image.ssim(y_true,y_pred,1.0)

**Tensorboard Callback**

In [None]:
from keras.callbacks import ModelCheckpoint

# tensorboard callback
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")

In [None]:
# save_best_model_callback
from datetime import datetime
model_save_folder = datetime.now().strftime("%Y%m%d_%H%M%S")
model_save_dir = f'models/{model_save_folder}'
#model_save_dir = f'models'

if not os.path.exists(model_save_dir):
  print("FOLDER CREATED")
  os.makedirs(model_save_dir)

save_best_model_checkpoint = ModelCheckpoint(model_save_dir+'/model-{epoch:03d}.hdf5',monitor='val_loss',save_best_only=True,mode='auto')

**Callback to show intermediate results of the training model.**

In [None]:
def preprocess_image(img_path):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img,(1024,1024))
    preprocess_img = pad_image(img)
    preprocess_img = preprocess_img/255;
    return preprocess_img

In [None]:
def infer(x_paths,y_paths,n_images=2):
    for i,x_image_path in enumerate(x_paths):
        if i == n_images:
            break
        print(f'{i+1}/{n_images} : {x_image_path}')
        x_img = preprocess_image(x_image_path)
        prediction = model.predict(np.expand_dims(x_img,axis=0),verbose=0)

        y_img_path = y_paths[i]
        y_img = preprocess_image(y_img_path)

        fig, ax = plt.subplots(ncols=3, figsize=(15,10));
        ax[0].imshow(x_img);
        ax[1].imshow(prediction[0]);
        ax[2].imshow(y_img);
        ax[0].axis('off');
        ax[1].axis('off');
        ax[2].axis('off');
        ax[0].set_title("Noisy Image")
        ax[1].set_title("Predicted Image")
        ax[2].set_title("Ground Truth Image")
        plt.show()

In [None]:
from tensorflow.keras import callbacks
class PredictionCallback(callbacks.Callback):
    def __init__(self, log_interval, *args, **kwargs):
        super().__init__(*args, **kwargs)
        #self.image_files = image_files
        self.log_interval = log_interval

    def on_epoch_end(self, epoch, logs=None):
        if epoch % self.log_interval == 0:
            infer(train_x_paths,train_y_paths,n_images=2)


In [None]:
#model.compile(optimizer,loss, metrics = [charbonnier_loss,psnr_loss_fn,ssim_loss_fn])
model.compile(optimizer,charbonnier_loss, metrics = [MSEloss,psnr_loss_fn,ssim_loss_fn])

In [None]:
LOG_INTERVALS = 5
EPOCHS = 50

In [None]:
# Train for 10 epochs

hist = model.fit(train, epochs = EPOCHS, validation_data = val, callbacks=[tensorboard_callback,save_best_model_checkpoint,
                                                                        PredictionCallback(log_interval = LOG_INTERVALS)])

In [None]:
# save weights
model.save("LLE_UNET_DENOISE.h5")

# save entire model
model.save('LLE_UNET_DENOISE')

**Loss graphs**

In [None]:
plt.plot(hist.history['loss'], color='teal', label='loss')
plt.plot(hist.history['val_loss'], color='orange', label='val loss')
plt.suptitle('Loss')
plt.legend()
plt.show()

**Tensorboard**

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir '/kaggle/working/logs'

**Model Inference**

In [None]:
# on test images
infer(test_x_paths,test_y_paths,n_images=10)