In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib nbagg
import os.path as op
import time

from keras.callbacks import TensorBoard, ModelCheckpoint
from keras.optimizers import Adam
from keras_tqdm import TQDMNotebookCallback
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm_notebook

from fastmri_recon.data.fastmri_sequences import ZeroFilled2DSequence
from fastmri_recon.helpers.adversarial_training import compile_models, adversarial_training_loop
from fastmri_recon.helpers.image_tboard_cback import TensorBoardImage
from fastmri_recon.helpers.keras_utils import wasserstein_loss
from fastmri_recon.models.discriminator import discriminator_model, generator_containing_discriminator_multiple_outputs
from fastmri_recon.models.unet import unet

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
run_params = {
    'n_layers': 4, 
    'pool': 'max', 
    "layers_n_channels": [16, 32, 64, 128], 
    'layers_n_non_lins': 2,
}

def generator_model():
    model = unet(input_size=(320, 320, 1), **run_params, compile=False)
    model.name = 'Reconstructor'
    return model

In [3]:
# model definitions
g = generator_model()
# prev_run_id = 'unet_wo_lastrelu_af4_1572013433'
# prev_epoch = 500
# chkpt_path = f'checkpoints/{prev_run_id}-{prev_epoch}.hdf5'
# g.load_weights(chkpt_path)
d = discriminator_model()
d_on_g = generator_containing_discriminator_multiple_outputs(g, d)

W1026 14:59:13.956563 140209183528704 deprecation_wrapper.py:119] From /volatile/home/Zaccharie/workspace/keras/keras/backend/tensorflow_backend.py:4070: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.



In [4]:
# model compiling
perceptual_loss = 'mae'

discriminator_lr = 1e-4
d_on_reconstructor_lr = 1e-3

compile_models(
    d, 
    g, 
    d_on_g, 
    d_lr=discriminator_lr, 
    d_on_g_lr=d_on_reconstructor_lr, 
    perceptual_loss=perceptual_loss, 
    perceptual_weight=1.0,
)

In [5]:
train_path = '/media/Zaccharie/UHRes/singlecoil_train/singlecoil_train/'
val_path = '/media/Zaccharie/UHRes/singlecoil_val/'
n_volumes_train = 973
n_volumes_val = 199
AF = 4
train_gen = ZeroFilled2DSequence(train_path, af=AF, norm=True)
val_gen = ZeroFilled2DSequence(val_path, af=AF, norm=True)

epoch_num = 100
critic_updates = 5

run_id = f'unet_gan_af{AF}_{int(time.time())}'
chkpt_path = f'checkpoints/{run_id}' + '-{epoch:02d}.hdf5'
print(run_id)

unet_gan_af4_1572094755


In [6]:
log_dir = op.join('logs', run_id)
tqdm_cb = TQDMNotebookCallback(metric_format="{name}: {value:e}")
tboard_cback = TensorBoard(
    log_dir=log_dir, 
    write_graph=True, 
    update_freq=50,
)
chkpt_cback = ModelCheckpoint(chkpt_path, period=epoch_num)
selected_slice = 15
tboard_image_cback = TensorBoardImage(
    log_dir=log_dir + '/images',
    image=val_gen[0][1][selected_slice:selected_slice+1],
    # NOTE: for cross-domain slice has to be on kspace and mask
    model_input=val_gen[0][0][selected_slice:selected_slice+1],
)

In [None]:
%%time
adversarial_training_loop(
    g, 
    d, 
    d_on_g, 
    train_gen, 
    val_gen=val_gen,
    validation_steps=1,
    n_epochs=epoch_num, 
    n_batches=n_volumes_train, 
    n_critic_updates=critic_updates,
    callbacks=[tqdm_cb, tboard_cback, chkpt_cback, tboard_image_cback],
    workers=35,
    use_multiprocessing=True,
    max_queue_size=35,
    include_d_metrics=True,
)

W1026 14:59:19.706170 140209183528704 deprecation_wrapper.py:119] From /volatile/home/Zaccharie/workspace/keras/keras/callbacks/tensorboard_v1.py:200: The name tf.summary.merge_all is deprecated. Please use tf.compat.v1.summary.merge_all instead.

W1026 14:59:19.711432 140209183528704 deprecation_wrapper.py:119] From /volatile/home/Zaccharie/workspace/keras/keras/callbacks/tensorboard_v1.py:203: The name tf.summary.FileWriter is deprecated. Please use tf.compat.v1.summary.FileWriter instead.



HBox(children=(IntProgress(value=0, description='Training', style=ProgressStyle(description_width='initial')),…

  .format(dtypeobj_in, dtypeobj_out))
W1026 14:59:19.944834 140209183528704 deprecation_wrapper.py:119] From /volatile/home/Zaccharie/workspace/fastmri-reproducible-benchmark/fastmri_recon/helpers/image_tboard_cback.py:26: The name tf.Summary is deprecated. Please use tf.compat.v1.Summary instead.



HBox(children=(IntProgress(value=0, description='Epoch 0', max=973, style=ProgressStyle(description_width='ini…

W1026 14:59:26.360121 140209183528704 deprecation_wrapper.py:119] From /volatile/home/Zaccharie/workspace/keras/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.



In [None]:
# %%time
# #overfitting trial
# adversarial_training_loop(
#     g, 
#     d, 
#     d_on_g, 
#     train_gen, 
# #     val_gen=val_gen,
# #     validation_steps=1,
#     n_epochs=500, 
#     n_batches=1, 
#     n_critic_updates=5,
#     callbacks=[tqdm_cb, tboard_cback, tboard_image_cback],
# #     workers=35,
# #     use_multiprocessing=True,
# #     max_queue_size=35,
#     include_d_metrics=True,
# )