<a href="https://colab.research.google.com/github/soumik12345/image-restoration-primer/blob/main/notebooks/04_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<!--- @wandbcode{tfug-kol} -->

In [None]:
!pip install -q --upgrade wandb gradio

In [None]:
import tensorflow as tf
from tensorflow import keras

import wandb
import gradio as gr

import os
import numpy as np
from PIL import Image
from glob import glob

In [None]:
# Login to Weights & Biases
wandb.login()

In [None]:
def preprocess_image(image):
    """Preprocesses the image for inference.

    Returns:
        A numpy array of shape (1, height, width, 3) preprocessed for inference.
    """
    image = keras.preprocessing.image.img_to_array(image)
    image = image.astype("float32") / 255.0
    return np.expand_dims(image, axis=0)


def postprocess_image(model_output):
    """Postprocesses the model output for inference.
    
    Returns:
        A list of PIL.Image.Image objects postprocessed for visualization.
    """
    model_output = model_output * 255.0
    model_output = model_output.clip(0, 255)
    image = model_output[0].reshape(
        (np.shape(model_output)[1], np.shape(model_output)[2], 3)
    )
    return Image.fromarray(np.uint8(image))

In [None]:
def dehaze_image(wandb_project, wandb_entity, hazy_image, model_artifact_address, model_artifact_version):
    model_artifact_address += ":" + model_artifact_version
    with wandb.init(project=wandb_project, entity=wandb_entity, job_type="demo"):
        artifact = wandb.use_artifact(model_artifact_address, type="model")
        model_path = artifact.download()
        model = keras.models.load_model(model_path, compile=False)
        preprocessed_hazy_images = preprocess_image(hazy_image)
        prediction = postprocess_image(model.predict(preprocessed_hazy_images))
        table = wandb.Table(columns=["Hazy-Image", "Predicted-Image"])
        table.add_data(wandb.Image(hazy_image), wandb.Image(prediction))
        wandb.log({"Demo-Table": table})
    return prediction


model_artifact_versions = ["latest"] + [f"v{idx}" for idx in range(30)]

demo = gr.Interface(
    fn=dehaze_image,
    inputs=[
        gr.Text(value="image-dehazing", label="WandB Project", show_label=True),
        gr.Text(value="geekyrakshit", label="WandB Entity", show_label=True),
        "image",
        gr.Dropdown(
            choices=[
                "geekyrakshit/image-dehazing/denim-sea-39",
                "geekyrakshit/image-dehazing/run_39nvjow7_model",
                "geekyrakshit/image-dehazing/run_nj3biqvb_model"
            ]
        ),
        gr.Dropdown(choices=model_artifact_versions)
    ],
    outputs="image",
)

demo.launch(debug=True)