<a href="https://colab.research.google.com/github/vnvo2409/deep-image-prior/blob/main/code/denoising.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!rm -rf *
!git clone https://github.com/vnvo2409/deep-image-prior
!mv deep-image-prior/code/* ./
!mv deep-image-prior/res/ ./res

In [None]:
import functools

import tensorflow as tf

import callbacks
import metrics
import skip
import utils

In [None]:
kDATA_FORMAT = "NCHW"

In [None]:
def denoising_train_dataset_generator(x, y, noise_std=1 / 30):
    while True:
        yield (tf.add(x, tf.random.normal(tf.shape(x), stddev=noise_std)), y)


def build_denoising_initial_input(input_img, input_depth, maxval=0.1, data_format=None):
    input_shape = None
    if data_format != "NCHW":
        size = tf.shape(input_img)[tf.rank(input_img) - 3 : tf.rank(input_img) - 1]
        input_shape = (1, size[0], size[1], input_depth)
    else:
        size = tf.shape(input_img)[tf.rank(input_img) - 2 : tf.rank(input_img)]
        input_shape = (1, input_depth, size[0], size[1])
    return tf.random.uniform(input_shape, 0, maxval)


def build_denoising_model(summary=False, plot=None, data_format=None):
    model = skip.build_skip_net(
        32,
        3,
        5,
        128,
        128,
        4,
        upsample_modes="bilinear",
        padding_mode="reflect",
        data_format=data_format,
        activations=functools.partial(tf.keras.layers.LeakyReLU, 0.2),
    )
    if summary:
        model.summary(line_length=150)
    if plot:
        tf.keras.utils.plot_model(
            model, to_file=plot, show_shapes=True, expand_nested=True, dpi=192
        )
    return model

In [None]:
input_img = utils.load_img("res/denoising/input.png", data_format=kDATA_FORMAT)
input_img = tf.expand_dims(input_img, axis=0)
noisy_img = utils.make_noisy_img(input_img)
input_net = build_denoising_initial_input(input_img, 32, data_format=kDATA_FORMAT)

In [None]:
model = build_denoising_model(data_format=kDATA_FORMAT)

In [None]:
model.compile(
    loss="mse",
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.01),
    metrics=metrics.build_psnr_metrics(addition_imgs={"original": input_img}),
)

In [None]:
save_predictions_callback = callbacks.SavePredictions(input_net, plot=True, data_format=kDATA_FORMAT)

In [None]:
model.fit(
    x=denoising_train_dataset_generator(input_net, noisy_img),
    epochs=3000,
    steps_per_epoch=1,
    callbacks=[save_predictions_callback],
    verbose=2,
)

In [None]:
denoised_img = model(input_net)

In [None]:
utils.plot_img(tf.concat([input_img, noisy_img, denoised_img], axis=0), figsize=(25,25), data_format=kDATA_FORMAT)