In [1]:
%cd ..

/volatile/home/Zaccharie/workspace/understanding-unets


In [18]:
%matplotlib inline
import bm3d
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from learning_wavelets.data.datasets import im_dataset_bsd68
from learning_wavelets.image_utils import trim_padding
from learning_wavelets.keras_utils.load_model import unpack_model
from learning_wavelets.models.learned_wavelet import learnlet
from learning_wavelets.models.unet import unet
from learning_wavelets.models.wavelet_denoising import wavelet_denoising_pysap

In [3]:
plt.rcParams['figure.figsize'] = (480/8, 320/8)
plt.rcParams['image.cmap'] = 'gray'

In [4]:
tf.random.set_seed(0)

In [5]:
all_net_params = [
    {
        'name': 'unet_0_55',
        'init_function': unet,
        'run_params': {
            'n_layers': 5, 
            'pool': 'max', 
            "layers_n_channels": [64, 128, 256, 512, 1024], 
            'layers_n_non_lins': 2,
            'non_relu_contract': False,
            'bn': True,
            'input_size': (None, None, 1),
        },
        'run_id': 'unet_dynamic_st_bsd500_0_55_1576668365',
        'epoch': 500,
    },
]

dynamic_denoising_net_params = [
    {
        'name': 'learnlet_0_55_big_bsd',
        'init_function': learnlet,
        'run_params': {
            'denoising_activation': 'dynamic_soft_thresholding',
            'learnlet_analysis_kwargs':{
                'n_tiling': 256, 
                'mixing_details': False,  
                'kernel_size': 11,
                'skip_connection': True,
            },
            'learnlet_synthesis_kwargs': {
                'res': True,
                'kernel_size': 13,
            },
            'n_scales': 5,
            'exact_reconstruction_weight': 0,
            'clip': True,
            'input_size': (None, None, 1),     
        },
        'run_id': 'learnlet_dynamic_st_bsd500_0_55_1580806694',
        'epoch': 500,
    },
]

In [6]:
noise_std = 30

In [7]:
im_ds = im_dataset_bsd68(
    mode='testing', 
    batch_size=1, 
    patch_size=None, 
    noise_std=noise_std, 
    return_noise_level=True,
    n_pooling=5,
    n_samples=1,
)

In [8]:
(im_noisy, tf_noise_std), im_gt, orig_im_shape = next(iter(im_ds))

In [9]:
unet_model = unpack_model(**all_net_params[0])

In [13]:
%%timeit
im_denoised_unet = unet_model(im_noisy)

63.7 ms ± 1.39 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [14]:
learnlet_model = unpack_model(**dynamic_denoising_net_params[0])

In [15]:
%%timeit
im_denoised_learnlet = learnlet_model([im_noisy, tf_noise_std])

106 ms ± 12.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [16]:
im_noisy_numpy = np.squeeze(im_noisy.numpy())

In [19]:
%%timeit
im_denoised_bm3d = bm3d.bm3d(im_noisy_numpy + 0.5, sigma_psd=noise_std/255, stage_arg=bm3d.BM3DStages.ALL_STAGES) - 0.5

10.8 s ± 223 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [20]:
%%timeit
im_denoised_wavelets = wavelet_denoising_pysap([im_noisy_numpy], noise_std/255, '24', 5, False, 3)[0]

274 ms ± 20.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
