In [1]:
%cd ..

/volatile/home/Zaccharie/workspace/understanding-unets


In [2]:
# # this just to make sure we are using only on CPU
# import os
# os.environ["CUDA_VISIBLE_DEVICES"]="-1"

In [3]:
%load_ext autoreload
%autoreload 2
%matplotlib nbagg
import time
import warnings
warnings.filterwarnings("ignore")

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm_notebook

from learning_wavelets.data import im_generators, ConcatenateGenerators
from learning_wavelets.evaluate import Metrics
from learning_wavelets.keras_utils.thresholding import SoftThresholding
from learning_wavelets.learned_wavelet import learned_wavelet
from learning_wavelets.unet import unet
from learning_wavelets.wavelet_denoising import wavelet_denoising

Using TensorFlow backend.


In [4]:
np.random.seed(0)

In [5]:
plt.rcParams['figure.figsize'] = (9, 5)
plt.rcParams['image.cmap'] = 'gray'

In [6]:
grey = True
n_channels = 3
if grey:
    n_channels = 1
im_gen_train, im_gen_val, im_gen_test, size, n_samples_train = im_generators(
    'bsd68', 
    batch_size=1, 
    validation_split=0.1, 
    no_augment=True, 
    noise_std=30,
    grey=grey
)    
im_bsd68 = ConcatenateGenerators(im_gen_train, im_gen_val, im_gen_test)

In [7]:
all_net_params = [
    {
        'name': 'unet',
        'init_function': unet,
        'run_params': {
            'n_layers': 5, 
            'pool': 'max', 
            "layers_n_channels": [64, 128, 256, 512, 1024], 
            'layers_n_non_lins': 2,
            'non_relu_contract': False,
            'input_size': (size, size, n_channels),
        },
        'run_id': 'unet_div2k_30_1570805218',
    },
    {
        'name': 'unet_bn',
        'init_function': unet,
        'run_params': {
            'n_layers': 5, 
            'pool': 'max', 
            "layers_n_channels": [64, 128, 256, 512, 1024], 
            'layers_n_non_lins': 2,
            'non_relu_contract': False,
            'bn': True,
            'input_size': (size, size, n_channels),
        },
        'run_id': 'unet_div2k_30_1571047085',
    },
    {
        'name': 'learned wavelet',
        'init_function': learned_wavelet,
        'run_params': {
            'n_scales': 5, 
            'n_details': 256, 
            'n_coarse': 1, 
            'n_groupping': 256,
            'input_size': (size, size, n_channels),
        },
        'run_id': 'learned_wavelet_div2k_30_1570806527',
    },
    {
        'name': 'learned wavelet_st_0.1',
        'init_function': learned_wavelet,
        'run_params': {
            'n_scales': 5, 
            'n_details': 256, 
            'n_coarse': 1, 
            'n_groupping': 256,
            'denoising_activation': SoftThresholding(0.1),
            'input_size': (size, size, n_channels),
        },
        'run_id': 'learned_wavelet_div2k_30_1571645964',
    },
    {
        'name': 'learned wavelet_st_0.1_wav_pooling',
        'init_function': learned_wavelet,
        'run_params': {
            'n_scales': 5, 
            'n_details': 256, 
            'n_coarse': 1, 
            'n_groupping': 256,
            'denoising_activation': SoftThresholding(0.1),
            'input_size': (size, size, n_channels),
            'wav_pooling': True,
        },
        'run_id': 'learned_wavelet_div2k_30_1571646276',
    },
    {
        'name': 'learned wavelet_st_0.1_unit_norm',
        'init_function': learned_wavelet,
        'run_params': {
            'n_scales': 5, 
            'n_details': 256, 
            'n_coarse': 1, 
            'n_groupping': 256,
            'denoising_activation': SoftThresholding(0.1),
            'input_size': (size, size, n_channels),
            'filters_normed': ['details', 'coarse'],

        },
        'run_id': 'learned_wavelet_div2k_30_1571646661',
    },
    {
        'name': 'learned wavelet_st_0.1_wav_pooling_unit_norm',
        'init_function': learned_wavelet,
        'run_params': {
            'n_scales': 5, 
            'n_details': 256, 
            'n_coarse': 1, 
            'n_groupping': 256,
            'denoising_activation': SoftThresholding(0.1),
            'input_size': (size, size, n_channels),
            'wav_pooling': True,
            'filters_normed': ['details', 'coarse'],

        },
        'run_id': 'learned_wavelet_div2k_30_1571647144',
    },
    {
        'name': 'learned wavelet_st_0.01',
        'init_function': learned_wavelet,
        'run_params': {
            'n_scales': 5, 
            'n_details': 256, 
            'n_coarse': 1, 
            'n_groupping': 256,
            'denoising_activation': SoftThresholding(0.01),
            'input_size': (size, size, n_channels),
        },
        'run_id': 'learned_wavelet_div2k_30_1571647448',
    },
    {
        'name': 'learned wavelet_st_1',
        'init_function': learned_wavelet,
        'run_params': {
            'n_scales': 5, 
            'n_details': 256, 
            'n_coarse': 1, 
            'n_groupping': 256,
            'denoising_activation': SoftThresholding(1.0),
            'input_size': (size, size, n_channels),
        },
        'run_id': 'learned_wavelet_div2k_30_1571647811',
    },
    {
        'name': 'learned wavelet_linear',
        'init_function': learned_wavelet,
        'run_params': {
            'n_scales': 5, 
            'n_details': 256, 
            'n_coarse': 1, 
            'n_groupping': 256,
            'denoising_activation': 'linear',
            'input_size': (size, size, n_channels),
        },
        'run_id': 'learned_wavelet_div2k_30_1571645666',
    },
]

