# **LLE_UNET INFERENCE UNDERWATER**

Date - 03/08/2023

Images - EUVP (Enhancing Underwater Visual Perception) Datset Images

In [None]:
import os
import random
import csv

In [None]:
import cv2
import sys
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import math
import random


import tensorflow as tf
from tensorflow.keras.layers import Dense, Activation, Concatenate, GlobalAveragePooling2D, Multiply,GlobalMaxPooling2D
from tensorflow.keras.layers import Conv2D, BatchNormalization, MaxPool2D, Conv2DTranspose, Input
from tensorflow.keras.models import Model
from tensorflow.keras.applications.vgg16 import VGG16

In [None]:
# set Radndom Seed
SEED = 0
np.random.seed(SEED)
tf.random.set_seed(SEED)
tf.config.run_functions_eagerly(True)

In [None]:
# Load train test validation image paths form csv files 
def load_excel_data(file_path, column1_name='UnderWater Images', column2_name='GroundTruth Images', sheet_name="Sheet1"):
    try:
        # Load the Excel file
        df = pd.read_csv(file_path)

        # Extract the data from the specified columns
        column1_data = df[column1_name].tolist()
        column2_data = df[column2_name].tolist()

        return column1_data, column2_data

    except Exception as e:
        print(f"Error occurred while loading data from Excel: {e}")
        return None, None

In [None]:
train_csv_path = r'/kaggle/input/train-lle-unet-for-underwater/train.csv'
test_csv_path = r'/kaggle/input/train-lle-unet-for-underwater/test.csv'
validation_csv_path = r'/kaggle/input/train-lle-unet-for-underwater/validation.csv'

In [None]:
train_x_paths, train_y_paths = load_excel_data(train_csv_path)
val_x_paths, val_y_paths = load_excel_data(validation_csv_path)
test_x_paths, test_y_paths = load_excel_data(test_csv_path)

In [None]:
train_x_paths = sorted(train_x_paths)
train_y_paths = sorted(train_y_paths)
val_x_paths = sorted(val_x_paths)
val_y_paths = sorted(val_y_paths)
test_x_paths = sorted(test_x_paths)
test_y_paths = sorted(test_y_paths)


In [None]:
print(f'X_train : {len(train_x_paths)}')
print(f'Y_train : {len(train_y_paths)}')
print(f'X_val   : {len(val_x_paths)}')
print(f'Y_val   : {len(val_y_paths)}')
print(f'X_test  : {len(test_x_paths)}')
print(f'Y_test  : {len(test_y_paths)}')

### **Create Model**

