In [1]:
%cd ../..

/volatile/home/Zaccharie/workspace/understanding-unets


In [2]:
# # this just to make sure we are using only on CPU
# import os
# os.environ["CUDA_VISIBLE_DEVICES"]="-1"

In [3]:
%load_ext autoreload
%autoreload 2
import copy
import time

import numpy as np
import pandas as pd
from tqdm import tqdm_notebook

from learning_wavelets.data.datasets import im_dataset_bsd68
from learning_wavelets.models.learnlet_model import Learnlet
from learning_wavelets.models.unet import unet
from learning_wavelets.utils.metrics import metrics_from_ds, metrics_original_from_ds


                 .|'''|       /.\      '||'''|,
                 ||          // \\      ||   ||
'||''|, '||  ||` `|'''|,    //...\\     ||...|'
 ||  ||  `|..||   .   ||   //     \\    ||
 ||..|'      ||   |...|' .//       \\. .||
 ||       ,  |'
.||        ''

Package version: 0.0.3

License: CeCILL-B

Authors: 

Antoine Grigis <antoine.grigis@cea.fr>
Samuel Farrens <samuel.farrens@cea.fr>
Jean-Luc Starck <jl.stark@cea.fr>
Philippe Ciuciu <philippe.ciuciu@cea.fr>

Dependencies: 

scipy          : >=1.3.0   - required | 1.4.1     installed
numpy          : >=1.16.4  - required | 1.17.4    installed
matplotlib     : >=3.0.0   - required | 3.1.2     installed
astropy        : >=3.0.0   - required | 3.2.3     installed
nibabel        : >=2.3.2   - required | 2.5.1     installed
pyqtgraph      : >=0.10.0  - required | 0.10.0    installed
progressbar2   : >=3.34.3  - required | ?         installed
modopt         : >=1.4.0   - required | 1.4.1     installed
scikit-learn   : >=0.19.1  - requi

In [4]:
np.random.seed(0)

In [5]:
def build_learnlet_subclassing(**run_params):
    model = Learnlet(**run_params)
    model.build([[None, None, None, 1], [None, 1]])
    return model

In [6]:
all_net_params = [
    {
        'name': 'unet-multiple-stds',
        'init_function': unet,
        'run_params': {
            'n_layers': 5, 
            'pool': 'max', 
            "layers_n_channels": [64, 128, 256, 512, 1024], 
            'layers_n_non_lins': 2,
            'non_relu_contract': False,
            'bn': True,
            'input_size': (None, None, 1),
        },
        'run_id': 'unet_dynamic_st_bsd500_0_55_1576668365',
        'epoch': 500,
    },
]

dynamic_denoising_net_params = [
    {
        'name': 'learnlet',
        'init_function': build_learnlet_subclassing,
        'run_params': {
            'denoising_activation': 'dynamic_soft_thresholding',
            'learnlet_analysis_kwargs':{
                'n_tiling': 64, 
                'mixing_details': False,    
                'skip_connection': True,
                'kernel_size': 11,
            },
            'learnlet_synthesis_kwargs': {
                'res': True,
                'kernel_size': 13,
            },
            'threshold_kwargs':{
                'noise_std_norm': True,
            },
            'n_scales': 5,
            'exact_reconstruction': True,
            'n_reweights_learn': 1,
            'undecimated': True,
            'clip': False,
        },
        'run_id': 'learnlet_subclassed_undecimated_exact_reco_64_dynamic_soft_thresholding_bsd500_0.0_55.0_None_1582907244',
        'epoch': 100,
    },
]

In [7]:
noise_stds = [0.0001, 5, 15, 20, 25, 30, 50, 55, 60, 75]

In [8]:
noise_std_metrics = {}
n_samples = None
for noise_std in tqdm_notebook(noise_stds, 'Noise stds'):
    metrics = []
    for net_params in all_net_params:
        im_ds = im_dataset_bsd68(
            mode='testing', 
            batch_size=1, 
            patch_size=None, 
            noise_std=noise_std, 
            return_noise_level=False,
            n_pooling=5,
            n_samples=n_samples,
        )
        metrics.append((net_params['name'], metrics_from_ds(im_ds, **net_params)))
    im_ds = im_dataset_bsd68(
        mode='testing', 
        batch_size=1, 
        patch_size=None, 
        noise_std=noise_std, 
        return_noise_level=False,
        n_pooling=5,
        n_samples=n_samples,
    )
    metrics.append(('original', metrics_original_from_ds(im_ds)))
        
    for net_params in dynamic_denoising_net_params:
        im_ds = im_dataset_bsd68(
            mode='testing', 
            batch_size=1, 
            patch_size=None, 
            noise_std=noise_std, 
            return_noise_level=True,
            n_pooling=5,
            n_samples=n_samples,
        )
        metrics.append((net_params['name'], metrics_from_ds(im_ds, **net_params)))


    noise_std_metrics[noise_std] = metrics

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  This is separate from the ipykernel package so we can avoid doing imports until


HBox(children=(FloatProgress(value=0.0, description='Noise stds', max=10.0, style=ProgressStyle(description_wi…

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for unet-multiple-stds', max=68.0, style=ProgressSt…

  return compare_psnr(gt, pred, data_range=1)
  gt, pred, multichannel=True, data_range=1





HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for original noisy images', max=68.0, style=Progres…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for learnlet', max=68.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for unet-multiple-stds', max=68.0, style=ProgressSt…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for original noisy images', max=68.0, style=Progres…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for learnlet', max=68.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for unet-multiple-stds', max=68.0, style=ProgressSt…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for original noisy images', max=68.0, style=Progres…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for learnlet', max=68.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for unet-multiple-stds', max=68.0, style=ProgressSt…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for original noisy images', max=68.0, style=Progres…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for learnlet', max=68.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for unet-multiple-stds', max=68.0, style=ProgressSt…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for original noisy images', max=68.0, style=Progres…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for learnlet', max=68.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for unet-multiple-stds', max=68.0, style=ProgressSt…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for original noisy images', max=68.0, style=Progres…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for learnlet', max=68.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for unet-multiple-stds', max=68.0, style=ProgressSt…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for original noisy images', max=68.0, style=Progres…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for learnlet', max=68.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for unet-multiple-stds', max=68.0, style=ProgressSt…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for original noisy images', max=68.0, style=Progres…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for learnlet', max=68.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for unet-multiple-stds', max=68.0, style=ProgressSt…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for original noisy images', max=68.0, style=Progres…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for learnlet', max=68.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for unet-multiple-stds', max=68.0, style=ProgressSt…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for original noisy images', max=68.0, style=Progres…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Stats for learnlet', max=68.0, style=ProgressStyle(descri…





In [9]:
noise_std_metrics;

In [14]:
BM3D_psnr_results = {
    0.0001: 127.0,
    5: 37.57,
    15: 31.07,
    20: 29.60,
    25: 28.57,
    30: 27.74,
    50: 25.62,
    55: 25.26,
    60: 25.02,
    75: 24.21,
}

wavelets_psnr_results = {
    0.0001: (127.3,  0.3614),
    5: (35.76,  1.937),
    15: (29.56,  2.553),
    20: (28.25,  2.645),
    25: (27.32,  2.684),
    30: (26.61,  2.694),
    50: (24.79,  2.673),
    55: (24.46, 2.642),
    60: (24.18, 2.624),
    75: (23.46,  2.574),
}

In [15]:
# PSNR table
psnr_metrics_table = pd.DataFrame(
    columns=['noise_std'] + [p['name'] for p in all_net_params] + [p['name'] for p in dynamic_denoising_net_params] + ['original', 'wavelets_24', 'bm3d'],
)
for i, (noise_std, metrics) in enumerate(noise_std_metrics.items()):
    psnr_metrics_table.loc[i, 'noise_std'] = noise_std
    for name, m in metrics:
        psnr_metrics_table.loc[i, name] = "{mean:.4} ({std:.4})".format(
            mean=m.metrics['PSNR'].mean(), 
            std=m.metrics['PSNR'].stddev(),
        )
    psnr_metrics_table.loc[i, 'bm3d'] = BM3D_psnr_results.get(noise_std, np.nan)
    psnr_metrics_table.loc[i, 'wavelets_24'] = "{mean:.4} ({std:.4})".format(
        mean=wavelets_psnr_results[noise_std][0], 
        std=wavelets_psnr_results[noise_std][1],
    )
psnr_metrics_table

Unnamed: 0,noise_std,unet-multiple-stds,learnlet,original,wavelets_24,bm3d
0,0.0001,51.88 (3.037),124.9 (0.2378),128.1 (0.01528),127.3 (0.3614),127.1
1,5.0,37.61 (1.853),36.56 (1.978),34.15 (0.01634),35.76 (1.937),37.57
2,15.0,31.59 (2.439),30.87 (2.412),24.61 (0.01391),29.56 (2.553),31.07
3,20.0,30.2 (2.544),29.53 (2.459),22.11 (0.0161),28.25 (2.645),29.6
4,25.0,29.18 (2.619),28.54 (2.486),20.17 (0.01689),27.32 (2.684),28.57
5,30.0,28.38 (2.669),27.75 (2.48),18.59 (0.01477),26.61 (2.694),27.74
6,50.0,26.28 (2.708),25.59 (2.403),14.16 (0.01445),24.79 (2.673),25.62
7,55.0,25.91 (2.71),25.18 (2.351),13.32 (0.01744),24.46 (2.642),25.26
8,60.0,25.45 (2.683),24.82 (2.327),12.57 (0.0158),24.18 (2.624),
9,75.0,22.31 (1.446),23.87 (2.222),10.63 (0.01599),23.46 (2.574),24.21


In [None]:
# SSIM table
ssim_metrics_table = pd.DataFrame(
    columns=['noise_std'] + [p['name'] for p in all_net_params] + [p['name'] for p in dynamic_denoising_net_params] + ['original', 'wavelets_24', 'bm3d'],
)
for i, (noise_std, metrics) in enumerate(noise_std_metrics.items()):
    ssim_metrics_table.loc[i, 'noise_std'] = noise_std
    for name, m in metrics:
        ssim_metrics_table.loc[i, name] = "{mean:.4} ({std:.4})".format(
            mean=m.metrics['SSIM'].mean(), 
            std=m.metrics['SSIM'].stddev(),
        )
ssim_metrics_table;

In [None]:
%matplotlib nbagg
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
sns.set(style="whitegrid", palette="muted", rc={'figure.figsize': (9, 5), 'image.cmap': 'gray'})

In [None]:
relative_to_original = True


model_family_str = r'$\bf{Model}$'
train_stds_str = r'$\bf{Train}$  $\bf{std}$'
noise_std_str = r'$\sigma$'
psnr_str = 'PSNR'
# PSNR to plot
psnr_to_plot = pd.DataFrame(
    columns=[noise_std_str, psnr_str, 'model_name', model_family_str, train_stds_str]
)

def from_name_to_family(model_name):
    if 'learnlet' in model_name:
        return 'Learnlets'
    elif 'unet' in model_name:
        return 'U-net'
    elif 'bm3d' in model_name:
        return 'BM3D'
    elif 'wavelets' in model_name:
        return 'Wavelets'
    else:
        return 'Original'
    
def from_name_to_train_stds(model_name):
    if '30' in model_name:
        return '30'
    elif '20_40' in model_name:
        return '[20; 40]'
    else:
        return '[0; 55]'

family_model_to_color = {
    'U-net': 'C0',
    'Learnlets': 'C1',
    'BM3D': 'C2',
    'Wavelets': 'C3',
}
index = 0
orig_psnrs = {}
for i_noise, (noise_std, metrics) in enumerate(noise_std_metrics.items()):
    for j_model, (name, m) in enumerate(metrics):
        if relative_to_original and name == 'original':
            orig_psnrs[noise_std] = m.metrics['PSNR'].mean()
        else:
            psnr_to_plot.loc[index, noise_std_str] = noise_std
            psnr_to_plot.loc[index, psnr_str] = m.metrics['PSNR'].mean()
            psnr_to_plot.loc[index, 'model_name'] = name
            psnr_to_plot.loc[index, train_stds_str] = from_name_to_train_stds(name)
            psnr_to_plot.loc[index, model_family_str] = from_name_to_family(name)
            index += 1
    name = 'bm3d'
    psnr_to_plot.loc[index, noise_std_str] = noise_std
    psnr_to_plot.loc[index, psnr_str] = BM3D_psnr_results.get(noise_std, np.nan)
    psnr_to_plot.loc[index, 'model_name'] = name
    psnr_to_plot.loc[index, train_stds_str] = from_name_to_train_stds(name)
    psnr_to_plot.loc[index, model_family_str] = from_name_to_family(name)
    index += 1
    name = 'wavelets_24'
    psnr_to_plot.loc[index, noise_std_str] = noise_std
    psnr_to_plot.loc[index, psnr_str] = wavelets_psnr_results.get(noise_std, [np.nan]*2)[0]
    psnr_to_plot.loc[index, 'model_name'] = name
    psnr_to_plot.loc[index, train_stds_str] = from_name_to_train_stds(name)
    psnr_to_plot.loc[index, model_family_str] = from_name_to_family(name)
    index += 1

if relative_to_original:
    for noise_std, orig_psnr in orig_psnrs.items():
        psnr_to_plot.loc[psnr_to_plot[noise_std_str] == noise_std, psnr_str] = psnr_to_plot[psnr_to_plot[noise_std_str] == noise_std][psnr_str] / orig_psnr
    
psnr_to_plot

In [None]:
plt.figure()
psnr_to_plot[psnr_str] = psnr_to_plot[psnr_str].astype(float)
lplot = sns.lineplot(
    x=noise_std_str, 
    y=psnr_str, 
    hue=model_family_str,
    data=psnr_to_plot,
    palette=family_model_to_color,
)
plt.legend(bbox_to_anchor=(0., 1.01, 1., .05), loc='center', borderaxespad=0., ncol=5, fontsize=13.35)
plt.tight_layout()
plt.savefig(f'gen_wo_error_bars.png')