In [8]:
def unpack_model(init_function=None, run_params=None, run_id=None, epoch=500, **dummy_kwargs):
    model = init_function(**run_params)
    chkpt_path = f'checkpoints/{run_id}-{epoch}.hdf5'
    model.load_weights(chkpt_path)
    return model

def enumerate_seq(seq, name):
    return (seq[i] for i in tqdm_notebook(range(len(seq)), desc=f'Val files for {name}'))

def metrics_for_params(reco_function=None, name=None, **net_params):
    model = unpack_model(**net_params)
    metrics = Metrics()
    pred_and_gt = [
        (model.predict_on_batch(images_noisy), images_gt)
        for images_noisy, images_gt in enumerate_seq(im_bsd68, name)
    ]    
    for im_recos, images in tqdm_notebook(pred_and_gt, desc=f'Stats for {name}'):
        metrics.push(images, im_recos)
    return metrics

def metrics_original():
    metrics = Metrics()
    pred_and_gt = [
        (images_noisy, images_gt)
        for images_noisy, images_gt in enumerate_seq(im_bsd68, 'Original noisy image')
    ]    
    for im_recos, images in tqdm_notebook(pred_and_gt, desc='Original noisy image'):
        metrics.push(images, im_recos)
    return metrics

def metrics_wavelets():
    metrics = Metrics()
    pred_and_gt = [
        (wavelet_denoising(images_noisy[0], 30/255, n_scales=5), images_gt[0])
        for images_noisy, images_gt in enumerate_seq(im_bsd68, 'Wavelet denoising')
    ]    
    for im_recos, images in tqdm_notebook(pred_and_gt, desc='Stats for wavelet denoising'):
        metrics.push(images, im_recos)
    return metrics

In [9]:
%%time
metrics = []
for net_params in all_net_params:
    metrics.append((net_params['name'], metrics_for_params(**net_params)))
    
metrics.append(('original', metrics_original()))
metrics.append(('wavelets', metrics_wavelets()))

Instructions for updating:
Colocations handled automatically by placer.


