<a href="https://colab.research.google.com/github/priyanshu-iitrgh/image-denoising-project/blob/main/image_denoising.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import random
import numpy as np
from glob import glob
from PIL import Image, ImageOps
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [None]:
IMAGE_SIZE = 256
BATCH_SIZE = 16
MAX_TRAIN_IMAGES = 400

In [None]:
train_low_light_images  = sorted(glob("./Train/low/*"))[:MAX_TRAIN_IMAGES]
val_low_light_images  = sorted(glob("./Train/low/*"))[MAX_TRAIN_IMAGES:]

# Verify that the lists are populated correctly
print("Number of training images:", len(train_low_light_images))
print("Number of validation images:", len(val_low_light_images))

# Print a few file paths to check if they are valid
print("Sample training images:")
for i in range(5):
    if i < len(train_low_light_images):
        print(train_low_light_images[i])

print("\nSample validation images:")
for i in range(5):
    if i < len(val_low_light_images):
        print(val_low_light_images[i])


Number of training images: 400
Number of validation images: 84
Sample training images:
./Train/low/100.png
./Train/low/101.png
./Train/low/102.png
./Train/low/103.png
./Train/low/104.png

Sample validation images:
./Train/low/720.png
./Train/low/721.png
./Train/low/722.png
./Train/low/723.png
./Train/low/724.png


In [None]:
def load_data(image_path):
        image = tf.io.read_file(image_path)
        image = tf.image.decode_png(image, channels = 3)
        image = tf.image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE])
        image = image/255.0
        return image

def data_generator(low_light_images):
        dataset = tf.data.Dataset.from_tensor_slices((low_light_images))
        dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.batch(BATCH_SIZE, drop_remainder = True )
        return dataset

In [None]:
train_dataset = data_generator(train_low_light_images)
val_dataset = data_generator(val_low_light_images)

print("Train Dataset: ", train_dataset)
print("Validation Dataset: ", val_dataset)


Train Dataset:  <_BatchDataset element_spec=TensorSpec(shape=(16, 256, 256, 3), dtype=tf.float32, name=None)>
Validation Dataset:  <_BatchDataset element_spec=TensorSpec(shape=(16, 256, 256, 3), dtype=tf.float32, name=None)>


In [None]:
def build_dce_net():
    input_img = keras.Input(shape = [None , None, 3])
    convl = layers.Conv2D(
        32, (3,3), strides = (1,1),  activation = 'relu', padding = 'same'
        )(input_img)
    conv2 = layers.Conv2D(
        32, (3,3), strides = (1,1),  activation = 'relu', padding = 'same'
        )(convl)
    conv3 = layers.Conv2D(
        32, (3,3), strides = (1,1),  activation = 'relu', padding = 'same'
        )(conv2)
    conv4 = layers.Conv2D(
        32, (3,3), strides = (1,1),  activation = 'relu', padding = 'same'
        )(conv3)
    int_con1 = layers.Concatenate(axis= -1)([conv4, conv3])
    conv5 = layers.Conv2D(
        32, (3,3), strides = (1,1),  activation = 'relu', padding = 'same'
        )(int_con1)
    int_con2 = layers.Concatenate(axis= -1)([conv5, conv2])
    conv6 = layers.Conv2D(
        32, (3,3), strides = (1,1),  activation = 'relu', padding = 'same'
        )(int_con2)
    int_con3 = layers.Concatenate(axis= -1)([conv6, convl])
    x_r = layers.Conv2D(
        24, (3,3), strides = (1,1),  activation = 'tanh', padding = 'same'
        )(int_con3)
    return keras.Model(inputs = input_img, outputs = x_r)


In [None]:
def color_constancy_loss(x):
    mean_rgb = tf.reduce_mean(x, axis = (1,2), keepdims= True)
    mr, mg, mb = mean_rgb[:,:,:,0], mean_rgb[:,:,:,1], mean_rgb[:,:,:,2]
    d_rg = tf.square(mr - mg)
    d_gb = tf.square(mg - mb)
    d_rb = tf.square(mr - mb)
    return tf.sqrt(tf.square(d_rg) + tf.square(d_gb) + tf.square(d_rb))

In [None]:
def exposure_loss(x, mean_val = 0.6):
    x = tf.reduce_mean(x, axis = 3, keepdims = True)
    means = tf.nn.avg_pool2d(x, ksize = 16, strides = 16, padding = 'VALID')
    return tf.reduce_mean(tf.square(means - mean_val))


