In [1]:
%cd ..

/volatile/home/Zaccharie/workspace/understanding-unets


In [2]:
# # this just to make sure we are using only on CPU
# import os
# os.environ["CUDA_VISIBLE_DEVICES"]="-1"

In [3]:
%load_ext autoreload
%autoreload 2
%matplotlib nbagg
import time
import warnings
warnings.filterwarnings("ignore")

import bm3d
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm_notebook

from learning_wavelets.data import im_generators, ConcatenateGenerators
from learning_wavelets.evaluate import Metrics

In [4]:
np.random.seed(0)

In [5]:
plt.rcParams['figure.figsize'] = (9, 5)
plt.rcParams['image.cmap'] = 'gray'

In [6]:
grey = True
n_channels = 3
if grey:
    n_channels = 1
noise_std = 25
im_gen_train, im_gen_val, im_gen_test, size, n_samples_train = im_generators(
    'bsd68', 
    batch_size=1, 
    validation_split=0, 
    no_augment=True, 
    noise_std=noise_std,
    grey=grey
)    
im_bsd68 = im_gen_test

In [7]:
def enumerate_seq(seq, name):
    return (seq[i] for i in tqdm_notebook(range(len(seq)), desc=f'Val files for {name}'))

def enumerate_seq_noisy(seq, name):
    return (np.squeeze(seq[i][0]) for i in tqdm_notebook(range(len(seq)), desc=f'Val files for {name}'))

def enumerate_seq_gt(seq):
    return (np.squeeze(seq[i][1]) for i in range(len(seq)))


def metrics_original():
    metrics = Metrics()
    pred_and_gt = [
        (images_noisy, images_gt)
        for images_noisy, images_gt in enumerate_seq(im_bsd68, 'Original noisy image')
    ]    
    for im_recos, images in tqdm_notebook(pred_and_gt, desc='Original noisy image'):
        metrics.push(images, im_recos)
    return metrics


def metrics_bm3d():
    metrics = Metrics()
    pred = [
        bm3d.bm3d(image_noisy + 0.5, sigma_psd=noise_std/255, stage_arg=bm3d.BM3DStages.ALL_STAGES) - 0.5
        for image_noisy in enumerate_seq_noisy(im_bsd68, f'BM3D')
    ]
    gt = enumerate_seq_gt(im_bsd68)
    for im_recos, images in tqdm_notebook(zip(pred, gt), desc='Stats for bm3d'):
        metrics.push(images[..., None], im_recos[..., None])
    return metrics

In [8]:
%%time
metrics = []
    
metrics.append(('original', metrics_original()))
metrics.append(('bm3d', metrics_bm3d()))

HBox(children=(IntProgress(value=0, description='Val files for Original noisy image', max=68, style=ProgressSt…




HBox(children=(IntProgress(value=0, description='Original noisy image', max=68, style=ProgressStyle(descriptio…




HBox(children=(IntProgress(value=0, description='Val files for BM3d', max=68, style=ProgressStyle(description_…




HBox(children=(IntProgress(value=1, bar_style='info', description='Stats for bm3d', max=1, style=ProgressStyle…


CPU times: user 17min 16s, sys: 5min 7s, total: 22min 24s
Wall time: 13min 38s


In [9]:
metrics

[('original', PSNR = 20.17 +/- 0.03205 SSIM = 0.3883 +/- 0.2559),
 ('bm3d', PSNR = 28.6 +/- 5.105 SSIM = 0.7864 +/- 0.1415)]

In [10]:
metrics.sort(key=lambda x: x[1].metrics['PSNR'].mean())

In [11]:
# import pickle
# with open('metrics_net_rec_fastmri', 'wb') as f:
#     pickle.dump(metrics, f)

In [12]:
def n_model_params_for_params(reco_function=None, val_gen=None, name=None, **net_params):
    model = unpack_model(**net_params)
    n_params = model.count_params()
    return n_params

In [13]:
%%time
n_params = {}
    
n_params['original'] =  0
n_params['bm3d'] =  'NA'

CPU times: user 2 µs, sys: 1e+03 ns, total: 3 µs
Wall time: 5.25 µs


In [14]:
%%time
runtimes = {}

runtimes['original'] = 'NA'
runtimes['bm3d'] = 'NA'  # TODO: code function for that

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 5.72 µs


In [15]:
metrics_table = pd.DataFrame(
    index=[name for name, _ in metrics], 
    columns=['PSNR-mean (std) (dB)', 'SSIM-mean (std)', '# params', 'Runtime (s)'],
)
for name, m in metrics:
    metrics_table.loc[name, 'PSNR-mean (std) (dB)'] = "{mean:.4} ({std:.4})".format(
        mean=m.metrics['PSNR'].mean(), 
        std=m.metrics['PSNR'].stddev(),
    )
    metrics_table.loc[name, 'SSIM-mean (std)'] = "{mean:.4} ({std:.4})".format(
        mean=m.metrics['SSIM'].mean(), 
        std=m.metrics['SSIM'].stddev(),
    )
    metrics_table.loc[name, '# params'] = "{}".format(
        n_params[name], 
     )
    metrics_table.loc[name, 'Runtime (s)'] = "{runtime:.4}".format(
        runtime=runtimes[name], 
    )

In [16]:
metrics_table

Unnamed: 0,PSNR-mean (std) (dB),SSIM-mean (std),# params,Runtime (s)
original,20.17 (0.01602),0.3883 (0.128),0.0,
bm3d,28.6 (2.553),0.7864 (0.07073),,
