In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib nbagg
import os.path as op
import time

from keras_tqdm import TQDMNotebookCallback
import tensorflow as tf
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint
from tqdm import tqdm_notebook

from fastmri_recon.data.fastmri_sequences import ZeroFilled2DSequence
from fastmri_recon.models.unet import unet

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
# paths
train_path = '/media/Zaccharie/UHRes/singlecoil_train/singlecoil_train/'
val_path = '/media/Zaccharie/UHRes/singlecoil_val/'
test_path = '/media/Zaccharie/UHRes/singlecoil_test/'

In [3]:
n_samples_train = 34742
n_samples_val = 7135

n_volumes_train = 973
n_volumes_val = 199

In [4]:
# generators
AF = 4
train_set = ZeroFilled2DSequence(train_path, af=AF, norm=True)
val_set = ZeroFilled2DSequence(val_path, af=AF, norm=True)

In [5]:
run_params = {
    'n_layers': 4, 
    'pool': 'max', 
    "layers_n_channels": [16, 32, 64, 128], 
    'layers_n_non_lins': 2,
#     'n_layers': 2,
}
n_epochs = 10
run_id = f'unet_af{AF}_{int(time.time())}'
chkpt_path = f'checkpoints/{run_id}' + '-{epoch:02d}.hdf5'

In [6]:
chkpt_cback = ModelCheckpoint(chkpt_path, period=100, save_weights_only=True)
log_dir = op.join('logs', run_id)
tboard_cback = TensorBoard(
    log_dir=log_dir, 
    histogram_freq=0, 
    write_graph=True, 
    write_images=False, 
    profile_batch=0,
)
tqdm_cb = TQDMNotebookCallback(metric_format="{name}: {value:e}")
tqdm_cb.on_train_batch_begin = tqdm_cb.on_batch_begin
tqdm_cb.on_train_batch_end = tqdm_cb.on_batch_end

W1201 17:33:41.559658 140227304007424 callbacks.py:886] `period` argument is deprecated. Please use `save_freq` to specify the frequency in number of samples seen.


In [7]:
model = unet(input_size=(320, 320, 1), lr=1e-3, **run_params)
print(model.summary())

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 320, 320, 1) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 320, 320, 16) 160         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 320, 320, 16) 2320        conv2d[0][0]                     
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)    (None, 160, 160, 16) 0           conv2d_1[0][0]                   
______________________________________________________________________________________________

In [8]:
model.fit(
    train_set, 
    steps_per_epoch=5, 
    epochs=n_epochs,
    validation_data=val_set,
    validation_steps=1,
    verbose=0,
    callbacks=[tqdm_cb, tboard_cback, chkpt_cback,],
)

HBox(children=(IntProgress(value=0, description='Training', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 0', max=5, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Epoch 1', max=5, style=ProgressStyle(description_width='initi…

HBox(children=(IntProgress(value=0, description='Epoch 2', max=5, style=ProgressStyle(description_width='initi…




KeyboardInterrupt: 

In [None]:
# %%time
# # simple overfit trials
# data = next(iter(train_ds))
# model.fit(
#     x=data[0][15:16], 
#     y=data[1][15:16], 
#     batch_size=1, 
#     epochs=500,
#     verbose=2, 
#     shuffle=False,
# )