In [1]:
%cd ..

/volatile/home/Zaccharie/workspace/understanding-unets


In [2]:
# this just to make sure we are using only on CPU
import os
os.environ["CUDA_VISIBLE_DEVICES"]="-1"

In [3]:
%load_ext autoreload
%autoreload 2
%matplotlib nbagg
import time
import warnings
warnings.filterwarnings("ignore")

import bm3d
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm_notebook

from learning_wavelets.data import im_generators, ConcatenateGenerators
from learning_wavelets.evaluate import Metrics
from learning_wavelets.keras_utils.thresholding import SoftThresholding, HardThresholding
from learning_wavelets.learned_wavelet import learned_wavelet, learnlet
from learning_wavelets.unet import unet
from learning_wavelets.wavelet_denoising import wavelet_denoising_pysap


                 .|'''|       /.\      '||'''|,
                 ||          // \\      ||   ||
'||''|, '||  ||` `|'''|,    //...\\     ||...|'
 ||  ||  `|..||   .   ||   //     \\    ||
 ||..|'      ||   |...|' .//       \\. .||
 ||       ,  |'
.||        ''

Package version: 0.0.3

License: CeCILL-B

Authors: 

Antoine Grigis <antoine.grigis@cea.fr>
Samuel Farrens <samuel.farrens@cea.fr>
Jean-Luc Starck <jl.stark@cea.fr>
Philippe Ciuciu <philippe.ciuciu@cea.fr>

Dependencies: 

scipy          : >=1.3.0   - required | 1.3.0     installed
numpy          : >=1.16.4  - required | 1.17.4    installed
matplotlib     : >=3.0.0   - required | 3.1.0     installed
astropy        : >=3.0.0   - required | 3.1.2     installed
nibabel        : >=2.3.2   - required | 2.5.1     installed
pyqtgraph      : >=0.10.0  - required | 0.10.0    installed
progressbar2   : >=3.34.3  - required | ?         installed
modopt         : >=1.4.0   - required | 1.4.1     installed
scikit-learn   : >=0.19.1  - requi

In [4]:
np.random.seed(0)

In [5]:
plt.rcParams['figure.figsize'] = (9, 5)
plt.rcParams['image.cmap'] = 'gray'

In [6]:
grey = True
n_channels = 3
if grey:
    n_channels = 1
noise_std = 30
im_gen_train, im_gen_val, im_gen_test, size, n_samples_train = im_generators(
    'bsd68', 
    batch_size=1, 
    validation_split=0, 
    no_augment=True, 
    noise_std=noise_std,
    grey=grey
)    
im_bsd68 = im_gen_test

In [7]:
thresh = 2*noise_std/255
all_net_params = [
    {
        'name': 'learned wavelet_st',
        'init_function': learned_wavelet,
        'run_params': {
            'n_scales': 5, 
            'n_details': 256, 
            'n_coarse': n_channels, 
            'mixing_details': False,
            'denoising_activation': 'linear',
            'wav_pooling': True,
            'wav_use_bias': False,
            'wav_normed': True,
            'filters_normed': ['details', 'coarse'],
            'input_size': (size, size, n_channels),     
        },
        'run_id': 'learned_wavelet_div2k_30_1574790444',
        'epoch': 250,
    },
#     {
#         'name': 'learned wavelet_ht',
#         'init_function': learned_wavelet,
#         'run_params': {
#             'n_scales': 5, 
#             'n_details': 256, 
#             'n_coarse': n_channels, 
#             'mixing_details': False,
#             'denoising_activation': HardThresholding(1.5*thresh),
#             'wav_pooling': True,
#             'wav_use_bias': False,
#             'wav_normed': True,
#             'filters_normed': ['details', 'coarse'],
#             'input_size': (size, size, n_channels),
#         },
#         'run_id': 'learned_wavelet_div2k_30_1574371296',
#         'epoch': 50,
#     },
]
all_exact_recon_net_params = [
    {
        'name': 'learnlet_exact_reco',
        'init_function': learnlet,
        'run_params': {
            'denoising_activation': 'linear',
            'learnlet_analysis_kwargs':{
                'n_tiling': 256, 
                'mixing_details': False,        
            },
            'learnlet_synthesis_kwargs':{
            },
            'n_scales': 5,
            'exact_reconstruction_weight': 1,
            'input_size': (size, size, n_channels),     
        },
        'run_id': 'learnlet_div2k_30_1575052510',
        'epoch': 250,
    },
]

In [8]:
def unpack_model(init_function=None, run_params=None, run_id=None, epoch=250, **dummy_kwargs):
    model = init_function(**run_params)
    chkpt_path = f'checkpoints/{run_id}-{epoch}.hdf5'
    model.load_weights(chkpt_path)
    return model

def enumerate_seq(seq, name):
    return (seq[i] for i in tqdm_notebook(range(len(seq)), desc=f'Val files for {name}'))

def enumerate_seq_noisy(seq, name):
    return (np.squeeze(seq[i][0]) for i in tqdm_notebook(range(len(seq)), desc=f'Val files for {name}'))

def enumerate_seq_gt(seq):
    return (np.squeeze(seq[i][1]) for i in range(len(seq)))

def metrics_for_params(reco_function=None, name=None, **net_params):
    model = unpack_model(**net_params)
    metrics = Metrics()
    pred_and_gt = [
        (model.predict_on_batch(images_gt), images_gt)
        for images_noisy, images_gt in enumerate_seq(im_bsd68, name)
    ]    
    for im_recos, images in tqdm_notebook(pred_and_gt, desc=f'Stats for {name}'):
        metrics.push(images, im_recos.numpy())
    return metrics

def metrics_exact_recon_net(name=None, **net_params):
    model = unpack_model(**net_params)
    metrics = Metrics()
    pred_and_gt = [
        (model.predict_on_batch((images_gt, images_gt))[0], images_gt)
        for images_noisy, images_gt in enumerate_seq(im_bsd68, name)
    ]    
    for im_recos, images in tqdm_notebook(pred_and_gt, desc=f'Stats for {name}'):
        metrics.push(images, im_recos.numpy())
    return metrics


def metrics_wavelets(wavelet_id):
    metrics = Metrics()
    pred = wavelet_denoising_pysap(
        enumerate_seq_noisy(im_bsd68, f'Wavelet denoising {wavelet_id}'),
        noise_std=noise_std/255,
        wavelet_id=wavelet_id, 
        n_scales=5, 
        soft_thresh=False, 
        n_sigma=3,
    )
    gt = enumerate_seq_gt(im_bsd68)
    for im_recos, images in tqdm_notebook(zip(pred, gt), desc='Stats for wavelet denoising'):
        metrics.push(images[..., None], im_recos[..., None])
    return metrics

def metrics_bm3d():
    metrics = Metrics()
    pred = [
        bm3d.bm3d(image_noisy + 0.5, sigma_psd=noise_std/255, stage_arg=bm3d.BM3DStages.ALL_STAGES) - 0.5
        for image_noisy in enumerate_seq_noisy(im_bsd68, f'BM3D')
    ]
    gt = enumerate_seq_gt(im_bsd68)
    for im_recos, images in tqdm_notebook(zip(pred, gt), desc='Stats for bm3d'):
        metrics.push(images[..., None], im_recos[..., None])
    return metrics

In [9]:
%%time
metrics = []
for net_params in all_net_params:
    metrics.append((net_params['name'], metrics_for_params(**net_params)))
    
for net_params in all_exact_recon_net_params:
    metrics.append((net_params['name'], metrics_exact_recon_net(**net_params)))
    
# metrics.append(('bm3d', metrics_bm3d()))
# metrics.append(('wavelets_24', metrics_wavelets('24')))

HBox(children=(IntProgress(value=0, description='Val files for learned wavelet_st', max=68, style=ProgressStyl…




HBox(children=(IntProgress(value=0, description='Stats for learned wavelet_st', max=68, style=ProgressStyle(de…




HBox(children=(IntProgress(value=0, description='Val files for learnlet_exact_reco', max=68, style=ProgressSty…




HBox(children=(IntProgress(value=0, description='Stats for learnlet_exact_reco', max=68, style=ProgressStyle(d…


CPU times: user 23min 1s, sys: 6min 47s, total: 29min 48s
Wall time: 1min 34s


In [10]:
metrics

[('learned wavelet_st', PSNR = 25.56 +/- 5.584 SSIM = 0.8816 +/- 0.06021),
 ('learnlet_exact_reco', PSNR = 41.48 +/- 4.28 SSIM = 0.991 +/- 0.00542)]

In [11]:
metrics.sort(key=lambda x: x[1].metrics['PSNR'].mean())

In [12]:
# import pickle
# with open('metrics_net_rec_fastmri', 'wb') as f:
#     pickle.dump(metrics, f)

In [15]:
metrics_table = pd.DataFrame(
    index=[name for name, _ in metrics], 
    columns=['PSNR-mean (std) (dB)', 'SSIM-mean (std)'],
)
for name, m in metrics:
    metrics_table.loc[name, 'PSNR-mean (std) (dB)'] = "{mean:.4} ({std:.4})".format(
        mean=m.metrics['PSNR'].mean(), 
        std=m.metrics['PSNR'].stddev(),
    )
    metrics_table.loc[name, 'SSIM-mean (std)'] = "{mean:.4} ({std:.4})".format(
        mean=m.metrics['SSIM'].mean(), 
        std=m.metrics['SSIM'].stddev(),
    )

In [16]:
metrics_table

Unnamed: 0,PSNR-mean (std) (dB),SSIM-mean (std)
learned wavelet_st,25.56 (2.792),0.8816 (0.03011)
learnlet_exact_reco,41.48 (2.14),0.991 (0.00271)
