In [13]:
%load_ext autoreload
%autoreload 2
%matplotlib nbagg
import os.path as op
import time

from keras.callbacks import TensorBoard, ModelCheckpoint
from keras.optimizers import Adam
from keras_tqdm import TQDMNotebookCallback
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm_notebook

from fastmri_recon.data.test_generators import RandomShapeGenerator
from fastmri_recon.helpers.adversarial_training import compile_models, adversarial_training_loop
from fastmri_recon.helpers.image_tboard_cback import TensorBoardImage
from fastmri_recon.helpers.keras_utils import wasserstein_loss
from fastmri_recon.models.discriminator import discriminator_model, generator_containing_discriminator_multiple_outputs
from fastmri_recon.models.unet import unet
from fastmri_recon.helpers.utils import keras_ssim, keras_psnr

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [129]:
run_params = {
    'n_layers': 4, 
    'pool': 'max', 
    "layers_n_channels": [16, 32, 64, 128], 
    'layers_n_non_lins': 2,
}

im_size = 64

def generator_model():
    model = unet(input_size=(im_size, im_size, 1), **run_params, compile=False)
    model.name = 'Reconstructor'
    return model

In [130]:
g = generator_model()
g_opt = Adam(lr=1e-3, clipnorm=1.)
g.compile(optimizer=g_opt, loss='mse', metrics=[keras_psnr, keras_ssim])

In [135]:
n_batches_train = 1000
n_batches_val = 1
AF = 2
train_gen = RandomShapeGenerator(af=AF, n_shapes=100, size=im_size, batch_size=1).flow_z_filled_random_shapes()
val_gen = RandomShapeGenerator(af=AF, n_shapes=100, size=im_size, batch_size=1).flow_z_filled_random_shapes()

epoch_num = 100


run_id = f'gen_af{AF}_{int(time.time())}'
chkpt_path = f'checkpoints/{run_id}' + '-{epoch:02d}.hdf5'
print(run_id)

gen_af2_1572957673


In [136]:
log_dir = op.join('logs2', run_id)
tqdm_cb = TQDMNotebookCallback(metric_format="{name}: {value:e}")
tboard_cback = TensorBoard(
    log_dir=log_dir, 
    write_graph=True, 
    update_freq=50,
)
chkpt_cback = ModelCheckpoint(chkpt_path, period=epoch_num)
selected_slice = 0
data = next(val_gen)
tboard_image_cback = TensorBoardImage(
    log_dir=log_dir + '/images',
    image=data[1][selected_slice:selected_slice+1],
    # NOTE: for cross-domain slice has to be on kspace and mask
    model_input=data[0][selected_slice:selected_slice+1],
)

In [126]:
aliased = []
og = []
for i in range(150):
    x,image = next(train_gen) 
    aliased.append(x)
    og.append(image)
    

In [127]:
import numpy
a = np.squeeze(aliased, axis=1)
o = np.squeeze(og, axis=1)

In [128]:
g.fit(x=a, y=o, batch_size=1, epochs=100, callbacks=[tqdm_cb, tboard_cback, tboard_image_cback])