In [None]:
def channel_attention_module(x, ratio=8):
    batch, _, _, channel = x.shape

    ## Shared layers
    l1 = Dense(channel//ratio, activation="relu", use_bias=False)
    l2 = Dense(channel, use_bias=False)

    ## Global Average Pooling
    x1 = GlobalAveragePooling2D()(x)
    x1 = l1(x1)
    x1 = l2(x1)

    ## Global Max Pooling
    x2 = GlobalMaxPooling2D()(x)
    x2 = l1(x2)
    x2 = l2(x2)

    ## Add both the features and pass through sigmoid
    feats = x1 + x2
    feats = Activation("sigmoid")(feats)
    feats = Multiply()([x, feats])

    return feats

def spatial_attention_module(x):
    ## Average Pooling
    x1 = tf.reduce_mean(x, axis=-1)
    x1 = tf.expand_dims(x1, axis=-1)

    ## Max Pooling
    x2 = tf.reduce_max(x, axis=-1)
    x2 = tf.expand_dims(x2, axis=-1)

    ## Concatenat both the features
    feats = Concatenate()([x1, x2])
    ## Conv layer
    feats = Conv2D(1, kernel_size=7, padding="same", activation="sigmoid")(feats)
    feats = Multiply()([x, feats])

    return feats

def CBAM(x):
    x = channel_attention_module(x)
    x = spatial_attention_module(x)
    return x

In [None]:
def conv_block(input, num_filters):
    x = Conv2D(num_filters, 3, padding="same")(input)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, 3, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

def encoder_block(input, num_filters):
    x = conv_block(input, num_filters)
    p = MaxPool2D((2, 2))(x)
    return x, p

def decoder_block(input, skip_features, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

def build_model(input_shape):
    inputs = Input(input_shape)

    vgg_model = VGG16(include_top=False, weights="imagenet",input_tensor=inputs)
    vgg_model.trainable = False

    # Encoder
    s1 = vgg_model.get_layer("block1_conv2").output                             ## (512 x 512)
    s2 = vgg_model.get_layer("block2_conv2").output                             ## (256 x 256)
    s3 = vgg_model.get_layer("block3_conv3").output                             ## (128 x 128)
    s4 = vgg_model.get_layer("block4_conv3").output                             ## (64 x 64)

    b1 = vgg_model.get_layer("block5_conv3").output                             ## (32 x 32)

    # Attention
    s1 = CBAM(s1)
    s2 = CBAM(s2)
    s3 = CBAM(s3)
    s4 = CBAM(s4)

    # Decoder
    d1 = decoder_block(b1, s4, 512)                                             ## (64 x 64)
    d2 = decoder_block(d1, s3, 256)                                             ## (128 x 128)
    d3 = decoder_block(d2, s2, 128)                                             ## (256 x 256)
    d4 = decoder_block(d3, s1, 64)                                              ## (512 x 512)

    # Output
    outputs = Conv2D(3, 1, padding="same", activation="sigmoid")(d4)

    model = Model(inputs, outputs, name="LLE_UNET_undwerwater")
    return model


In [None]:
input_shape = (None, None, 3)
model = build_model(input_shape)
model.summary()

In [None]:
#optimizer = tf.keras.optimizers.experimental.AdamW(learning_rate=0.0001,weight_decay=0.004)
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4)
MSEloss = tf.keras.losses.MeanSquaredError()

def charbonnier_loss(y_true, y_pred):
    return tf.reduce_mean(tf.sqrt(tf.square(y_true - y_pred) + tf.square(1e-3)))

def psnr_loss_fn(y_true, y_pred):
    return tf.image.psnr(y_pred, y_true, max_val=1.0)

def ssim_loss_fn(y_true,y_pred):
    return tf.image.ssim(y_true,y_pred,1.0)

In [None]:
model_path = r'/kaggle/input/train-lle-unet-for-underwater/LLE_UNET_UNDWERWATER.h5'

In [None]:
from keras.models import load_model
model = load_model(model_path,custom_objects={'charbonnier_loss': charbonnier_loss,
                                              'psnr_loss_fn':psnr_loss_fn,
                                              'ssim_loss_fn':ssim_loss_fn},
                                              compile = False)

### **Preprocessing Functions**

In [None]:
def padding_calc(input_dim,multiplier=32):
    return math.ceil(input_dim/multiplier)*multiplier - input_dim

def pad_image(image,mood = "center_padding"):
    img_h = image.shape[0]
    img_w = image.shape[1]

    pad_y = padding_calc(img_h)
    pad_x = padding_calc(img_w)

    if mood == "center_padding":
        pad_y2 = pad_y//2
        pad_x2 = pad_x//2

        padded_img = image.copy()
        if pad_y%2 != 0:
            padded_img = np.pad(image, ((pad_y2, pad_y2+1), (pad_x2, pad_x2), (0, 0)), mode='constant')
        if pad_x%2 != 0:
            padded_img = np.pad(image, ((pad_y2, pad_y2), (pad_x2, pad_x2+1), (0, 0)), mode='constant')
        if (pad_y%2 == 0) & (pad_x%2 == 0):
            padded_img = np.pad(image, ((pad_y2, pad_y2), (pad_x2, pad_x2), (0, 0)), mode='constant')

    elif mood == "corner_padding":
        padded_img = np.pad(image, ((0, pad_y), (0, pad_x), (0, 0)), mode='constant')
    return padded_img

def inverse_padding(pad_image,image_dim,pad_method="center_padding"):
  pad_img_height = pad_image.shape[0]
  pad_img_width = pad_image.shape[1]

  img_height = image_dim[0]
  img_width = image_dim[1]

  if pad_method == "center_padding":
    pad_y1 = (pad_img_height - img_height)//2
    if pad_y1*2 == (pad_img_height - img_height):pad_y2 = pad_y1
    else: pad_y2 = pad_y1+1

    pad_x1 = (pad_img_width - img_width)//2
    if pad_x1*2 == (pad_img_width - img_width):pad_x2 = pad_x1
    else: pad_x2 = pad_x1+1
    extract_image = pad_image[pad_y1:pad_img_height-pad_y2,pad_x1:pad_img_width-pad_x2]

  if pad_method == "corner_padding":
    extract_image = pad_image[0:img_height,0:img_width]


  return extract_image

In [None]:
def preprocess_image(img_path):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img,(256,256))
    preprocess_img = pad_image(img)
    preprocess_img = preprocess_img/255;
    preprocess_img = preprocess_img.astype(np.float32)
    return preprocess_img

**Inference**

In [None]:
EXPERIMENT_VERSION = 'v1'
RESULTS_SAVE_DIR = os.path.join(r'/kaggle/working/',EXPERIMENT_VERSION)
os.makedirs(RESULTS_SAVE_DIR,exist_ok = True)

