In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib nbagg
import os.path as op
import time

from keras_tqdm import TQDMNotebookCallback
import tensorflow as tf
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint
from tqdm import tqdm_notebook

from fastmri_recon.data.fastmri_sequences import Masked2DSequence, KIKISequence
from fastmri_recon.helpers.nn_mri import MultiplyScalar, lrelu
from fastmri_recon.models.kiki_sep import kiki_sep_net

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
# paths
train_path = '/media/Zaccharie/UHRes/singlecoil_train/singlecoil_train/'
val_path = '/media/Zaccharie/UHRes/singlecoil_val/'
test_path = '/media/Zaccharie/UHRes/singlecoil_test/'

n_samples_train = 34742
n_samples_val = 7135

n_volumes_train = 973
n_volumes_val = 199


# generators
AF = 4
train_gen_last = Masked2DSequence(train_path, af=AF, inner_slices=8, rand=True, scale_factor=1e6)
val_gen_last = Masked2DSequence(val_path, af=AF, scale_factor=1e6)
train_gen_i = KIKISequence(train_path, af=AF, inner_slices=8, rand=True, scale_factor=1e6, space='I')
val_gen_i = KIKISequence(val_path, af=AF, scale_factor=1e6, space='I')
train_gen_k = KIKISequence(train_path, af=AF, inner_slices=8, rand=True, scale_factor=1e6, space='K')
val_gen_k = KIKISequence(val_path, af=AF, scale_factor=1e6, space='K')

In [3]:
run_params = {
    'n_convs': 25,
    'n_filters': 48,
    'noiseless': True,
    'lr': 1e-3,
    'activation': lrelu,
}
multiply_scalar = MultiplyScalar()
n_epochs = 300

def train_model(model, space='K', n=1):
    print(model.summary(line_length=150))
    run_id = f'kikinet_sep_{space}{n}_af{AF}_{int(time.time())}'
    chkpt_path = f'checkpoints/{run_id}' + '-{epoch:02d}.hdf5'
    print(run_id)

    chkpt_cback = ModelCheckpoint(chkpt_path, period=50)
    log_dir = op.join('logs', run_id)
    tboard_cback = TensorBoard(
        log_dir=log_dir,
        histogram_freq=0,
        write_graph=True,
        write_images=False,
        profile_batch=0,
    )
    tqdm_cb = TQDMNotebookCallback(metric_format="{name}: {value:e}")
    tqdm_cb.on_train_batch_begin = tqdm_cb.on_batch_begin
    tqdm_cb.on_train_batch_end = tqdm_cb.on_batch_end
    if space == 'K':
        train_gen = train_gen_k
        val_gen = val_gen_k
    elif space == 'I':
        if n == 2:
            train_gen = train_gen_last
            val_gen = val_gen_last
        elif n == 1:
            train_gen = train_gen_i
            val_gen = val_gen_i
    model.fit(
        train_ds, 
        steps_per_epoch=5, 
        epochs=n_epochs,
        validation_data=val_ds,
        validation_steps=1,
        verbose=0,
        callbacks=[tqdm_cb, tboard_cback, chkpt_cback,],
    )
    return model

In [4]:
%%time
model_1 = kiki_sep_net(None, multiply_scalar, to_add='K', last=False, **run_params)
# model_2 = kiki_sep_net(model_1, multiply_scalar, to_add='I', last=False, **run_params)
# model_3 = kiki_sep_net(model_2, multiply_scalar, to_add='K', last=False, **run_params)
# model_4 = kiki_sep_net(model_3, multiply_scalar, to_add='I', last=True, **run_params)
train_model(model_1, space='K', n=1)

W1201 20:49:48.150653 140651500717824 callbacks.py:886] `period` argument is deprecated. Please use `save_freq` to specify the frequency in number of samples seen.


Model: "model"
______________________________________________________________________________________________________________________________________________________
Layer (type)                                     Output Shape                     Param #           Connected to                                      
kspace_input (InputLayer)                        [(None, 640, None, 1)]           0                                                                   
______________________________________________________________________________________________________________________________________________________
lambda (Lambda)                                  (None, 640, None, 1)             0                 kspace_input[0][0]                                
______________________________________________________________________________________________________________________________________________________
lambda_1 (Lambda)                                (None, 640, None, 1)          

HBox(children=(IntProgress(value=0, description='Training', max=300, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 0', max=5, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Epoch 1', max=5, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Epoch 2', max=5, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Epoch 3', max=5, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Epoch 4', max=5, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Epoch 5', max=5, style=ProgressStyle(description_width='initi…




KeyboardInterrupt: 