HBox(children=(IntProgress(value=0, description='Val files for unet', max=300, style=ProgressStyle(description…




HBox(children=(IntProgress(value=0, description='Stats for unet', max=300, style=ProgressStyle(description_wid…




HBox(children=(IntProgress(value=0, description='Val files for unet_bn', max=300, style=ProgressStyle(descript…




HBox(children=(IntProgress(value=0, description='Stats for unet_bn', max=300, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=0, description='Val files for learned wavelet', max=300, style=ProgressStyle(…




HBox(children=(IntProgress(value=0, description='Stats for learned wavelet', max=300, style=ProgressStyle(desc…




HBox(children=(IntProgress(value=0, description='Val files for learned wavelet_st_0.1', max=300, style=Progres…




HBox(children=(IntProgress(value=0, description='Stats for learned wavelet_st_0.1', max=300, style=ProgressSty…




HBox(children=(IntProgress(value=0, description='Val files for learned wavelet_st_0.1_wav_pooling', max=300, s…




HBox(children=(IntProgress(value=0, description='Stats for learned wavelet_st_0.1_wav_pooling', max=300, style…




HBox(children=(IntProgress(value=0, description='Val files for learned wavelet_st_0.1_unit_norm', max=300, sty…




HBox(children=(IntProgress(value=0, description='Stats for learned wavelet_st_0.1_unit_norm', max=300, style=P…




HBox(children=(IntProgress(value=0, description='Val files for learned wavelet_st_0.1_wav_pooling_unit_norm', …




HBox(children=(IntProgress(value=0, description='Stats for learned wavelet_st_0.1_wav_pooling_unit_norm', max=…




HBox(children=(IntProgress(value=0, description='Val files for learned wavelet_st_0.01', max=300, style=Progre…




HBox(children=(IntProgress(value=0, description='Stats for learned wavelet_st_0.01', max=300, style=ProgressSt…




HBox(children=(IntProgress(value=0, description='Val files for learned wavelet_st_1', max=300, style=ProgressS…




HBox(children=(IntProgress(value=0, description='Stats for learned wavelet_st_1', max=300, style=ProgressStyle…




HBox(children=(IntProgress(value=0, description='Val files for learned wavelet_linear', max=300, style=Progres…




HBox(children=(IntProgress(value=0, description='Stats for learned wavelet_linear', max=300, style=ProgressSty…




HBox(children=(IntProgress(value=0, description='Val files for Original noisy image', max=300, style=ProgressS…




HBox(children=(IntProgress(value=0, description='Original noisy image', max=300, style=ProgressStyle(descripti…




HBox(children=(IntProgress(value=0, description='Val files for Wavelet denoising', max=300, style=ProgressStyl…




HBox(children=(IntProgress(value=0, description='Stats for wavelet denoising', max=300, style=ProgressStyle(de…


CPU times: user 4min 50s, sys: 1min 1s, total: 5min 51s
Wall time: 6min 5s


In [10]:
metrics

[('unet', PSNR = 26.14 +/- 3.94 SSIM = 0.6788 +/- 0.1393),
 ('unet_bn', PSNR = 26.3 +/- 4.001 SSIM = 0.6776 +/- 0.1396),
 ('learned wavelet', PSNR = 25.95 +/- 3.885 SSIM = 0.6633 +/- 0.1216),
 ('learned wavelet_st_0.1', PSNR = 26.31 +/- 3.984 SSIM = 0.6716 +/- 0.1281),
 ('learned wavelet_st_0.1_wav_pooling',
  PSNR = 26.13 +/- 4.135 SSIM = 0.6856 +/- 0.1444),
 ('learned wavelet_st_0.1_unit_norm',
  PSNR = 25.09 +/- 3.616 SSIM = 0.619 +/- 0.09594),
 ('learned wavelet_st_0.1_wav_pooling_unit_norm',
  PSNR = 25.56 +/- 3.615 SSIM = 0.6574 +/- 0.1093),
 ('learned wavelet_st_0.01', PSNR = 25.97 +/- 3.745 SSIM = 0.6678 +/- 0.1176),
 ('learned wavelet_st_1', PSNR = 23.87 +/- 3.235 SSIM = 0.5559 +/- 0.1102),
 ('learned wavelet_linear', PSNR = 24.73 +/- 3.436 SSIM = 0.6051 +/- 0.09892),
 ('original', PSNR = 19.3 +/- 1.085 SSIM = 0.348 +/- 0.2037),
 ('wavelets', PSNR = 22.19 +/- 1.438 SSIM = 0.4577 +/- 0.167)]

In [11]:
metrics.sort(key=lambda x: x[1].metrics['PSNR'].mean())

In [12]:
# import pickle
# with open('metrics_net_rec_fastmri', 'wb') as f:
#     pickle.dump(metrics, f)

In [13]:
def n_model_params_for_params(reco_function=None, val_gen=None, name=None, **net_params):
    model = unpack_model(**net_params)
    n_params = model.count_params()
    return n_params

In [14]:
%%time
n_params = {}
for net_params in all_net_params:
    n_params[net_params['name']] =  n_model_params_for_params(**net_params)
    
n_params['original'] =  0
n_params['wavelets'] =  2  # number of scales and type of wavelets

CPU times: user 18.9 s, sys: 767 ms, total: 19.7 s
Wall time: 19.3 s


In [15]:
def runtime_for_params(reco_function=None, name=None, **net_params):
    model = unpack_model(**net_params)
    data = im_bsd68[0]
    start = time.time()
    model.predict_on_batch(data[0])
    end = time.time()
    return end - start

In [16]:
%%time
runtimes = {}
for net_params in tqdm_notebook(all_net_params):
    runtimes[net_params['name']] =  runtime_for_params(**net_params)
    
runtimes['original'] = 'NA'
runtimes['wavelets'] = 'NA'  # TODO: code function for that

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))


CPU times: user 38.6 s, sys: 790 ms, total: 39.4 s
Wall time: 39 s


In [17]:
metrics_table = pd.DataFrame(
    index=[name for name, _ in metrics], 
    columns=['PSNR-mean (std) (dB)', 'SSIM-mean (std)', '# params', 'Runtime (s)'],
)
for name, m in metrics:
    metrics_table.loc[name, 'PSNR-mean (std) (dB)'] = "{mean:.4} ({std:.4})".format(
        mean=m.metrics['PSNR'].mean(), 
        std=m.metrics['PSNR'].stddev(),
    )
    metrics_table.loc[name, 'SSIM-mean (std)'] = "{mean:.4} ({std:.4})".format(
        mean=m.metrics['SSIM'].mean(), 
        std=m.metrics['SSIM'].stddev(),
    )
    metrics_table.loc[name, '# params'] = "{}".format(
        n_params[name], 
     )
    metrics_table.loc[name, 'Runtime (s)'] = "{runtime:.4}".format(
        runtime=runtimes[name], 
    )

In [18]:
metrics_table

Unnamed: 0,PSNR-mean (std) (dB),SSIM-mean (std),# params,Runtime (s)
original,19.3 (0.5426),0.348 (0.1019),0,
wavelets,22.19 (0.719),0.4577 (0.0835),2,
learned wavelet_st_1,23.87 (1.617),0.5559 (0.05508),2976055,1.173
learned wavelet_linear,24.73 (1.718),0.6051 (0.04946),2976055,1.212
learned wavelet_st_0.1_unit_norm,25.09 (1.808),0.619 (0.04797),2976055,1.021
learned wavelet_st_0.1_wav_pooling_unit_norm,25.56 (1.808),0.6574 (0.05467),2976070,1.06
learned wavelet,25.95 (1.943),0.6633 (0.06082),2976055,0.9569
learned wavelet_st_0.01,25.97 (1.872),0.6678 (0.05881),2976055,1.165
learned wavelet_st_0.1_wav_pooling,26.13 (2.067),0.6856 (0.07222),2976070,0.9872
unet,26.14 (1.97),0.6788 (0.06967),31030793,0.8287
