In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%env BSD500_DATA_DIR=/media/Zaccharie/UHRes/
# %env CUDA_VISIBLE_DEVICES=-1

env: BSD500_DATA_DIR=/media/Zaccharie/UHRes/


In [3]:
%cd ..
from tdv import TDV, TV
%cd experiments
from data import im_dataset_bsd500
from unrolled_fb import UnrolledFB

/volatile/home/Zaccharie/workspace/tf-tdv
/volatile/home/Zaccharie/workspace/tf-tdv/experiments


In [4]:
!pip install tensorflow-addons



In [5]:
%matplotlib nbagg
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow_addons.callbacks import TQDMProgressBar

In [6]:
batch_size = 1
val_ds = im_dataset_bsd500(mode='validation', batch_size=batch_size, noise_std=(25, 25))

In [7]:
train_ds = im_dataset_bsd500(mode='training', batch_size=batch_size, noise_std=(25, 25))

In [8]:
model = UnrolledFB(TDV, {'n_macro': 3, 'n_filters': 32, 'activation_str': 'relu'}, init_step_size=0.0001, n_iter=10)
model(tf.ones([1, 32, 32, 1]))
model.load_weights('denoising_unrolled_fb_tdv_relu.h5')
model.compile(loss='mse')

In [9]:
model.evaluate(val_ds.take(100))



0.002367957727983594

In [10]:
model.evaluate(train_ds.take(50))



0.0020186614710837603

In [11]:
def pad_for_pool(inputs, n_pools):
    problematic_dims = tf.shape(inputs)[1:3]
    k = tf.math.floordiv(problematic_dims, 2 ** n_pools)
    n_pad = tf.where(
        tf.math.mod(problematic_dims, 2 ** n_pools) == 0,
        0,
        (k + 1) * 2 ** n_pools - problematic_dims,
    )
    left_padding = tf.where(
        tf.logical_or(tf.math.mod(problematic_dims, 2) == 0, n_pad == 0),
        n_pad//2,
        n_pad//2 + 1,
    )
    right_padding = n_pad//2
    paddings = [
        (0, 0),
        (left_padding[0], right_padding[0]),
        (left_padding[1], right_padding[1]),
        (0, 0),
    ]
    inputs_padded = tf.pad(inputs, paddings)
    return inputs_padded, paddings

class MultiScaleModel(tf.keras.models.Model):
    def __init__(self, model, n_scales=0, **kwargs):
        super().__init__(**kwargs)
        self.model = model
        self.n_scales = n_scales
        
    def call(self, inputs):
        if self.n_scales > 0:
            outputs, paddings = pad_for_pool(inputs, n_pools=self.n_scales)
        else:
            outputs = inputs
        outputs = self.model(outputs)
        if self.n_scales > 0:
            problematic_dims = tf.shape(outputs)[1:3]
            outputs = outputs[
                :,
                paddings[1][0]: problematic_dims[0] - paddings[1][1],
                paddings[2][0]: problematic_dims[1] - paddings[2][1],
                :,
            ]
        return outputs
    
def tf_psnr(y_true, y_pred):
    max_pixel = tf.math.reduce_max(y_true)
    min_pixel = tf.math.reduce_min(y_true)
    return tf.image.psnr(y_true, y_pred, max_pixel - min_pixel)

In [12]:
val_ds_full_size = im_dataset_bsd500(mode='validation', batch_size=batch_size, noise_std=(25, 25), patch_size=None)

In [13]:
full_model = MultiScaleModel(model, n_scales=4)
full_model.compile(loss='mse', metrics=[tf_psnr])

In [14]:
full_model.evaluate(val_ds_full_size.take(100))



[0.0021508880890905857, 26.5772705078125]