In [None]:
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import random
import numpy as np
from glob import glob
from PIL import Image, ImageOps
import matplotlib.pyplot as plt

import keras
from keras import layers

import tensorflow as tf


In [None]:
BATCH_SIZE = 4
MAX_TRAIN_IMAGES = 300

def read_image(image_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image,channels=3)
    image.set_shape([400,600,3])
    image = tf.cast(image,dtype=tf.float32)/255.0
    return image


def load_data(low_light_image_path,enhanced_image_path):
    low_light_image = read_image(low_light_image_path)
    enhanced_image = read_image(enhanced_image_path)
    return low_light_image,enhanced_image


def get_dataset(low_light_images, enhanced_images):
    dataset = tf.data.Dataset.from_tensor_slices((low_light_images, enhanced_images))
    dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    return dataset


train_low_light_images = sorted(glob("./Train/Train/low/*"))[:MAX_TRAIN_IMAGES]
train_enhanced_images = sorted(glob("./Train/Train/high/*"))[:MAX_TRAIN_IMAGES]

val_low_light_images = sorted(glob("./Train/Train/low/*"))[MAX_TRAIN_IMAGES:]
val_enhanced_images = sorted(glob("./Train/Train/low/*"))[MAX_TRAIN_IMAGES:]

train_dataset = get_dataset(train_low_light_images,train_enhanced_images)
val_dataset = get_dataset(val_low_light_images,val_enhanced_images)

In [None]:
def SKFF(multiscale_feature_1, multiscale_feature_2, multiscale_feature_3):
    channels=list(multiscale_feature_1.shape)[-1]
    combined_feature = layers.Add()(
        [multiscale_feature_1, multiscale_feature_2, multiscale_feature_3]
    )
    gap = layers.GlobalAveragePooling2D()(combined_feature)
    channel_wise_stat = layers.Reshape((1,1,channels))(gap)
    compact_feature_representation = layers.Conv2D(
        filters=channels//8,kernel_size=(1,1),activation='relu'
    )(channel_wise_stat)
    feature_descriptor_1 = layers.Conv2D(
        channels,kernel_size=(1,1),activation='softmax'
    )(compact_feature_representation)
    feature_descriptor_2 = layers.Conv2D(
        channels,kernel_size=(1,1),activation='softmax'
    )(compact_feature_representation)
    feature_descriptor_3 = layers.Conv2D(
        channels,kernel_size=(1,1),activation='softmax'
    )(compact_feature_representation)
    feature_1 = multiscale_feature_1*feature_descriptor_1
    feature_2 = multiscale_feature_2*feature_descriptor_2
    feature_3 = multiscale_feature_3*feature_descriptor_3
    aggregate_feature = layers.Add()([feature_1,feature_2,feature_3])
    return aggregate_feature

