In [1]:
# # this just to make sure we are using only on CPU
# import os
# os.environ["CUDA_VISIBLE_DEVICES"]="-1"

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib nbagg
import time
import warnings
warnings.filterwarnings("ignore")

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm_notebook

from fastmri_recon.data.sequences.fastmri_sequences import ZeroFilled2DSequence, Masked2DSequence
from fastmri_recon.evaluate.metrics.np_metrics import METRIC_FUNCS, Metrics
from fastmri_recon.evaluate.reconstruction.zero_filled_reconstruction import reco_and_gt_zfilled_from_val_file
from fastmri_recon.evaluate.reconstruction.cross_domain_reconstruction import reco_and_gt_net_from_val_file
from fastmri_recon.evaluate.reconstruction.unet_reconstruction import reco_and_gt_unet_from_val_file 
from fastmri_recon.models.functional_models.cascading import cascade_net
from fastmri_recon.models.functional_models.kiki import kiki_net
from fastmri_recon.models.functional_models.kiki_sep import full_kiki_net
from fastmri_recon.models.functional_models.pdnet import pdnet
from fastmri_recon.models.functional_models.unet import unet
from fastmri_recon.models.utils.non_linearities import lrelu

In [3]:
np.random.seed(0)

In [4]:
plt.rcParams['figure.figsize'] = (9, 5)
plt.rcParams['image.cmap'] = 'gray'

In [5]:
AF = 4
contrast = 'CORPDFS_FBK' # CORPDFS_FBK CORPD_FBK None

In [6]:
val_path = '/media/Zaccharie/UHRes/singlecoil_val/'
val_gen_zero = ZeroFilled2DSequence(val_path, af=AF, norm=True, mode='validation', contrast=contrast)
val_gen_zero.filenames = val_gen_zero.filenames[:2]
val_gen_scaled = Masked2DSequence(val_path, mode='validation', af=AF, scale_factor=1e6, contrast=contrast)
val_gen_scaled.filenames = val_gen_scaled.filenames[:2]

In [7]:
all_net_params = [
#     {
#         'name': 'unet',
#         'init_function': unet,
#         'run_params': {
#             'n_layers': 4,
#             'pool': 'max',
#             "layers_n_channels": [16, 32, 64, 128],
#             'layers_n_non_lins': 2,
#             'input_size': (320, 320, 1),
#         },
#         'val_gen': val_gen_zero,
#         'run_id': 'unet_af4_1569210349',
#         'reco_function': reco_and_gt_unet_from_val_file,
#     },
    {
        'name': 'pdnet',
        'init_function': pdnet,
        'run_params': {
            'n_primal': 5,
            'n_dual': 5,
            'n_iter': 10,
            'n_filters': 32,
        },
        'val_gen': val_gen_scaled,
        'run_id': 'pdnet_af4_1568384763',
        'reco_function': reco_and_gt_net_from_val_file,
    },
#     {
#         'name': 'cascadenet',
#         'init_function': cascade_net,
#         'run_params': {
#             'n_cascade': 5,
#             'n_convs': 5,
#             'n_filters': 48,
#             'noiseless': True,
#         },
#         'val_gen': val_gen_scaled,
#         'run_id': 'cascadenet_af4_1568926824',
#         'reco_function': reco_and_gt_net_from_val_file,
#     },
#     {
#         'name': 'kikinet',
#         'init_function': kiki_net,
#         'run_params': {
#             'n_cascade': 2,
#             'n_convs': 25,
#             'n_filters': 32,
#             'noiseless': True,
#         },
#         'val_gen': val_gen_scaled,
#         'run_id': 'kikinet_af4_1568724379',
#         'reco_function': reco_and_gt_net_from_val_file,
#     },
#     {
#         'name': 'kikinet-sep-8',
#         'init_function': full_kiki_net,
#         'run_params': {
#             'n_convs': 8,
#             'n_filters': 48,
#             'noiseless': True,
#             'activation': lrelu,
#         },
#         'val_gen': val_gen_scaled,
#         'run_id': 'kikinet_sep_I2_af4_1569964596',
#         'reco_function': reco_and_gt_net_from_val_file,
#         'epoch': 50,
#     },
#     {
#         'name': 'kikinet-sep-16',
#         'init_function': full_kiki_net,
#         'run_params': {
#             'n_convs': 16,
#             'n_filters': 48,
#             'noiseless': True,
#             'activation': lrelu,
#         },
#         'val_gen': val_gen_scaled,
#         'run_id': 'kikinet_sep_I2_af4_1570049560',
#         'reco_function': reco_and_gt_net_from_val_file,
#         'epoch': 50,
#     },
]

In [8]:
def unpack_model(init_function=None, run_params=None, run_id=None, epoch=300, **dummy_kwargs):
    model = init_function(**run_params)
    chkpt_path = f'../checkpoints/{run_id}-{epoch}.hdf5'
    model.load_weights(chkpt_path)
    return model

def metrics_for_params(reco_function=None, val_gen=None, name=None, **net_params):
    model = unpack_model(**net_params)
    metrics = Metrics(METRIC_FUNCS)
    pred_and_gt = [
        reco_function(*val_gen[i], model)
        for i in tqdm_notebook(range(len(val_gen)), desc=f'Val files for {name}')
    ]    
    for im_recos, images in tqdm_notebook(pred_and_gt, desc=f'Stats for {name}'):
        metrics.push(images, im_recos)
    return metrics