HBox(children=(IntProgress(value=0, description='Training', style=ProgressStyle(description_width='initial')),…

Epoch 1/100


HBox(children=(IntProgress(value=0, description='Epoch 0', max=150, style=ProgressStyle(description_width='ini…

Epoch 2/100


HBox(children=(IntProgress(value=0, description='Epoch 1', max=150, style=ProgressStyle(description_width='ini…

Epoch 3/100


HBox(children=(IntProgress(value=0, description='Epoch 2', max=150, style=ProgressStyle(description_width='ini…

Epoch 4/100


HBox(children=(IntProgress(value=0, description='Epoch 3', max=150, style=ProgressStyle(description_width='ini…

Epoch 5/100


HBox(children=(IntProgress(value=0, description='Epoch 4', max=150, style=ProgressStyle(description_width='ini…

Epoch 6/100


HBox(children=(IntProgress(value=0, description='Epoch 5', max=150, style=ProgressStyle(description_width='ini…

Epoch 7/100


HBox(children=(IntProgress(value=0, description='Epoch 6', max=150, style=ProgressStyle(description_width='ini…

Epoch 8/100


HBox(children=(IntProgress(value=0, description='Epoch 7', max=150, style=ProgressStyle(description_width='ini…

Epoch 9/100


HBox(children=(IntProgress(value=0, description='Epoch 8', max=150, style=ProgressStyle(description_width='ini…

Epoch 10/100


HBox(children=(IntProgress(value=0, description='Epoch 9', max=150, style=ProgressStyle(description_width='ini…

Epoch 11/100


HBox(children=(IntProgress(value=0, description='Epoch 10', max=150, style=ProgressStyle(description_width='in…

Epoch 12/100


HBox(children=(IntProgress(value=0, description='Epoch 11', max=150, style=ProgressStyle(description_width='in…

Epoch 13/100


HBox(children=(IntProgress(value=0, description='Epoch 12', max=150, style=ProgressStyle(description_width='in…

Epoch 14/100


HBox(children=(IntProgress(value=0, description='Epoch 13', max=150, style=ProgressStyle(description_width='in…

Epoch 15/100


HBox(children=(IntProgress(value=0, description='Epoch 14', max=150, style=ProgressStyle(description_width='in…

Epoch 16/100


HBox(children=(IntProgress(value=0, description='Epoch 15', max=150, style=ProgressStyle(description_width='in…

Epoch 17/100


HBox(children=(IntProgress(value=0, description='Epoch 16', max=150, style=ProgressStyle(description_width='in…

Epoch 18/100


HBox(children=(IntProgress(value=0, description='Epoch 17', max=150, style=ProgressStyle(description_width='in…

Epoch 19/100


HBox(children=(IntProgress(value=0, description='Epoch 18', max=150, style=ProgressStyle(description_width='in…

Epoch 20/100


HBox(children=(IntProgress(value=0, description='Epoch 19', max=150, style=ProgressStyle(description_width='in…

Epoch 21/100


HBox(children=(IntProgress(value=0, description='Epoch 20', max=150, style=ProgressStyle(description_width='in…

Epoch 22/100


HBox(children=(IntProgress(value=0, description='Epoch 21', max=150, style=ProgressStyle(description_width='in…

Epoch 23/100


HBox(children=(IntProgress(value=0, description='Epoch 22', max=150, style=ProgressStyle(description_width='in…

Epoch 24/100


HBox(children=(IntProgress(value=0, description='Epoch 23', max=150, style=ProgressStyle(description_width='in…

Epoch 25/100


HBox(children=(IntProgress(value=0, description='Epoch 24', max=150, style=ProgressStyle(description_width='in…

Epoch 26/100


HBox(children=(IntProgress(value=0, description='Epoch 25', max=150, style=ProgressStyle(description_width='in…

Epoch 27/100


HBox(children=(IntProgress(value=0, description='Epoch 26', max=150, style=ProgressStyle(description_width='in…

Epoch 28/100


HBox(children=(IntProgress(value=0, description='Epoch 27', max=150, style=ProgressStyle(description_width='in…

Epoch 29/100


HBox(children=(IntProgress(value=0, description='Epoch 28', max=150, style=ProgressStyle(description_width='in…

Epoch 30/100


HBox(children=(IntProgress(value=0, description='Epoch 29', max=150, style=ProgressStyle(description_width='in…

Epoch 31/100


HBox(children=(IntProgress(value=0, description='Epoch 30', max=150, style=ProgressStyle(description_width='in…

Epoch 32/100


HBox(children=(IntProgress(value=0, description='Epoch 31', max=150, style=ProgressStyle(description_width='in…

Epoch 33/100


HBox(children=(IntProgress(value=0, description='Epoch 32', max=150, style=ProgressStyle(description_width='in…

Epoch 34/100


HBox(children=(IntProgress(value=0, description='Epoch 33', max=150, style=ProgressStyle(description_width='in…

Epoch 35/100


HBox(children=(IntProgress(value=0, description='Epoch 34', max=150, style=ProgressStyle(description_width='in…

Epoch 36/100


HBox(children=(IntProgress(value=0, description='Epoch 35', max=150, style=ProgressStyle(description_width='in…

Epoch 37/100


HBox(children=(IntProgress(value=0, description='Epoch 36', max=150, style=ProgressStyle(description_width='in…

Epoch 38/100


HBox(children=(IntProgress(value=0, description='Epoch 37', max=150, style=ProgressStyle(description_width='in…

Epoch 39/100


HBox(children=(IntProgress(value=0, description='Epoch 38', max=150, style=ProgressStyle(description_width='in…

Epoch 40/100


HBox(children=(IntProgress(value=0, description='Epoch 39', max=150, style=ProgressStyle(description_width='in…

Epoch 41/100


HBox(children=(IntProgress(value=0, description='Epoch 40', max=150, style=ProgressStyle(description_width='in…

Epoch 42/100


HBox(children=(IntProgress(value=0, description='Epoch 41', max=150, style=ProgressStyle(description_width='in…

Epoch 43/100


HBox(children=(IntProgress(value=0, description='Epoch 42', max=150, style=ProgressStyle(description_width='in…

Epoch 44/100


HBox(children=(IntProgress(value=0, description='Epoch 43', max=150, style=ProgressStyle(description_width='in…

Epoch 45/100


HBox(children=(IntProgress(value=0, description='Epoch 44', max=150, style=ProgressStyle(description_width='in…

Epoch 46/100


HBox(children=(IntProgress(value=0, description='Epoch 45', max=150, style=ProgressStyle(description_width='in…

Epoch 47/100


HBox(children=(IntProgress(value=0, description='Epoch 46', max=150, style=ProgressStyle(description_width='in…

Epoch 48/100


HBox(children=(IntProgress(value=0, description='Epoch 47', max=150, style=ProgressStyle(description_width='in…

Epoch 49/100


HBox(children=(IntProgress(value=0, description='Epoch 48', max=150, style=ProgressStyle(description_width='in…

Epoch 50/100


HBox(children=(IntProgress(value=0, description='Epoch 49', max=150, style=ProgressStyle(description_width='in…

Epoch 51/100


HBox(children=(IntProgress(value=0, description='Epoch 50', max=150, style=ProgressStyle(description_width='in…

Epoch 52/100


HBox(children=(IntProgress(value=0, description='Epoch 51', max=150, style=ProgressStyle(description_width='in…

Epoch 53/100


HBox(children=(IntProgress(value=0, description='Epoch 52', max=150, style=ProgressStyle(description_width='in…

Epoch 54/100


HBox(children=(IntProgress(value=0, description='Epoch 53', max=150, style=ProgressStyle(description_width='in…

Epoch 55/100


HBox(children=(IntProgress(value=0, description='Epoch 54', max=150, style=ProgressStyle(description_width='in…

Epoch 56/100


HBox(children=(IntProgress(value=0, description='Epoch 55', max=150, style=ProgressStyle(description_width='in…

Epoch 57/100


HBox(children=(IntProgress(value=0, description='Epoch 56', max=150, style=ProgressStyle(description_width='in…

Epoch 58/100


HBox(children=(IntProgress(value=0, description='Epoch 57', max=150, style=ProgressStyle(description_width='in…

Epoch 59/100


HBox(children=(IntProgress(value=0, description='Epoch 58', max=150, style=ProgressStyle(description_width='in…

Epoch 60/100


HBox(children=(IntProgress(value=0, description='Epoch 59', max=150, style=ProgressStyle(description_width='in…

Epoch 61/100


HBox(children=(IntProgress(value=0, description='Epoch 60', max=150, style=ProgressStyle(description_width='in…

Epoch 62/100


HBox(children=(IntProgress(value=0, description='Epoch 61', max=150, style=ProgressStyle(description_width='in…

Epoch 63/100


HBox(children=(IntProgress(value=0, description='Epoch 62', max=150, style=ProgressStyle(description_width='in…

Epoch 64/100


HBox(children=(IntProgress(value=0, description='Epoch 63', max=150, style=ProgressStyle(description_width='in…

Epoch 65/100


HBox(children=(IntProgress(value=0, description='Epoch 64', max=150, style=ProgressStyle(description_width='in…

Epoch 66/100


HBox(children=(IntProgress(value=0, description='Epoch 65', max=150, style=ProgressStyle(description_width='in…

Epoch 67/100


HBox(children=(IntProgress(value=0, description='Epoch 66', max=150, style=ProgressStyle(description_width='in…

Epoch 68/100


HBox(children=(IntProgress(value=0, description='Epoch 67', max=150, style=ProgressStyle(description_width='in…

Epoch 69/100


HBox(children=(IntProgress(value=0, description='Epoch 68', max=150, style=ProgressStyle(description_width='in…

Epoch 70/100


HBox(children=(IntProgress(value=0, description='Epoch 69', max=150, style=ProgressStyle(description_width='in…

Epoch 71/100


HBox(children=(IntProgress(value=0, description='Epoch 70', max=150, style=ProgressStyle(description_width='in…

Epoch 72/100


HBox(children=(IntProgress(value=0, description='Epoch 71', max=150, style=ProgressStyle(description_width='in…

Epoch 73/100


HBox(children=(IntProgress(value=0, description='Epoch 72', max=150, style=ProgressStyle(description_width='in…

Epoch 74/100


HBox(children=(IntProgress(value=0, description='Epoch 73', max=150, style=ProgressStyle(description_width='in…

Epoch 75/100


HBox(children=(IntProgress(value=0, description='Epoch 74', max=150, style=ProgressStyle(description_width='in…

Epoch 76/100


HBox(children=(IntProgress(value=0, description='Epoch 75', max=150, style=ProgressStyle(description_width='in…

Epoch 77/100


HBox(children=(IntProgress(value=0, description='Epoch 76', max=150, style=ProgressStyle(description_width='in…

Epoch 78/100


HBox(children=(IntProgress(value=0, description='Epoch 77', max=150, style=ProgressStyle(description_width='in…

Epoch 79/100


HBox(children=(IntProgress(value=0, description='Epoch 78', max=150, style=ProgressStyle(description_width='in…

Epoch 80/100


HBox(children=(IntProgress(value=0, description='Epoch 79', max=150, style=ProgressStyle(description_width='in…

Epoch 81/100


HBox(children=(IntProgress(value=0, description='Epoch 80', max=150, style=ProgressStyle(description_width='in…

Epoch 82/100


HBox(children=(IntProgress(value=0, description='Epoch 81', max=150, style=ProgressStyle(description_width='in…

Epoch 83/100


HBox(children=(IntProgress(value=0, description='Epoch 82', max=150, style=ProgressStyle(description_width='in…

Epoch 84/100


HBox(children=(IntProgress(value=0, description='Epoch 83', max=150, style=ProgressStyle(description_width='in…

Epoch 85/100


HBox(children=(IntProgress(value=0, description='Epoch 84', max=150, style=ProgressStyle(description_width='in…

Epoch 86/100


HBox(children=(IntProgress(value=0, description='Epoch 85', max=150, style=ProgressStyle(description_width='in…

Epoch 87/100


HBox(children=(IntProgress(value=0, description='Epoch 86', max=150, style=ProgressStyle(description_width='in…

Epoch 88/100


HBox(children=(IntProgress(value=0, description='Epoch 87', max=150, style=ProgressStyle(description_width='in…

Epoch 89/100


HBox(children=(IntProgress(value=0, description='Epoch 88', max=150, style=ProgressStyle(description_width='in…

Epoch 90/100


HBox(children=(IntProgress(value=0, description='Epoch 89', max=150, style=ProgressStyle(description_width='in…

Epoch 91/100


HBox(children=(IntProgress(value=0, description='Epoch 90', max=150, style=ProgressStyle(description_width='in…

Epoch 92/100


HBox(children=(IntProgress(value=0, description='Epoch 91', max=150, style=ProgressStyle(description_width='in…

Epoch 93/100


HBox(children=(IntProgress(value=0, description='Epoch 92', max=150, style=ProgressStyle(description_width='in…

Epoch 94/100


HBox(children=(IntProgress(value=0, description='Epoch 93', max=150, style=ProgressStyle(description_width='in…

Epoch 95/100


HBox(children=(IntProgress(value=0, description='Epoch 94', max=150, style=ProgressStyle(description_width='in…

Epoch 96/100


HBox(children=(IntProgress(value=0, description='Epoch 95', max=150, style=ProgressStyle(description_width='in…

Epoch 97/100


HBox(children=(IntProgress(value=0, description='Epoch 96', max=150, style=ProgressStyle(description_width='in…

Epoch 98/100


HBox(children=(IntProgress(value=0, description='Epoch 97', max=150, style=ProgressStyle(description_width='in…

Epoch 99/100


HBox(children=(IntProgress(value=0, description='Epoch 98', max=150, style=ProgressStyle(description_width='in…

Epoch 100/100


HBox(children=(IntProgress(value=0, description='Epoch 99', max=150, style=ProgressStyle(description_width='in…



<keras.callbacks.callbacks.History at 0x1c3f5880b8>

In [137]:
aliased = []
og = []
for i in range(500):
    x,image = next(train_gen) 
    aliased.append(x)
    og.append(image)

a = np.squeeze(aliased, axis=1)
o = np.squeeze(og, axis=1)
g.fit(x=a, y=o, batch_size=1, epochs=1, callbacks=[tqdm_cb, tboard_cback, tboard_image_cback])

HBox(children=(IntProgress(value=0, description='Training', max=1, style=ProgressStyle(description_width='init…

Epoch 1/1


HBox(children=(IntProgress(value=0, description='Epoch 0', max=500, style=ProgressStyle(description_width='ini…



<keras.callbacks.callbacks.History at 0x1c460df7f0>