<a href="https://colab.research.google.com/github/soumik12345/image-restoration-primer/blob/main/notebooks/02_evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<!--- @wandbcode{tfug-kol} -->

In [None]:
!pip install -q --upgrade wandb

In [None]:
import tensorflow as tf
from tensorflow import keras

import os
import wandb
import numpy as np
from PIL import Image
from glob import glob
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

In [None]:
wandb_project = "image-dehazing" #@param {type:"string"}
wandb_entity = "geekyrakshit" #@param {type:"string"}
wandb.init(
    project=wandb_project, entity=wandb_entity, job_type="eval"
)

config = wandb.config
config.dataset_artifact = 'geekyrakshit/image-dehazing/dehaze-dataset:v0' #@param {type:"string"}

artifact = wandb.use_artifact(config.dataset_artifact, type='dataset')
artifact_dir = artifact.download()

In [None]:
def get_image_file_list_from_dehazy(dataset_path):
    ground_truth_files = []
    hazy_image_paths = sorted(glob(str(os.path.join(dataset_path, 'train_images/*.jpg'))))
    for image_path in hazy_image_paths:
        image_file_name = image_path.split('/')[-1]
        ground_truth_file_name = image_file_name.split('_')[0] + '_' + image_file_name.split('_')[1] + '.jpg'
        ground_truth_files.append(str(os.path.join(
            dataset_path, 'original_images/' + ground_truth_file_name)))
    return hazy_image_paths, ground_truth_files


dehazy_dataset_path = os.path.join(artifact_dir, "Dehazing")
dehazy_hazy_image_paths, dehazy_ground_truth_paths = get_image_file_list_from_dehazy(dehazy_dataset_path)
print("Number of Hazy Images:", len(dehazy_hazy_image_paths))
print("Number of Ground-truth Images:", len(dehazy_ground_truth_paths))

config.val_split = 0.2
num_train_images = len(dehazy_hazy_image_paths) - int(len(dehazy_hazy_image_paths) * config.val_split)

train_hazy_image_paths = dehazy_hazy_image_paths[:num_train_images]
train_ground_truth_image_paths = dehazy_ground_truth_paths[:num_train_images]

val_hazy_image_paths = dehazy_hazy_image_paths[num_train_images:]
val_ground_truth_image_paths = dehazy_hazy_image_paths[num_train_images:]

In [None]:
ohazy_dataset_path = os.path.join(artifact_dir, "O-HAZY")
ohazy_hazy_image_paths = sorted(glob(os.path.join(ohazy_dataset_path, "hazy", "*")))
ohazy_ground_truth_paths = sorted(glob(os.path.join(ohazy_dataset_path, "GT", "*")))
print("Number of Hazy Images in O-HAZY Dataset:", len(ohazy_hazy_image_paths))
print("Number of Ground-truth Images in O-HAZY Dataset:", len(ohazy_ground_truth_paths))

ihazy_dataset_path = os.path.join(artifact_dir, "I-HAZY")
ihazy_hazy_image_paths = sorted(glob(os.path.join(ihazy_dataset_path, "hazy", "*")))
ihazy_ground_truth_paths = sorted(glob(os.path.join(ihazy_dataset_path, "GT", "*")))
print("Number of Hazy Images in I-HAZY Dataset:", len(ihazy_hazy_image_paths))
print("Number of Ground-truth Images in I-HAZY Dataset:", len(ihazy_ground_truth_paths))

In [None]:
model_artifact_address = "geekyrakshit/image-dehazing/run_39nvjow7_model:latest" #@param {type:"string"}

# Fetch model from WandB Model artifact
artifact = wandb.use_artifact(model_artifact_address, type="model")
model_path = artifact.download()

# Load Model
model = keras.models.load_model(model_path, compile=False)

In [None]:
def preprocess_image(image):
    """Preprocesses the image for inference.

    Returns:
        A numpy array of shape (1, height, width, 3) preprocessed for inference.
    """
    image = keras.preprocessing.image.img_to_array(image)
    image = image.astype("float32") / 255.0
    return np.expand_dims(image, axis=0)


def postprocess_image(model_output):
    """Postprocesses the model output for inference.
    
    Returns:
        A list of PIL.Image.Image objects postprocessed for visualization.
    """
    model_output = model_output * 255.0
    model_output = model_output.clip(0, 255)
    image = model_output[0].reshape(
        (np.shape(model_output)[1], np.shape(model_output)[2], 3)
    )
    return Image.fromarray(np.uint8(image))