def metrics_zfilled():
    metrics = Metrics(METRIC_FUNCS)
    pred_and_gt = [
        reco_and_gt_zfilled_from_val_file(*val_gen_scaled[i])
        for i in tqdm_notebook(range(len(val_gen_scaled)), desc='Val files for z-filled')
    ]    
    for im_recos, images in tqdm_notebook(pred_and_gt, desc='Stats for z-filled'):
        metrics.push(images, im_recos)
    return metrics

In [9]:
%%time
metrics = []
for net_params in all_net_params:
    metrics.append((net_params['name'], metrics_for_params(**net_params)))
    
metrics.append(('zfilled', metrics_zfilled()))

HBox(children=(FloatProgress(value=0.0, description='Val files for pdnet', max=2.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Stats for pdnet', max=2.0, style=ProgressStyle(descriptio…




HBox(children=(FloatProgress(value=0.0, description='Val files for z-filled', max=2.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, description='Stats for z-filled', max=2.0, style=ProgressStyle(descrip…


CPU times: user 1min 4s, sys: 21.5 s, total: 1min 25s
Wall time: 1min 18s


In [10]:
metrics

[('pdnet', PSNR = 27.66 +/- 6.883 SSIM = 0.5545 +/- 0.2644),
 ('zfilled', PSNR = 26.04 +/- 6.819 SSIM = 0.504 +/- 0.2814)]

In [11]:
metrics.sort(key=lambda x: x[1].metrics['PSNR'].mean())

In [12]:
# import pickle
# with open('metrics_net_rec_fastmri', 'wb') as f:
#     pickle.dump(metrics, f)

In [13]:
# # former metrics
# metrics = [
#     ('unet', {'PSNR': {'mean': 31.78, 'stddev': 6.534},  'SSIM':{'mean': 0.7205, 'stddev': 0.2595}}),
#     ('pdnet', {'PSNR': {'mean': 32.15, 'stddev': 6.905},  'SSIM':{'mean': 0.7292, 'stddev': 0.2657}}),
#     ('cascadenet', {'PSNR': {'mean': 31.97, 'stddev': 6.951},  'SSIM':{'mean': 0.7191, 'stddev': 0.2719}}),
#     ('kikinet', {'PSNR': {'mean': 28.43, 'stddev': 4.345},  'SSIM':{'mean': 0.6307, 'stddev': 0.2366}}),
#     ('zfilled', {'PSNR': {'mean': 29.61, 'stddev': 5.287},  'SSIM':{'mean': 0.6577, 'stddev': 0.2333}}),
# ]

In [14]:
def n_model_params_for_params(reco_function=None, val_gen=None, name=None, **net_params):
    model = unpack_model(**net_params)
    n_params = model.count_params()
    return n_params

In [15]:
%%time
n_params = {}
for net_params in all_net_params:
    n_params[net_params['name']] =  n_model_params_for_params(**net_params)
    
n_params['zfilled'] =  0

CPU times: user 1.46 s, sys: 94.4 ms, total: 1.55 s
Wall time: 1.45 s


In [16]:
def runtime_for_params(reco_function=None, val_gen=None, name=None, **net_params):
    model = unpack_model(**net_params)
    data = val_gen[0]
    start = time.time()
    reco_function(*data, model)
    end = time.time()
    return end - start

def runtime_zfilled():
    data = val_gen_scaled[0]
    start = time.time()
    reco_and_gt_zfilled_from_val_file(*data)
    end = time.time()
    return end - start

In [17]:
%%time
runtimes = {}
for net_params in tqdm_notebook(all_net_params):
    runtimes[net_params['name']] =  runtime_for_params(**net_params)
    
runtimes['zfilled'] = runtime_zfilled()

HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


CPU times: user 28.5 s, sys: 9.18 s, total: 37.6 s
Wall time: 33.4 s


In [18]:
metrics_table = pd.DataFrame(
    index=[name for name, _ in metrics], 
    columns=['PSNR-mean (std) (dB)', 'SSIM-mean (std)', '# params', 'Runtime (s)'],
)
for name, m in metrics:
    metrics_table.loc[name, 'PSNR-mean (std) (dB)'] = "{mean:.4} ({std:.4})".format(
        mean=m.metrics['PSNR'].mean(), 
        std=m.metrics['PSNR'].stddev(),
    )
    metrics_table.loc[name, 'SSIM-mean (std)'] = "{mean:.4} ({std:.4})".format(
        mean=m.metrics['SSIM'].mean(), 
        std=m.metrics['SSIM'].stddev(),
    )
    metrics_table.loc[name, '# params'] = "{}".format(
        n_params[name], 
     )
    metrics_table.loc[name, 'Runtime (s)'] = "{runtime:.4}".format(
        runtime=runtimes[name], 
    )

In [19]:
metrics_table

Unnamed: 0,PSNR-mean (std) (dB),SSIM-mean (std),# params,Runtime (s)
zfilled,26.04 (3.41),0.504 (0.1407),0,0.3925
pdnet,27.66 (3.442),0.5545 (0.1322),318280,31.57
