In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os.path as op
import time

from keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping
from keras.optimizers import Adam
from keras_tqdm import TQDMNotebookCallback
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm_notebook
from keras.datasets import cifar10
import seaborn as sns
from skimage import color
import pandas as pd
import os
import matplotlib.image as mpimg

from fastmri_recon.data.test_generators import DataGenerator
from fastmri_recon.helpers.adversarial_training import compile_models, adversarial_training_loop
from fastmri_recon.helpers.image_tboard_cback import TensorBoardImage
from fastmri_recon.helpers.keras_utils import wasserstein_loss
from fastmri_recon.models.discriminator import discriminator_model, generator_containing_discriminator_multiple_outputs
from fastmri_recon.models.unet import unet
from fastmri_recon.helpers.utils import keras_ssim, keras_psnr
from fastmri_recon.helpers.evaluate import psnr, ssim, mse, nmse
from fastmri_recon.helpers.fourier import fft
from fastmri_recon.helpers.utils import gen_mask
from fastmri_recon.helpers.reconstruction import zero_filled_recon
from numpy.random import seed
from tensorflow import set_random_seed

Using TensorFlow backend.


Link to dl the dataset: http://chaladze.com/l5/

In [51]:
# parameter initialization
run_params = {
    'n_layers': 4, 
    'pool': 'max', 
    "layers_n_channels": [16, 32, 64, 128], 
    'layers_n_non_lins': 2,
}


AF = 2

In [52]:
def load_data(path):
    path_train = path + "train/"
    path_test = path + "test/"
    train = os.listdir(path_train)
    train_data = []
    for s in train:
        for image in os.listdir(path_train+s):       
            train_data.append(mpimg.imread(path_train+s+'/'+image))

    test = os.listdir(path_test)
    test_data = []
    for s in test:
        for image in os.listdir(path_test+s):          
            test_data.append(mpimg.imread(path_test+s+'/'+image))

    x_train = color.rgb2gray(np.array(train_data))
    x_test = color.rgb2gray( np.array(test_data))

    return x_train, x_test

In [53]:
#load data

path = "/Users/WorkAccount/Desktop/Linnaeus_5_64X64/"

if path != "":
    x_train, x_test = load_data(path)
else:
    (x_train, _), (x_test, _) = cifar10.load_data()
    x_train = color.rgb2gray(x_train)
    x_test = color.rgb2gray(x_test)

im_size = x_train[0].shape[0]

val_gen = DataGenerator(AF, x_test).flow_z_filled_images(0.5)
train_gen = DataGenerator(AF, x_train).flow_z_filled_images(0.5)

In [54]:
def generator_model():
    model = unet(input_size=(im_size, im_size, 1), **run_params, compile=False)
    model.name = 'Reconstructor'
    return model

In [55]:
# model definitions
g = generator_model()
d = discriminator_model(im_size)
d_on_g = generator_containing_discriminator_multiple_outputs(g, d, im_size=im_size)

In [56]:
# model compiling
perceptual_loss = 'mae'

discriminator_lr = 1e-4
d_on_reconstructor_lr = 1e-4

compile_models(
    d, 
    g, 
    d_on_g, 
    d_lr=discriminator_lr, 
    d_on_g_lr=d_on_reconstructor_lr, 
    perceptual_loss=perceptual_loss, 
    perceptual_weight=1.0,
)

In [58]:
n_batches_train = 1000
n_batches_val = 1

run_id = f'unet_gan_af{AF}_{int(time.time())}'
chkpt_path = f'checkpoints/{run_id}' + '-{epoch:02d}.hdf5'
print(run_id)

unet_gan_af2_1576326922


In [60]:
log_dir = op.join('logs/training', run_id)

tqdm_cb = TQDMNotebookCallback(metric_format="{name}: {value:e}")
tboard_cback = TensorBoard(
    log_dir=log_dir, 
    write_graph=True, 
    update_freq=50,
)

chkpt_cback = ModelCheckpoint(chkpt_path, period=100)
selected_slice = 0
data = next(val_gen)
tboard_image_cback = TensorBoardImage(
    log_dir=log_dir + '/images',
    image=data[1][selected_slice:selected_slice+1],
    # NOTE: for cross-domain slice has to be on kspace and mask
    model_input=data[0][selected_slice:selected_slice+1],
)

log_dir_pre = op.join('logs/pretraining', run_id)
gen_tensor = TensorBoard(
    log_dir=log_dir_pre
)

tboard_cback_pre = TensorBoard(
    log_dir=log_dir_pre, 
    write_graph=True, 
    update_freq=50,
)
tqdm_cb_pre = TQDMNotebookCallback(outer_description = "Pre-training", metric_format="{name}: {value:e}")
tboard_image_cback_pre = TensorBoardImage(
    log_dir=log_dir + '/images',
    image=data[1][selected_slice:selected_slice+1],
    # NOTE: for cross-domain slice has to be on kspace and mask
    model_input=data[0][selected_slice:selected_slice+1],
)
earlystop = EarlyStopping(monitor='val_loss',
                          min_delta=0,
                          patience=0,
                          verbose=0, mode='auto')

In [61]:
%%time
# %debug
# overfitting trial
hist = adversarial_training_loop(
    g, 
    d, 
    d_on_g, 
    train_gen,
    val_gen,
    n_epochs=0, 
    n_batches=10, 
    n_critic_updates=0,
    callbacks=[tqdm_cb, tboard_cback, tboard_image_cback, earlystop],
    include_d_metrics=True,
    gen_pre_training_steps=50,
    pre_training_callbacks=[tqdm_cb_pre, tboard_cback_pre, tboard_image_cback_pre],
)


HBox(children=(IntProgress(value=1, bar_style='info', description='Training', max=1, style=ProgressStyle(descr…

HBox(children=(IntProgress(value=0, description='Pre-training', max=50, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, description='Epoch 0', max=10, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Epoch 1', max=10, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Epoch 2', max=10, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Epoch 3', max=10, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Epoch 4', max=10, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Epoch 5', max=10, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Epoch 6', max=10, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Epoch 7', max=10, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Epoch 8', max=10, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Epoch 9', max=10, style=ProgressStyle(description_width='init…

HBox(children=(IntProgress(value=0, description='Epoch 10', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 11', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 12', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 13', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 14', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 15', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 16', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 17', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 18', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 19', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 20', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 21', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 22', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 23', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 24', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 25', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 26', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 27', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 28', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 29', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 30', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 31', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 32', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 33', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 34', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 35', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 36', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 37', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 38', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 39', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 40', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 41', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 42', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 43', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 44', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 45', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 46', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 47', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 48', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 49', max=10, style=ProgressStyle(description_width='ini…



CPU times: user 1min 19s, sys: 28 s, total: 1min 47s
Wall time: 51.7 s
