<a href="https://colab.research.google.com/github/soumik12345/image-restoration-primer/blob/main/notebooks/03_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<!--- @wandbcode{tfug-kol} -->

In [None]:
!pip install -q --upgrade wandb
!git clone https://github.com/soumik12345/BLR-ML-Monthly-Meetup

In [None]:
import tensorflow as tf
from tensorflow import keras

import os
import wandb
import numpy as np
from PIL import Image
from glob import glob
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

In [None]:
wandb_project = "image-dehazing" #@param {type:"string"}
wandb_entity = "geekyrakshit" #@param {type:"string"}
wandb.init(
    project=wandb_project, entity=wandb_entity, job_type="inference"
)

model_artifact_address = "geekyrakshit/image-dehazing/run_39nvjow7_model:latest" #@param {type:"string"}

# Fetch model from WandB Model artifact
artifact = wandb.use_artifact(model_artifact_address, type="model")
model_path = artifact.download()

# Load Model
model = keras.models.load_model(model_path, compile=False)

In [None]:
def preprocess_image(image):
    """Preprocesses the image for inference.

    Returns:
        A numpy array of shape (1, height, width, 3) preprocessed for inference.
    """
    image = keras.preprocessing.image.img_to_array(image)
    image = image.astype("float32") / 255.0
    return np.expand_dims(image, axis=0)


def postprocess_image(model_output):
    """Postprocesses the model output for inference.
    
    Returns:
        A list of PIL.Image.Image objects postprocessed for visualization.
    """
    model_output = model_output * 255.0
    model_output = model_output.clip(0, 255)
    image = model_output[0].reshape(
        (np.shape(model_output)[1], np.shape(model_output)[2], 3)
    )
    return Image.fromarray(np.uint8(image))


def plot_results(images, titles, figure_size=(12, 12)):
    """A simple utility for plotting the results"""
    fig = plt.figure(figsize=figure_size)
    for i in range(len(images)):
        fig.add_subplot(1, len(images), i + 1).set_title(titles[i])
        _ = plt.imshow(images[i])
        plt.axis("off")
    plt.show()

In [None]:
table = wandb.Table(columns=["Hazy-Image", "Predicted-Image"])
hazy_images = glob("./BLR-ML-Monthly-Meetup/test_images/*")

for hazy_image in tqdm(hazy_images):
    input_image = Image.open(hazy_image)
    preprocessed_input_image = preprocess_image(input_image)
    predicted_image = model.predict(preprocessed_input_image, verbose=0)
    table.add_data(
        wandb.Image(hazy_image),
        wandb.Image(postprocess_image(predicted_image)),
    )


wandb.log({"Inference-Results": table})

In [None]:
wandb.finish()