In [1]:
%cd ../..

/volatile/home/Zaccharie/workspace/understanding-unets


In [2]:
# this just to make sure we are using only on CPU
import os
os.environ["CUDA_VISIBLE_DEVICES"]="-1"

In [3]:
%load_ext autoreload
%autoreload 2
import copy
import time

import numpy as np
import pandas as pd
from tqdm import tqdm_notebook

from learning_wavelets.data.datasets import im_dataset_div2k
from learning_wavelets.models.learned_wavelet import learnlet
from learning_wavelets.models.unet import unet
from learning_wavelets.utils.metrics import metrics_from_ds


                 .|'''|       /.\      '||'''|,
                 ||          // \\      ||   ||
'||''|, '||  ||` `|'''|,    //...\\     ||...|'
 ||  ||  `|..||   .   ||   //     \\    ||
 ||..|'      ||   |...|' .//       \\. .||
 ||       ,  |'
.||        ''

Package version: 0.0.3

License: CeCILL-B

Authors: 

Antoine Grigis <antoine.grigis@cea.fr>
Samuel Farrens <samuel.farrens@cea.fr>
Jean-Luc Starck <jl.stark@cea.fr>
Philippe Ciuciu <philippe.ciuciu@cea.fr>

Dependencies: 

scipy          : >=1.3.0   - required | 1.4.1     installed
numpy          : >=1.16.4  - required | 1.17.4    installed
matplotlib     : >=3.0.0   - required | 3.1.2     installed
astropy        : >=3.0.0   - required | 3.2.3     installed
nibabel        : >=2.3.2   - required | 2.5.1     installed
pyqtgraph      : >=0.10.0  - required | 0.10.0    installed
progressbar2   : >=3.34.3  - required | ?         installed
modopt         : >=1.4.0   - required | 1.4.1     installed
scikit-learn   : >=0.19.1  - requi

In [4]:
np.random.seed(0)

In [5]:
all_net_params = [
    {
        'name': 'unet-multiple-stds',
        'init_function': unet,
        'run_params': {
            'n_layers': 5, 
            'pool': 'max', 
            "layers_n_channels": [64, 128, 256, 512, 1024], 
            'layers_n_non_lins': 2,
            'non_relu_contract': False,
            'bn': True,
            'input_size': (None, None, 1),
        },
        'run_id': 'unet_dynamic_st_bsd500_0_55_1576668365',
        'epoch': 500,
    },
]

dynamic_denoising_net_params = [
    {
        'name': 'learnlet_0_55_big_bsd',
        'init_function': learnlet,
        'run_params': {
            'denoising_activation': 'dynamic_soft_thresholding',
            'learnlet_analysis_kwargs':{
                'n_tiling': 256, 
                'mixing_details': False,  
                'kernel_size': 11
            },
            'learnlet_synthesis_kwargs': {
                'res': True,
                'kernel_size': 13,
            },
            'n_scales': 5,
            'exact_reconstruction_weight': 0,
            'clip': True,
            'input_size': (None, None, 1),     
        },
        'run_id': 'learnlet_dynamic_st_bsd500_0_55_1576762010',
        'epoch': 500,
    },
]

In [6]:
noise_stds = [5, 15, 20, 25, 30, 50, 75]
# noise_stds = [20,]
# noise_stds = [30]

In [7]:
noise_std_metrics = {}
for noise_std in tqdm_notebook(noise_stds, 'Noise stds'):
    metrics = []
    for net_params in all_net_params:
        im_ds = im_dataset_div2k(
            mode='testing', 
            batch_size=1, 
            patch_size=256, 
            noise_std=noise_std, 
            exact_recon=False, 
            return_noise_level=False,
        )
        metrics.append((net_params['name'], metrics_from_ds(im_ds, **net_params)))
        
    for net_params in dynamic_denoising_net_params:
        im_ds = im_dataset_div2k(
            mode='testing', 
            batch_size=1, 
            patch_size=256, 
            noise_std=noise_std, 
            exact_recon=False, 
            return_noise_level=True,
        )
        metrics.append((net_params['name'], metrics_from_ds(im_ds, **net_params)))

    noise_std_metrics[noise_std] = metrics

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, description='Noise stds', max=7.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for unet-multiple-stds', style=ProgressStyle(descri…

  return compare_psnr(gt, pred, data_range=1)
  gt, pred, multichannel=True, data_range=1





HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for learnlet_0_55_big_bsd', style=ProgressStyle(des…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for unet-multiple-stds', style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for learnlet_0_55_big_bsd', style=ProgressStyle(des…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for unet-multiple-stds', style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for learnlet_0_55_big_bsd', style=ProgressStyle(des…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for unet-multiple-stds', style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for learnlet_0_55_big_bsd', style=ProgressStyle(des…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for unet-multiple-stds', style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for learnlet_0_55_big_bsd', style=ProgressStyle(des…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for unet-multiple-stds', style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for learnlet_0_55_big_bsd', style=ProgressStyle(des…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for unet-multiple-stds', style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for learnlet_0_55_big_bsd', style=ProgressStyle(des…





In [8]:
# PSNR table
psnr_metrics_table = pd.DataFrame(
    index=noise_stds, 
    columns=[p['name'] for p in all_net_params] + [p['name'] for p in dynamic_denoising_net_params] + ['original', 'wavelets_24', 'bm3d'],
)
for noise_std, metrics in noise_std_metrics.items():
    for name, m in metrics:
        psnr_metrics_table.loc[noise_std, name] = "{mean:.4} ({std:.4})".format(
            mean=m.metrics['PSNR'].mean(), 
            std=m.metrics['PSNR'].stddev(),
        )
psnr_metrics_table

Unnamed: 0,unet-multiple-stds,learnlet_0_55_big_bsd,original,wavelets_24,bm3d
5,39.3 (3.724),38.5 (3.771),,,
15,34.45 (4.597),33.09 (4.12),,,
20,33.43 (5.02),32.01 (3.764),,,
25,32.11 (4.978),30.72 (3.982),,,
30,31.29 (4.83),29.97 (3.862),,,
50,28.95 (5.098),27.37 (3.565),,,
75,23.25 (1.838),25.33 (2.838),,,


In [9]:
# SSIM table
ssim_metrics_table = pd.DataFrame(
    index=noise_stds, 
    columns=[p['name'] for p in all_net_params] + [p['name'] for p in dynamic_denoising_net_params] + ['original', 'wavelets_24', 'bm3d'],
)
for noise_std, metrics in noise_std_metrics.items():
    for name, m in metrics:
        ssim_metrics_table.loc[noise_std, name] = "{mean:.4} ({std:.4})".format(
            mean=m.metrics['SSIM'].mean(), 
            std=m.metrics['SSIM'].stddev(),
        )
ssim_metrics_table

Unnamed: 0,unet-multiple-stds,learnlet_0_55_big_bsd,original,wavelets_24,bm3d
5,0.9621 (0.03139),0.9569 (0.03288),,,
15,0.9061 (0.0556),0.8826 (0.05417),,,
20,0.8855 (0.05858),0.853 (0.05976),,,
25,0.8572 (0.07313),0.82 (0.06098),,,
30,0.8351 (0.08393),0.7892 (0.06785),,,
50,0.7615 (0.111),0.6788 (0.07534),,,
75,0.4257 (0.09507),0.5581 (0.07022),,,
