To start this Jupyter Dash app, please run all the cells below. The app will appear inside the last cell.

In [None]:
!pip install dash-bootstrap-components jupyter-dash plotly -q

In [2]:
import base64
from io import BytesIO

import dash
import dash_bootstrap_components as dbc
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output, State
import jupyter_dash
from PIL import Image
import tensorflow as tf
import tensorflow_hub as hub

## Custom components

In [3]:
def image_card(src, header=None):
    return dbc.Card(
        [
            dbc.CardHeader(header),
            dbc.CardBody(html.Img(src=src, style={"width": "100%"})),
        ]
    )

## Helper functions

In [4]:
def preprocess_b64(image_enc):
    """Preprocess b64 string into TF tensor"""
    decoded = base64.b64decode(image_enc.split("base64,")[-1])
    hr_image = tf.image.decode_image(decoded)

    if hr_image.shape[-1] == 4:
        hr_image = hr_image[..., :-1]

    return tf.expand_dims(tf.cast(hr_image, tf.float32), 0)


def tf_to_b64(tensor, ext="jpeg"):
    buffer = BytesIO()

    image = tf.cast(tf.clip_by_value(tensor[0], 0, 255), tf.uint8).numpy()
    Image.fromarray(image).save(buffer, format=ext)

    encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")

    return f"data:image/{ext};base64, {encoded}"

## App layout

In [5]:
app = jupyter_dash.JupyterDash(external_stylesheets=[dbc.themes.BOOTSTRAP])

controls = [
    dcc.Upload(
        dbc.Card(
            "Drag and Drop or Click",
            body=True,
            style={
                "textAlign": "center",
                "borderStyle": "dashed",
                "borderColor": "black",
            },
        ),
        id="img-upload",
        multiple=False,
    )
]


app.layout = dbc.Container(
    [
        html.H1("Dash Image Enhancing with TensorFlow"),
        html.Hr(),
        dbc.Row([dbc.Col(c) for c in controls]),
        html.Br(),
        dbc.Spinner(
            dbc.Row(
                [
                    dbc.Col(html.Div(id=img_id))
                    for img_id in ["original-img", "enhanced-img"]
                ]
            )
        )
    ],
    fluid=True,
)

## Load ESRGAN model

In [6]:
model = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")

## Dash Callbacks

In [7]:
@app.callback(
    [Output("original-img", "children"), Output("enhanced-img", "children")],
    [Input("img-upload", "contents")],
    [State("img-upload", "filename")],
)
def enhance_image(img_str, filename):
    if img_str is None:
        return dash.no_update, dash.no_update

    # sr_str = img_str # PLACEHOLDER
    low_res = preprocess_b64(img_str)
    super_res = model(tf.cast(low_res, tf.float32))
    sr_str = tf_to_b64(super_res)

    lr = image_card(img_str, header="Original Image")
    sr = image_card(sr_str, header="Enhanced Image")

    return lr, sr

## Run the app

In [None]:
app.run_server(mode='inline', height=700)