In [None]:
import numpy as np
from skimage.metrics import structural_similarity as ssim
from math import log10, sqrt
import time
from PIL import Image

In [None]:
class DatasetMetrics:
    def __init__(self):
        self.total_psnr = 0
        self.total_ssim = 0
        self.total_mse  = 0
        self.total_images = 0

    def update_metrics(self, img1, img2):
        mse_score = self.calculate_mse(img1, img2)
        psnr = self.calculate_psnr(img1, img2)
        ssim_score = self.calculate_ssim(img1, img2)

        self.total_psnr += psnr
        self.total_ssim += ssim_score
        self.total_mse += mse_score
        self.total_images += 1
        
    def calculate_mse(self, img1, img2):
        return np.mean((np.array(img1, dtype=np.float32) - np.array(img2, dtype=np.float32)) ** 2)
    
    def calculate_psnr(self, img1, img2,max_pixel=1.0):
        mse = self.calculate_mse(img1, img2)
        if mse == 0:
            return 100
        psnr = 20 * np.log10(max_pixel / (np.sqrt(mse)))
        return psnr
    
    def calculate_ssim(self, img1, img2):
        gray1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
        gray2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
        return ssim(gray1, gray2,data_range= 1-0)

    def __str__(self):
        if self.total_images == 0:
            return "No images processed yet."
        
        avg_psnr = self.total_psnr / self.total_images
        avg_ssim = self.total_ssim / self.total_images
        avg_mse = self.total_mse / self.total_images

        return f"Average PSNR: {avg_psnr:.4f}, Average SSIM: {avg_ssim:.4f}, Average MAE: {avg_mse:.4f}"

# # Usage example:
# if __name__ == "__main__":
#     # Assuming you have a list of image pairs, img_list1 and img_list2
#     dataset_metrics = DatasetMetrics()

#     for img1, img2 in zip(img_list1, img_list2):
#         dataset_metrics.update_metrics(img1, img2)

#     print(dataset_metrics)

In [None]:
def save_image(numpy_image,img_name):
    img = Image.fromarray((numpy_image * 255).astype(np.uint8))
    save_image_path = os.path.join(RESULTS_SAVE_DIR,f'{img_name}.png')
    img.save(save_image_path)

In [None]:
# x_paths: A list of underwater image file paths for prediction.
# y_paths: A list of ground truth image file paths corresponding to the underwater images
# n_images: The number of images to process for inference and evaluation.
# suffix: suffix for naming saved images and results to differentiate.
# save1: If True, it saves individual images of input, predicted, and ground truth.
# save2: If True, it saves a single image with all three images side by side.
# disp: If True, it shows image with all three images side by side.

def infer(x_paths,y_paths,n_images,suffix,save1=False,save2=False,disp=False):
    st = time.time()
    dataset_metrics = DatasetMetrics()
    for i,x_image_path in enumerate(x_paths):
        if i == n_images:
            break
        print(f'{i+1}/{n_images} : {x_image_path}')
        x_img = preprocess_image(x_image_path)
        prediction = model.predict(np.expand_dims(x_img,axis=0),verbose=0)[0]

        y_img_path = y_paths[i]
        y_img = preprocess_image(y_img_path)
        
        if save1:
            save_image(x_img,f'underwater_{suffix}_{i}')
            save_image(prediction,f'pred_{suffix}_{i}')
            save_image(y_img,f'groundtruth_{suffix}_{i}')
        
        if disp:
            fig, ax = plt.subplots(ncols=3, figsize=(15,10));
            ax[0].imshow(x_img);
            ax[1].imshow(prediction);
            ax[2].imshow(y_img);
            ax[0].axis('off');
            ax[1].axis('off');
            ax[2].axis('off');
            ax[0].set_title("UnderWater Image")
            ax[1].set_title("Predicted Image")
            ax[2].set_title("GroundTruth Image")
            if save2:
                full_image_path = os.path.join(RESULTS_SAVE_DIR,f'full_pred_{suffix}_{i}.png')
                plt.savefig(full_image_path, bbox_inches='tight', pad_inches=0)
            plt.show()
        
        dataset_metrics.update_metrics(y_img,prediction)
    
    print(f'\n{suffix} images')
    print("="*80)
    print(f'{time.time()-st :.4f} sec for {n_images} images')
    print(dataset_metrics)
    print("="*80)

In [None]:
# on test images
infer(test_x_paths,test_y_paths,n_images=10,suffix='test',save1=False,save2=False,disp=True)

In [None]:
infer(train_x_paths,train_y_paths,n_images=10,suffix='train',save1=False,save2=False,disp=True)