def plot_results(images, titles, figure_size=(12, 12)):
    """A simple utility for plotting the results"""
    fig = plt.figure(figsize=figure_size)
    for i in range(len(images)):
        fig.add_subplot(1, len(images), i + 1).set_title(titles[i])
        _ = plt.imshow(images[i])
        plt.axis("off")
    plt.show()

In [None]:
table = wandb.Table(
    columns=[
        "Dataset",
        "Hazy-Image",
        "Ground-Truth",
        "Predicted-Image",
        "Peak-Signal-To-Noise-Ratio",
        "Structural-Similarity"
    ]
)

# Logging just 1000 samples to save time
for idx in tqdm(range(1000)):
    input_image = Image.open(train_hazy_image_paths[idx])
    ground_truth_image = Image.open(train_ground_truth_image_paths[idx])
    preprocessed_input_image = preprocess_image(input_image)
    preprocessed_ground_truth_image = preprocess_image(ground_truth_image)
    predicted_image = model.predict(preprocessed_input_image, verbose=0)
    psnr = tf.image.psnr(preprocessed_ground_truth_image, predicted_image, max_val=1.0).numpy().item()
    ssim = tf.image.ssim(preprocessed_ground_truth_image, predicted_image, max_val=1.0).numpy().item()
    table.add_data(
        "D-HAZE/Train",
        wandb.Image(train_hazy_image_paths[idx]),
        wandb.Image(train_ground_truth_image_paths[idx]),
        wandb.Image(postprocess_image(predicted_image)),
        psnr,
        ssim,
    )

# Logging just 1000 samples to save time
for idx in tqdm(range(1000)):
    input_image = Image.open(val_hazy_image_paths[idx])
    ground_truth_image = Image.open(val_ground_truth_image_paths[idx])
    preprocessed_input_image = preprocess_image(input_image)
    preprocessed_ground_truth_image = preprocess_image(ground_truth_image)
    predicted_image = model.predict(preprocessed_input_image, verbose=0)
    psnr = tf.image.psnr(preprocessed_ground_truth_image, predicted_image, max_val=1.0).numpy().item()
    ssim = tf.image.ssim(preprocessed_ground_truth_image, predicted_image, max_val=1.0).numpy().item()
    table.add_data(
        "D-HAZE/Validation",
        wandb.Image(val_hazy_image_paths[idx]),
        wandb.Image(val_ground_truth_image_paths[idx]),
        wandb.Image(postprocess_image(predicted_image)),
        psnr,
        ssim,
    )


for idx in tqdm(range(len(ohazy_hazy_image_paths))):
    input_image = Image.open(ohazy_hazy_image_paths[idx])
    ground_truth_image = Image.open(ohazy_ground_truth_paths[idx])
    preprocessed_input_image = preprocess_image(input_image)
    preprocessed_ground_truth_image = preprocess_image(ground_truth_image)
    predicted_image = model.predict(preprocessed_input_image, verbose=0)
    psnr = tf.image.psnr(preprocessed_ground_truth_image, predicted_image, max_val=1.0).numpy().item()
    ssim = tf.image.ssim(preprocessed_ground_truth_image, predicted_image, max_val=1.0).numpy().item()
    table.add_data(
        "O-HAZY",
        wandb.Image(ohazy_hazy_image_paths[idx]),
        wandb.Image(ohazy_ground_truth_paths[idx]),
        wandb.Image(postprocess_image(predicted_image)),
        psnr,
        ssim,
    )

for idx in tqdm(range(len(ihazy_hazy_image_paths))):
    input_image = Image.open(ihazy_hazy_image_paths[idx])
    ground_truth_image = Image.open(ihazy_ground_truth_paths[idx])
    preprocessed_input_image = preprocess_image(input_image)
    preprocessed_ground_truth_image = preprocess_image(ground_truth_image)
    predicted_image = model.predict(preprocessed_input_image, verbose=0)
    psnr = tf.image.psnr(preprocessed_ground_truth_image, predicted_image, max_val=1.0).numpy().item()
    ssim = tf.image.ssim(preprocessed_ground_truth_image, predicted_image, max_val=1.0).numpy().item()
    table.add_data(
        "I-HAZY",
        wandb.Image(ihazy_hazy_image_paths[idx]),
        wandb.Image(ihazy_ground_truth_paths[idx]),
        wandb.Image(postprocess_image(predicted_image)),
        psnr,
        ssim,
    )


wandb.log({"Evaluation-Results": table})

In [None]:
wandb.finish()