In [200]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import codecs
from keras.optimizers import Adam
from keras.datasets import cifar10
import json
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import os
import pandas as pd
from skimage import color
from skimage.measure import compare_psnr, compare_ssim
import time
from tqdm import tqdm_notebook

from fastmri_recon.helpers.evaluate import METRIC_FUNCS, Metrics
from fastmri_recon.helpers.utils import keras_ssim, keras_psnr
from fastmri_recon.data.test_generators import DataGenerator
from fastmri_recon.models.unet import unet
from fastmri_recon.helpers.evaluate import psnr, ssim, mse, nmse
from fastmri_recon.models.discriminator import discriminator_model, generator_containing_discriminator_multiple_outputs
from fastmri_recon.helpers.adversarial_training import compile_models



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Link to dl the dataset: http://chaladze.com/l5/

In [149]:
def load_data(path):
    path_train = path + "train/"
    path_test = path + "test/"
    train = os.listdir(path_train)
    train_data = []
    for s in train:
        for image in os.listdir(path_train+s):       
            train_data.append(mpimg.imread(path_train+s+'/'+image))

    test = os.listdir(path_test)
    test_data = []
    for s in test:
        for image in os.listdir(path_test+s):          
            test_data.append(mpimg.imread(path_test+s+'/'+image))

    x_train = color.rgb2gray(np.array(train_data))
    x_test = color.rgb2gray( np.array(test_data))

    return x_train, x_test

In [150]:
#load data

AF=2
path="/Users/WorkAccount/Desktop/Linnaeus_5_64X64/"

if path != "":
    x_train, x_test = load_data(path)
else:
    (x_train, _), (x_test, _) = cifar10.load_data()
    x_train = color.rgb2gray(x_train)
    x_test = color.rgb2gray(x_test)

im_shape = x_train.shape

val_gen = DataGenerator(AF, x_test).flow_z_filled_images()

In [161]:
all_net_params = [
    {
        'name': 'unet_gan',
        'init_function': unet,
        'run_params': {
            'n_layers': 4, 
            'pool': 'max', 
            "layers_n_channels": [16, 32, 64, 128], 
            'layers_n_non_lins': 2,
            "input_size":(64, 64, 1),
        },
        'val_gen': val_gen,
        'run_id': 'unet_gan_af2_1576259200',
    },
]

In [185]:
def unpack_model(init_function=None, run_params=None, run_id=None, epoch=50, **dummy_kwargs):
        g = init_function(**run_params)
        g.name = 'Reconstructor'
        chkpt_path = f'../checkpoints/{run_id}-{epoch}.hdf5'
        g.load_weights(chkpt_path)
        return g

def metrics_for_params(reco_function=None, val_gen=None, name=None, **net_params):
    g = unpack_model(**net_params)
    metrics = Metrics(METRIC_FUNCS)
    pred_and_gt=[]
    for i in tqdm_notebook(range(x_test.shape[0]), desc=f'Val files for {name}'):
        x, y = next(val_gen)
        pred_and_gt.append((np.squeeze(g.predict(x), axis=3), np.squeeze(y, axis=3)))
    for im_recos, images in tqdm_notebook(pred_and_gt, desc=f'Stats for {name}'):
        metrics.push(images, im_recos)
    print(len(pred_and_gt))
    return metrics

In [186]:
%%time
metrics = []
for net_params in all_net_params:
    metrics.append((net_params['name'], metrics_for_params(**net_params)))

{'name': 'unet_gan', 'init_function': <function unet at 0x1c3c526710>, 'run_params': {'n_layers': 4, 'pool': 'max', 'layers_n_channels': [16, 32, 64, 128], 'layers_n_non_lins': 2, 'input_size': (64, 64, 1)}, 'val_gen': <fastmri_recon.helpers.threadsafe_gen.threadsafe_iter object at 0x1c3d101750>, 'run_id': 'unet_gan_af2_1576259200'}


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  if sys.path[0] == '':


HBox(children=(IntProgress(value=0, description='Val files for unet_gan', max=2000, style=ProgressStyle(descri…




Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  from ipykernel import kernelapp as app


HBox(children=(IntProgress(value=0, description='Stats for unet_gan', max=2000, style=ProgressStyle(descriptio…


2000
CPU times: user 49.4 s, sys: 3.82 s, total: 53.2 s
Wall time: 32.7 s


In [187]:
print(metrics)

[('unet_gan', PSNR = 12.42 +/- 4.245 SSIM = 0.1931 +/- 0.1903)]


In [188]:
def n_model_params_for_params(val_gen=None, name=None, **net_params):
    model = unpack_model(**net_params)
    n_params = model.count_params()
    return n_params

In [190]:
%%time
n_params = {}
for net_params in all_net_params:
    n_params[net_params['name']] =  n_model_params_for_params(**net_params)

CPU times: user 3.82 s, sys: 120 ms, total: 3.94 s
Wall time: 3.99 s


In [196]:
def runtime_for_params(reco_function=None, val_gen=None, name=None, **net_params):
    g = unpack_model(**net_params)
    x, y = next(val_gen)
    start = time.time()
    pred = g.predict(x)
    end = time.time()
    return end - start

In [197]:
%%time
runtimes = {}
for net_params in tqdm_notebook(all_net_params):
    runtimes[net_params['name']] =  runtime_for_params(**net_params)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))


CPU times: user 5.3 s, sys: 151 ms, total: 5.45 s
Wall time: 5.48 s


In [198]:
metrics_table = pd.DataFrame(
    index=[name for name, _ in metrics], 
    columns=['PSNR-mean (std) (dB)', 'SSIM-mean (std)', '# params', 'Runtime (s)'],
)
for name, m in metrics:
    metrics_table.loc[name, 'PSNR-mean (std) (dB)'] = "{mean:.4} ({std:.4})".format(
        mean=m.metrics['PSNR'].mean(), 
        std=m.metrics['PSNR'].stddev(),
    )
    metrics_table.loc[name, 'SSIM-mean (std)'] = "{mean:.4} ({std:.4})".format(
        mean=m.metrics['SSIM'].mean(), 
        std=m.metrics['SSIM'].stddev(),
    )
    metrics_table.loc[name, '# params'] = "{}".format(
        n_params[name], 
     )
    metrics_table.loc[name, 'Runtime (s)'] = "{runtime:.4}".format(
        runtime=runtimes[name], 
    )

In [199]:
metrics_table

Unnamed: 0,PSNR-mean (std) (dB),SSIM-mean (std),# params,Runtime (s)
unet_gan,12.42 (2.122),0.1931 (0.09515),481801,1.453