In [None]:
class ChannelPooling(layers.Layer):
    def __init__(self, axis=-1, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.axis = axis
        self.concat = layers.Concatenate(axis=self.axis)

    def call(self, inputs):
        average_pooling = tf.expand_dims(tf.reduce_mean(inputs, axis=-1), axis=-1)
        max_pooling = tf.expand_dims(tf.reduce_max(inputs, axis=-1), axis=-1)
        return self.concat([average_pooling, max_pooling])

    def get_config(self):
        config = super().get_config()
        config.update({"axis": self.axis})

        
def spatial_attention_block(input_tensor):
    compressed_feature = ChannelPooling(axis=-1)(input_tensor)
    feature_map = layers.Conv2D(1,kernel_size=(1,1))(compressed_feature)
    feature_map = keras.activations.sigmoid(feature_map)
    return input_tensor*feature_map


def channel_attention_block(input_tensor):
    channels = list(input_tensor.shape)[-1]
    avg_pooling= layers.GlobalAveragePooling2D()(input_tensor)
    feature_descriptor = layers.Reshape((1,1,channels))(avg_pooling)
    feature_activations = layers.Conv2D(
        filters=channels // 8, kernel_size=(1, 1), activation="relu"
    )(feature_descriptor)
    feature_activations = layers.Conv2D(
        filters=channels, kernel_size=(1, 1), activation="sigmoid"
    )(feature_activations)
    return input_tensor * feature_activations


def dual_attention_unit_block(input_tensor):
    channels = list(input_tensor.shape)[-1]
    feature_map = layers.Conv2D(
        channels, kernel_size=(3,3),padding='same',activation='relu'
    )(input_tensor)
    feature_map = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(
        feature_map
    )
    channel_attention = channel_attention_block(feature_map)
    spatial_attention = spatial_attention_block(feature_map)
    concatenation = layers.Concatenate(axis=-1)([channel_attention,spatial_attention])
    concatenation = layers.Conv2D(channels, kernel_size=(1, 1))(concatenation)
    return layers.Add()([input_tensor,concatenation])

    

In [None]:
def down_sampling_module(input_tensor):
    channels = list(input_tensor.shape)[-1]
    main_branch = layers.Conv2D(channels,kernel_size=(1,1),activation ='relu')(input_tensor)
    main_branch = layers.Conv2D(
        channels,kernel_size=(3,3),padding='same',activation ='relu')(main_branch)
    main_branch = layers.MaxPooling2D()(main_branch)
    main_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(main_branch)
    skip_branch = layers.MaxPooling2D()(input_tensor)
    skip_branch = layers.Conv2D(channels * 2, kernel_size=(1, 1))(skip_branch)
    return layers.Add()([main_branch,skip_branch])

def up_sampling_module(input_tensor):
    channels = list(input_tensor.shape)[-1]
    main_branch = layers.Conv2D(channels,kernel_size=(1,1),activation ='relu')(input_tensor)
    main_branch = layers.Conv2D(
        channels,kernel_size=(3,3),padding='same',activation ='relu')(main_branch)
    main_branch = layers.UpSampling2D()(main_branch)
    main_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(main_branch)
    skip_branch = layers.UpSampling2D()(input_tensor)
    skip_branch = layers.Conv2D(channels // 2, kernel_size=(1, 1))(skip_branch)
    return layers.Add()([main_branch,skip_branch])


def MRB_block(input_tensor,channels):
    level_1 = input_tensor
    level_2 = down_sampling_module(level_1)
    level_3 = down_sampling_module(level_2)
    
    level_1_dau = dual_attention_unit_block(level_1)
    level_2_dau = dual_attention_unit_block(level_2)
    level_3_dau = dual_attention_unit_block(level_3)
    
    level_1_skff = SKFF(level_1_dau, 
                        up_sampling_module(level_2_dau),
                        up_sampling_module(up_sampling_module(level_3_dau)),)
    level_2_skff = SKFF(down_sampling_module(level_1_dau),
                       level_2_dau,
                       up_sampling_module(level_3_dau),)
    level_3_skff = SKFF(down_sampling_module(down_sampling_module(level_1_dau)),
                       down_sampling_module(level_2_dau),
                       level_3_dau,)
    
    level_1_dau_2 = dual_attention_unit_block(level_1_skff)
    level_2_dau_2 = up_sampling_module(dual_attention_unit_block(level_2_skff))
    level_3_dau_2 = up_sampling_module(up_sampling_module(dual_attention_unit_block(level_3_skff)))
    
    skff = SKFF(level_1_dau_2,level_2_dau_2,level_3_dau_2)
    conv = layers.Conv2D(channels, kernel_size = (3,3), padding = 'same')(skff)
    return layers.Add()([input_tensor, conv])

In [None]:
def recursive_residual_group(input_tensor, num_mrb, channels):
    conv1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
    for _ in range(num_mrb):
        conv1 = MRB_block(conv1, channels)
    conv2 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(conv1)
    return layers.Add()([conv2, input_tensor])


def mirnet_model(num_rrg, num_mrb, channels):
    input_tensor = keras.Input(shape=[None, None, 3])
    x1 = layers.Conv2D(channels, kernel_size=(3, 3), padding="same")(input_tensor)
    for _ in range(num_rrg):
        x1 = recursive_residual_group(x1, num_mrb, channels)
    conv = layers.Conv2D(3, kernel_size=(3, 3), padding="same")(x1)
    output_tensor = layers.Add()([input_tensor, conv])
    return keras.Model(input_tensor, output_tensor)

model = mirnet_model(num_rrg=3, num_mrb=2, channels=64)

In [None]:
def charbonnier_loss(y_true, y_pred):
    return tf.reduce_mean(tf.sqrt(tf.square(y_true - y_pred) + tf.square(1e-3)))


def peak_signal_noise_ratio(y_true, y_pred):
    return tf.image.psnr(y_pred, y_true, max_val=255.0)


optimizer = keras.optimizers.Adam(learning_rate=1e-4)
model.compile(
    optimizer=optimizer,
    loss=charbonnier_loss,
    metrics=[peak_signal_noise_ratio],
)

history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=50,
    callbacks=[
        keras.callbacks.ReduceLROnPlateau(
            monitor="val_peak_signal_noise_ratio",
            factor=0.5,
            patience=5,
            verbose=1,
            min_delta=1e-7,
            mode="max",
        )
    ],
)

In [None]:
def plot_history(value, name):
    plt.plot(history.history[value], label=f"train_{name.lower()}")
    plt.plot(history.history[f"val_{value}"], label=f"val_{name.lower()}")
    plt.xlabel("Epochs")
    plt.ylabel(name)
    plt.title(f"Train and Validation {name} Over Epochs", fontsize=14)
    plt.legend()
    plt.grid()
    plt.show()


plot_history("loss", "Loss")
plot_history("peak_signal_noise_ratio", "PSNR")

In [None]:
def preprocess_image(image_path):
    image = Image.open(image_path).convert('RGB')
    image = np.array(image) / 255.0
    image = np.expand_dims(image, axis=0)
    return image


def postprocess_image(image):
    image = np.squeeze(image, axis=0)
    image = (image * 255).astype(np.uint8)
    image = Image.fromarray(image)
    return image

test_dir = './test/low/'
predicted_dir = './test/predicted/'
os.makedirs(predicted_dir, exist_ok=True)
for image_name in os.listdir(test_dir):
    if image_name.endswith('.png'):
        image_path = os.path.join(test_dir, image_name)
        input_image = preprocess_image(image_path)
        denoised_image = model.predict(input_image)
        denoised_image = postprocess_image(denoised_image)
        output_path = os.path.join(predicted_dir, 'denoised_' + image_name)
        denoised_image.save(output_path)
