In [1]:
%cd ../..

/volatile/home/Zaccharie/workspace/understanding-unets


In [2]:
# # this just to make sure we are using only on CPU
# import os
# os.environ["CUDA_VISIBLE_DEVICES"]="-1"

In [3]:
import os.path as op
import time

from fastmri_recon.config import *
from fastmri_recon.data.fastmri_tf_datasets import train_masked_kspace_dataset_from_indexable
from fastmri_recon.helpers.nn_mri import tf_fastmri_format
from fastmri_recon.helpers.utils import keras_psnr, keras_ssim
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, LearningRateScheduler
import tensorflow_addons as tfa
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tqdm import tqdm_notebook

from learning_wavelets.keras_utils.fourier import tf_masked_shifted_normed_fft2d, tf_masked_shifted_normed_ifft2d
from learning_wavelets.models.ista import IstaLearnlet

In [4]:
# paths
train_path = f'{FASTMRI_DATA_DIR}singlecoil_train/singlecoil_train/'
val_path = f'{FASTMRI_DATA_DIR}singlecoil_val/'
test_path = f'{FASTMRI_DATA_DIR}singlecoil_test/'

n_volumes_train = 973
af = 4
train_set = train_masked_kspace_dataset_from_indexable(
    train_path,
    AF=af,
    contrast=None,
    inner_slices=8,
    rand=True,
    scale_factor=1e6,
    n_samples=None,
)
val_set = train_masked_kspace_dataset_from_indexable(
    val_path,
    AF=af,
    contrast=None,
    scale_factor=1e6,
)

Getting training files from /media/Zaccharie/UHRes/singlecoil_train/singlecoil_train/
Getting training files from /media/Zaccharie/UHRes/singlecoil_val/


In [5]:
learnlet_params = {
    'denoising_activation': 'dynamic_soft_thresholding',
    'learnlet_analysis_kwargs':{
        'n_tiling': 16, 
        'mixing_details': False,    
        'skip_connection': True,
        'kernel_size': 3,
    },
    'learnlet_synthesis_kwargs': {
        'res': True,
        'kernel_size': 5,
    },
    'threshold_kwargs':{
        'noise_std_norm': False,
        'alpha_init': 2.0,
    },
    'n_scales': 2,
    'exact_reconstruction': True,
    'undecimated': True,
    'clip': False,
}

model = IstaLearnlet(
    n_iterations=3,
    forward_operator=tf_masked_shifted_normed_fft2d,
    adjoint_operator=tf_masked_shifted_normed_ifft2d,
    postprocess=tf_fastmri_format,
    fista_mode=True,
    **learnlet_params,
)
model.compile(
    optimizer=Adam(lr=1e-3),
    loss='mse',
    metrics=[keras_psnr, keras_ssim,],
)

In [6]:
n_epochs = 1
run_id = f'ista_learnlet_ellipses_{int(time.time())}'
chkpt_path = f'checkpoints/{run_id}' + '-{epoch:02d}.hdf5'
print(run_id)

ista_learnlet_ellipses_1583844806


In [7]:
def l_rate_schedule(epoch):
    return max(1e-3 / 2**(epoch//25), 1e-5)
lrate_cback = LearningRateScheduler(l_rate_schedule)

In [8]:
chkpt_cback = ModelCheckpoint(chkpt_path, period=n_epochs, save_weights_only=False)
log_dir = op.join('logs', run_id)
tboard_cback = TensorBoard(
    log_dir=log_dir, 
    histogram_freq=0, 
    write_graph=False, 
    write_images=False, 
    profile_batch=0,
)
tqdm_cb = tfa.callbacks.TQDMProgressBar(metrics_format="{name}: {value:e}")
# val_noisy, val_gt = next(iter(im_ds_val))
# tboard_image_cback = TensorBoardImage(
#     log_dir=log_dir + '/images',
#     image=val_gt[0:1],
#     noisy_image=val_noisy[0:1],
# )



In [9]:
model.fit(
    train_set, 
    steps_per_epoch=n_volumes_train, 
#     steps_per_epoch=5, 
    epochs=n_epochs,
    validation_data=val_set,
#     validation_steps=int(validation_split * n_samples_train / batch_size),
    validation_steps=1,
    verbose=0,
    callbacks=[tqdm_cb, tboard_cback, chkpt_cback],
#     callbacks=[tqdm_cb, norm_cback, lrate_cback],
#     callbacks=[tqdm_cb, tboard_cback, chkpt_cback, tboard_image_cback, norm_cback, lrate_cback],
    shuffle=False,
)

HBox(children=(FloatProgress(value=0.0, description='Training', layout=Layout(flex='2'), max=1.0, style=Progre…

Epoch 1/1


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=973.0), HTML(value='')), layout=Layout(di…





<tensorflow.python.keras.callbacks.History at 0x7f75b8620e80>

In [10]:
# # %%time
# # overfitting trials
# data = next(iter(train_set))
# val_data = next(iter(val_set))

# model.fit(
#     x=data[0], 
#     y=data[1], 
# #     validation_data=val_data, 
#     batch_size=1, 
# #     callbacks=[tqdm_cb, tboard_cback, tboard_image_cback, norm_cback, lrate_cback],
#     callbacks=[tqdm_cb, lrate_cback],
#     epochs=1, 
#     verbose=0, 
#     shuffle=False,
# )

In [11]:
[ib.alpha.numpy() for ib in model.ista_blocks]

[array([1.4283597], dtype=float32),
 array([0.60190666], dtype=float32),
 array([1.123911], dtype=float32)]

In [11]:
[ib.momentum.numpy() for ib in model.ista_blocks]

[array([1.4283597], dtype=float32),
 array([0.60190666], dtype=float32),
 array([1.123911], dtype=float32)]