In [None]:
def illumination_smoothness_loss(x):
    batch_size = tf.shape(x)[0]
    h_x = tf.shape(x)[1]
    w_x = tf.shape(x)[2]
    count_h = (tf.shape(x)[2] - 1) * tf.shape(x)[3]
    count_w = tf.shape(x)[2] * (tf.shape(x)[3] - 1)
    h_tv = tf.reduce_sum(tf.square((x[:,1:,:,:] - x[:,:h_x-1,:,:])))
    w_tv = tf.reduce_sum(tf.square((x[:,:,1:,:] - x[:,:,:w_x-1,:])))
    batch_size = tf.cast(batch_size, dtype = tf.float32)
    count_h = tf.cast(count_h, dtype = tf.float32)
    count_w = tf.cast(count_w, dtype = tf.float32)
    return 2 * (h_tv/count_h + w_tv/count_w)/batch_size

In [None]:
class SpatialConsistencyLoss(keras.losses.Loss):
    def __init__(self, **kwargs):
        super(SpatialConsistencyLoss, self).__init__(reduction = "none")

        self.left_kernel = tf.constant(
            [[[[0,0,0]], [[-1,-1,0]], [[0,0,0]]]], dtype = tf.float32
            )
        self.right_kernel = tf.constant(
            [[[[0,0,0]], [[0,1,-1]], [[0,0,0]]]], dtype = tf.float32
            )
        self.up_kernel = tf.constant(
            [[[[0,-1,0]], [[0,1,0]], [[0,0,0]]]], dtype = tf.float32
            )
        self.down_kernel = tf.constant(
            [[[[0,0,0]], [[0,1,0]], [[0,-1,0]]]], dtype = tf.float32
            )

    def call(self, y_true, y_pred):

        original_mean = tf.reduce_mean(y_true, 3, keepdims = True)
        enchanced_mean = tf.reduce_mean(y_pred, 3, keepdims = True)
        original_pool = tf.nn.avg_pool2d(
            original_mean, ksize = 4, strides = 4, padding = 'VALID'
            )
        enchanced_pool = tf.nn.avg_pool2d(
            enchanced_mean, ksize = 4, strides = 4, padding = 'VALID'
            )

        d_original_left = tf.nn.conv2d(
            original_pool, self.left_kernel, strides = [1,1,1,1], padding = 'SAME'
            )
        d_original_right = tf.nn.conv2d(
            original_pool, self.right_kernel, strides = [1,1,1,1], padding = 'SAME'
            )
        d_original_up = tf.nn.conv2d(
            original_pool, self.up_kernel, strides = [1,1,1,1], padding = 'SAME'
            )
        d_original_down = tf.nn.conv2d(
            original_pool, self.down_kernel, strides = [1,1,1,1], padding = 'SAME'
            )

        d_enchanced_left = tf.nn.conv2d(
            enchanced_pool, self.left_kernel, strides = [1,1,1,1], padding = 'SAME'
            )
        d_enchanced_right = tf.nn.conv2d(
            enchanced_pool, self.right_kernel, strides = [1,1,1,1], padding = 'SAME'
            )
        d_enchanced_up = tf.nn.conv2d(
            enchanced_pool, self.up_kernel, strides = [1,1,1,1], padding = 'SAME'
            )
        d_enchanced_down = tf.nn.conv2d(
            enchanced_pool, self.down_kernel, strides = [1,1,1,1], padding = 'SAME'
            )

        d_left = tf.square(d_original_left - d_enchanced_left)
        d_right = tf.square(d_original_right - d_enchanced_right)
        d_up = tf.square(d_original_up - d_enchanced_up)
        d_down = tf.square(d_original_down - d_enchanced_down)

        return d_left + d_right + d_up + d_down


In [None]:

class ZeroDCE(keras.Model):
    def __init__(self, **kwargs):
        super(ZeroDCE, self).__init__(**kwargs)
        self.dce_model = build_dce_net()

    def compile(self, learning_rate, **kwargs):
        super(ZeroDCE, self).compile(**kwargs)
        self.optimizer = keras.optimizers.Adam(learning_rate = learning_rate)
        self.spatial_consistency_loss = SpatialConsistencyLoss(reduction = "none")

    def get_enchanced_image(self, data, output):
      r1 = output[: , : , : , :3]
      r2 = output[: , : , : , 3:6]
      r3 = output[: , : , : , 6:9]
      r4 = output[: , : , : , 9:12]
      r5 = output[: , : , : , 12:15]
      r6 = output[: , : , : , 15:18]
      r7 = output[: , : , : , 18:21]
      r8 = output[: , : , : , 21:24]
      x = data + r1 * (tf.square(data) - data)
      x = x + r2 * (tf.square(x) - x)
      x = x + r3 * (tf.square(x) - x)
      enhanced_image = x + r4 * (tf.square(x) - x)
      x = enhanced_image + r5 * (tf.square(enhanced_image) - enhanced_image)
      x = x + r6 * (tf.square(x) - x)
      x = x + r7 * (tf.square(x) - x)
      enhanced_image = x + r8 * (tf.square(x) - x)
      return enhanced_image

    def call(self, data ):
      dce_net_output = self.dce_model(data)
      return self.get_enchanced_image(data, dce_net_output)

    def compute_losses(self, data, output):
      enchanced_image = self.get_enchanced_image(data, output)
      loss_illumination = 200 * illumination_smoothness_loss(enchanced_image)
      loss_spatial_constancy = tf.reduce_mean(
          self.spatial_consistency_loss(data, enchanced_image)
          )
      loss_color_constancy = 5 * tf.reduce_mean(
          color_constancy_loss(enchanced_image)
          )
      loss_exposure = 10 * tf.reduce_mean(
          exposure_loss(enchanced_image)
          )
      total_loss = (
          loss_illumination
          + loss_spatial_constancy
          + loss_color_constancy
          + loss_exposure
      )
      return {
          "total loss": total_loss,
          "illumination_smoothness_loss": loss_illumination,
          "spatial_constancy loss": loss_spatial_constancy,
          "color_constancy loss": loss_color_constancy,
          "exposure loss": loss_exposure,
      }

    def train_step(self, data):
      with tf.GradientTape() as tape:
        output = self.dce_model(data)
        losses = self.compute_losses(data, output)
      gradients = tape.gradient(
          losses["total loss"], self.dce_model.trainable_weights
          )
      self.optimizer.apply_gradients(zip(gradients, self.dce_model.trainable_weights))
      return losses

    def test_step(self, data):
      output = self.dce_model(data)
      losses = self.compute_losses(data, output)
      return losses

    def save_weights(self, filepath, overwrite = True, save_format = None, options = None):
      self.dce_model.save_weights(
          filepath, overwrite = overwrite, save_format = save_format, options = options
          )

    def load_weights(self, filepath, by_name = False, skip_mismatch = False, options = None):
      self.dce_model.load_weights(
          filepath, by_name = by_name, skip_mismatch = skip_mismatch, options = options
      )




In [None]:
zero_dce_model = ZeroDCE()
zero_dce_model.compile(learning_rate = 1e-4)
history = zero_dce_model.fit(train_dataset, epochs = 100, validation_data = val_dataset)

def plot_result(item):
    plt.plot(history.history[item], label = item)
    plt.plot(history.history["val_" + item], label = "val_" + item)
    plt.xlabel("Epochs")
    plt.ylabel(item)
    plt.title("Train and Validation {} Over Epochs".format(item), fontsize = 14)
    plt.legend()
    plt.grid()
    plt.show()


plot_result("total loss")
plot_result("illumination_smoothness_loss")
plot_result("spatial_constancy loss")
plot_result("color_constancy loss")
plot_result("exposure loss")

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100

In [None]:
def plot_result(images, titles, figure_size=(12,12)):
  fig = plt.figure(figsize = figure_size)
  for i in range(len(images)):
    fig.add_subplot(1, len(images), i+1).set_title(titles[i])
    _ = plt.imshow(images[i])
    plt.axis("off")
  plt.show()


def infer(original_image):
  image = keras.preprocessing.image.img_to_array(original_image)
  image = image.astype("float32")/255.0
  image = np.expand_dims(image, axis = 0)
  output_image = zero_dce_image(image)
  output_image - tf.cast((output_image[0,:,:,:] * 255), dtype = tf.uint8)
  output_image = Image.fromarray(output_image.numpy())
  return output_image

In [None]:
for val_image_file in test_low_light_images:
  original_image = Image.open(val_image_file)
  enhanced_image = infer(original_image)
  plot_result(
      [original_image, ImageOps.autocontrast(original_image),enhanced_image],
      ["Original", "PIL Autocontrast" , "Enhanced"],
      (20,12)
  )