In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib nbagg
import random
import time
import warnings
warnings.filterwarnings("ignore")

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm_notebook

from fastmri_recon.data.oasis_sequences import Masked2DSequence, ZeroFilled2DSequence
from fastmri_recon.helpers.evaluate import METRIC_FUNCS, Metrics
from fastmri_recon.helpers.nn_mri import lrelu
from fastmri_recon.helpers.reconstruction import reco_and_gt_zfilled_from_val_file, reco_and_gt_net_from_val_file
from fastmri_recon.models.cascading import cascade_net
from fastmri_recon.models.kiki_sep import full_kiki_net
from fastmri_recon.models.pdnet import pdnet
from fastmri_recon.models.unet import unet

Using TensorFlow backend.


In [2]:
np.random.seed(0)

In [3]:
plt.rcParams['figure.figsize'] = (9, 5)
plt.rcParams['image.cmap'] = 'gray'

In [4]:
# paths
train_path = '/media/Zaccharie/UHRes/OASIS_data/'

# generators
AF = 4
train_gen = Masked2DSequence(
    train_path,
    af=AF,
    inner_slices=32,
    scale_factor=1e-2,
    seed=0,
    rand=True,
    val_split=0.1,
)
val_gen_mask = train_gen.val_sequence
n_train = 1000
n_val = 200
random.seed(0)
train_gen.filenames = random.sample(train_gen.filenames, n_train)
val_gen_mask.filenames = random.sample(val_gen_mask.filenames, n_val)

train_gen_zero = ZeroFilled2DSequence(
    train_path,
    af=AF,
    inner_slices=32,
    scale_factor=1e-2,
    seed=0,
    rand=False,
    val_split=0.1,
    n_pooling=3,
)
val_gen_zero = train_gen_zero.val_sequence
random.seed(0)
train_gen_zero.filenames = random.sample(train_gen_zero.filenames, n_train)
val_gen_zero.filenames = random.sample(val_gen_zero.filenames, n_val)

In [6]:
all_net_params = [
    {
        'name': 'unet',
        'init_function': unet,
        'run_params': {
            'n_layers': 4,
            'pool': 'max',
            "layers_n_channels": [16, 32, 64, 128],
            'layers_n_non_lins': 2,
        },
        'val_gen': val_gen_zero,
        'run_id': 'unet_af4_oasis_1570619888',
    },
    {
        'name': 'pdnet',
        'init_function': pdnet,
        'run_params': {
            'n_primal': 5,
            'n_dual': 5,
            'n_iter': 10,
            'n_filters': 32,
        },
        'run_id': 'pdnet_af4_oasis_1570722239',
    },
    {
        'name': 'cascadenet',
        'init_function': cascade_net,
        'run_params': {
            'n_cascade': 5,
            'n_convs': 5,
            'n_filters': 48,
            'noiseless': True,
        },
        'run_id': 'cascadenet_af4_oasis_1569491836',
    },
        {
        'name': 'kikinet-sep',
        'init_function': full_kiki_net,
        'run_params': {
            'n_convs': 16,
            'n_filters': 48,
            'noiseless': True,
            'activation': lrelu,
        },
        'run_id': 'kikinet_sep_I2_af4_oasis_1572552792',
        'epoch': 50,
    },
]

In [7]:
def unpack_model(init_function=None, run_params=None, run_id=None, epoch=300, **dummy_kwargs):
    try:
        model = init_function(input_size=(None, None, 1), fastmri=False, **run_params)
    except:
        model = init_function(input_size=(None, None, 1), **run_params)
    chkpt_path = f'../checkpoints/{run_id}-{epoch}.hdf5'
    model.load_weights(chkpt_path)
    return model

def metrics_for_params(name=None, val_gen=None, **net_params):
    if val_gen is None:
        val_gen = val_gen_mask
    model = unpack_model(**net_params)
    metrics = Metrics(METRIC_FUNCS)
    pred_and_gt = [
        reco_and_gt_net_from_val_file(*val_gen[i], model)
        for i in tqdm_notebook(range(len(val_gen)), desc=f'Val files for {name}')
    ]    
    for im_recos, images in tqdm_notebook(pred_and_gt, desc=f'Stats for {name}'):
        metrics.push(images, im_recos)
    return metrics


def metrics_zfilled():
    metrics = Metrics(METRIC_FUNCS)
    pred_and_gt = [
        reco_and_gt_zfilled_from_val_file(*val_gen_mask[i], crop=False)
        for i in tqdm_notebook(range(len(val_gen_mask)), desc='Val files for z-filled')
    ]    
    for im_recos, images in tqdm_notebook(pred_and_gt, desc='Stats for z-filled'):
        metrics.push(images, im_recos)
    return metrics

In [8]:
%%time
metrics = []
for net_params in all_net_params:
    metrics.append((net_params['name'], metrics_for_params(**net_params)))
    
metrics.append(('zfilled', metrics_zfilled()))

W1115 16:22:05.872544 140339832608512 deprecation_wrapper.py:119] From /volatile/home/Zaccharie/workspace/keras/keras/backend/tensorflow_backend.py:4070: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.



