In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import codecs
from keras.optimizers import Adam
from keras.datasets import cifar10
import json
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import os
import pandas as pd
from skimage import color
from skimage.measure import compare_psnr, compare_ssim
import time
from tqdm import tqdm_notebook

from fastmri_recon.helpers.evaluate import METRIC_FUNCS, Metrics
from fastmri_recon.helpers.utils import keras_ssim, keras_psnr
from fastmri_recon.data.test_generators import DataGenerator
from fastmri_recon.models.unet import unet
from fastmri_recon.helpers.evaluate import psnr, ssim, mse, nmse
from fastmri_recon.models.discriminator import discriminator_model, generator_containing_discriminator_multiple_outputs
from fastmri_recon.helpers.adversarial_training import compile_models



Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


Link to dl the dataset: http://chaladze.com/l5/

In [2]:
def load_data(path):
    path_train = path + "train/"
    path_test = path + "test/"
    train = os.listdir(path_train)
    train_data = []
    for s in train:
        for image in os.listdir(path_train+s):       
            train_data.append(mpimg.imread(path_train+s+'/'+image))

    test = os.listdir(path_test)
    test_data = []
    for s in test:
        for image in os.listdir(path_test+s):          
            test_data.append(mpimg.imread(path_test+s+'/'+image))

    x_train = color.rgb2gray(np.array(train_data))
    x_test = color.rgb2gray( np.array(test_data))

    return x_train, x_test

In [3]:
#load data

AF=2
path="/Users/WorkAccount/Desktop/Linnaeus_5_64X64/"

if path != "":
    x_train, x_test = load_data(path)
else:
    (x_train, _), (x_test, _) = cifar10.load_data()
    x_train = color.rgb2gray(x_train)
    x_test = color.rgb2gray(x_test)

im_shape = x_train.shape

val_gen = DataGenerator(AF, x_test).flow_z_filled_images()

In [4]:
all_net_params = [
    {
        'name': 'unet_gan',
        'init_function': unet,
        'run_params': {
            'n_layers': 4, 
            'pool': 'max', 
            "layers_n_channels": [16, 32, 64, 128], 
            'layers_n_non_lins': 2,
            "input_size":(64, 64, 1),
        },
        'val_gen': val_gen,
        'run_id': 'unet_gan_af2_1576259200',
    },
]

In [5]:
def unpack_model(init_function=None, run_params=None, run_id=None, epoch=50, **dummy_kwargs):
        g = init_function(**run_params)
        g.name = 'Reconstructor'
        chkpt_path = f'../checkpoints/{run_id}-{epoch}.hdf5'
        g.load_weights(chkpt_path)
        return g

def original_metrics(val_gen=None):
    name='original metrics'
    metrics = Metrics(METRIC_FUNCS)
    vals=[]
    for i in tqdm_notebook(range(x_test.shape[0]), desc=f'Val files for {name}'):
        x, y = next(val_gen)
        vals.append((np.squeeze(x, axis=3), np.squeeze(y, axis=3)))
    for im_recos, images in tqdm_notebook(vals, desc=f'Stats for {name}'):
        metrics.push(images, im_recos)
    print(len(vals))
    return metrics

def metrics_for_params(val_gen=None, name=None, **net_params):
    g = unpack_model(**net_params)
    metrics = Metrics(METRIC_FUNCS)
    pred_and_gt=[]
    for i in tqdm_notebook(range(x_test.shape[0]), desc=f'Val files for {name}'):
        x, y = next(val_gen)
        pred_and_gt.append((np.squeeze(g.predict(x), axis=3), np.squeeze(y, axis=3)))
    for im_recos, images in tqdm_notebook(pred_and_gt, desc=f'Stats for {name}'):
        metrics.push(images, im_recos)
    print(len(pred_and_gt))
    return metrics

In [None]:
%%time
metrics = []
metrics.append(('original metrics', original_metrics(val_gen)))
for net_params in all_net_params:
    metrics.append((net_params['name'], metrics_for_params(**net_params)))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  if sys.path[0] == '':


HBox(children=(IntProgress(value=0, description='Val files for original metrics', max=2000, style=ProgressStyl…

  kspaces[i, ..., 0] = kspace
  x_final[i] = fourier_op.adj_op(kspace)





Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  from ipykernel import kernelapp as app


HBox(children=(IntProgress(value=0, description='Stats for original metrics', max=2000, style=ProgressStyle(de…

  return compare_psnr(gt, pred, data_range=gt.max() - gt.min())
  gt.transpose(1, 2, 0), pred.transpose(1, 2, 0), multichannel=True, data_range=gt.max() - gt.min()



2000



Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(IntProgress(value=0, description='Val files for unet_gan', max=2000, style=ProgressStyle(descri…




In [None]:
print(metrics)

In [None]:
def n_model_params_for_params(val_gen=None, name=None, **net_params):
    model = unpack_model(**net_params)
    n_params = model.count_params()
    return n_params

In [None]:
%%time
n_params = {}
for net_params in all_net_params:
    n_params[net_params['name']] =  n_model_params_for_params(**net_params)

In [None]:
def runtime_for_params(reco_function=None, val_gen=None, name=None, **net_params):
    g = unpack_model(**net_params)
    x, y = next(val_gen)
    start = time.time()
    pred = g.predict(x)
    end = time.time()
    return end - start

In [None]:
%%time
runtimes = {}
all_net_params
for net_params in tqdm_notebook(all_net_params):
    runtimes[net_params['name']] =  runtime_for_params(**net_params)

In [None]:
metrics_table = pd.DataFrame(
    index=[name for name, _ in metrics], 
    columns=['PSNR-mean (std) (dB)', 'SSIM-mean (std)', '# params', 'Runtime (s)'],
)

for name, m in metrics:
    metrics_table.loc[name, 'PSNR-mean (std) (dB)'] = "{mean:.4} ({std:.4})".format(
        mean=m.metrics['PSNR'].mean(), 
        std=m.metrics['PSNR'].stddev(),
    )
    metrics_table.loc[name, 'SSIM-mean (std)'] = "{mean:.4} ({std:.4})".format(
        mean=m.metrics['SSIM'].mean(), 
        std=m.metrics['SSIM'].stddev(),
    )
    if name != 'original metrics':
        metrics_table.loc[name, '# params'] = "{}".format(
            n_params[name], 
         )
        metrics_table.loc[name, 'Runtime (s)'] = "{runtime:.4}".format(
            runtime=runtimes[name], 
        )

In [None]:
metrics_table