HBox(children=(IntProgress(value=0, description='Val files for unet', max=200, style=ProgressStyle(description…

W1115 16:22:12.575501 140339832608512 deprecation_wrapper.py:119] From /volatile/home/Zaccharie/workspace/keras/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.






HBox(children=(IntProgress(value=0, description='Stats for unet', max=200, style=ProgressStyle(description_wid…

W1115 16:25:15.751831 140339832608512 deprecation.py:323] From /volatile/home/Zaccharie/workspace/fastmri-reproducible-benchmark/fastmri_recon/helpers/nn_mri.py:92: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.cast` instead.





HBox(children=(IntProgress(value=0, description='Val files for pdnet', max=200, style=ProgressStyle(descriptio…




HBox(children=(IntProgress(value=0, description='Stats for pdnet', max=200, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Val files for cascadenet', max=200, style=ProgressStyle(descr…




HBox(children=(IntProgress(value=0, description='Stats for cascadenet', max=200, style=ProgressStyle(descripti…




HBox(children=(IntProgress(value=0, description='Val files for kinet', max=200, style=ProgressStyle(descriptio…




HBox(children=(IntProgress(value=0, description='Stats for kinet', max=200, style=ProgressStyle(description_wi…




HBox(children=(IntProgress(value=0, description='Val files for kikinet-sep', max=200, style=ProgressStyle(desc…




HBox(children=(IntProgress(value=0, description='Stats for kikinet-sep', max=200, style=ProgressStyle(descript…




HBox(children=(IntProgress(value=0, description='Val files for z-filled', max=200, style=ProgressStyle(descrip…




HBox(children=(IntProgress(value=0, description='Stats for z-filled', max=200, style=ProgressStyle(description…


CPU times: user 23min 2s, sys: 3min 35s, total: 26min 37s
Wall time: 29min


In [9]:
metrics

[('unet', PSNR = 29.8 +/- 2.78 SSIM = 0.8467 +/- 0.07955),
 ('pdnet', PSNR = 33.22 +/- 3.823 SSIM = 0.9097 +/- 0.07157),
 ('cascadenet', PSNR = 32 +/- 3.461 SSIM = 0.8867 +/- 0.06549),
 ('kinet', PSNR = 29.11 +/- 2.715 SSIM = 0.8226 +/- 0.07783),
 ('kikinet-sep', PSNR = 30.08 +/- 2.86 SSIM = 0.8532 +/- 0.06717),
 ('zfilled', PSNR = 26.11 +/- 2.901 SSIM = 0.6724 +/- 0.06137)]

In [10]:
metrics.sort(key=lambda x: x[1].metrics['PSNR'].mean())

In [11]:
# import pickle
# with open('metrics_net_rec_oasis', 'wb') as f:
#     pickle.dump(metrics, f)

In [12]:
def n_model_params_for_params(reco_function=None, val_gen=None, name=None, **net_params):
    model = unpack_model(**net_params)
    n_params = model.count_params()
    return n_params

In [13]:
%%time
n_params = {}
for net_params in tqdm_notebook(all_net_params):
    n_params[net_params['name']] =  n_model_params_for_params(**net_params)
    
n_params['zfilled'] =  0

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))


CPU times: user 19.3 s, sys: 590 ms, total: 19.9 s
Wall time: 19.6 s


In [14]:
def runtime_for_params(name=None, val_gen=None, **net_params):
    if val_gen is None:
        val_gen = val_gen_mask
    model = unpack_model(**net_params)
    data = val_gen[0]
    start = time.time()
    reco_and_gt_net_from_val_file(*data, model)
    end = time.time()
    return end - start

def runtime_zfilled():
    data = val_gen_mask[0]
    start = time.time()
    reco_and_gt_zfilled_from_val_file(*data, crop=False)
    end = time.time()
    return end - start

In [15]:
%%time
runtimes = {}
for net_params in tqdm_notebook(all_net_params):
    runtimes[net_params['name']] =  runtime_for_params(**net_params)
    
runtimes['zfilled'] = runtime_zfilled()

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))


CPU times: user 40.8 s, sys: 1.63 s, total: 42.4 s
Wall time: 43 s


In [16]:
metrics_table = pd.DataFrame(
    index=[name for name, _ in metrics], 
    columns=['PSNR-mean (std) (dB)', 'SSIM-mean (std)', '# params', 'Runtime (s)'],
)
for name, m in metrics:
    metrics_table.loc[name, 'PSNR-mean (std) (dB)'] = "{mean:.4} ({std:.4})".format(
        mean=m.metrics['PSNR'].mean(), 
        std=m.metrics['PSNR'].stddev(),
    )
    metrics_table.loc[name, 'SSIM-mean (std)'] = "{mean:.4} ({std:.4})".format(
        mean=m.metrics['SSIM'].mean(), 
        std=m.metrics['SSIM'].stddev(),
    )
    metrics_table.loc[name, '# params'] = "{}".format(
        n_params[name], 
     )
    metrics_table.loc[name, 'Runtime (s)'] = "{runtime:.4}".format(
        runtime=runtimes[name], 
    )

In [17]:
metrics_table

Unnamed: 0,PSNR-mean (std) (dB),SSIM-mean (std),# params,Runtime (s)
zfilled,26.11 (1.45),0.6724 (0.03069),0,0.1651
kinet,29.11 (1.357),0.8226 (0.03892),625540,2.492
unet,29.8 (1.39),0.8467 (0.03977),481801,1.202
kikinet-sep,30.08 (1.43),0.8532 (0.03358),1251080,3.567
cascadenet,32.0 (1.731),0.8867 (0.03274),424570,2.234
pdnet,33.22 (1.912),0.9097 (0.03579),318280,2.758
