In [1]:
%cd ..

/home/zaccharie/workspace/understanding-unets


In [2]:
import os.path as op
import time

from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, LearningRateScheduler
import tensorflow_addons as tfa
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tqdm import tqdm_notebook

from learning_wavelets.data.toy_datasets import masked_kspace_ellipse_dataset
from learning_wavelets.evaluate import keras_psnr, keras_ssim
from learning_wavelets.keras_utils.fourier import tf_masked_shifted_normed_fft2d, tf_masked_shifted_normed_ifft2d
from learning_wavelets.models.ista import IstaLearnlet

In [3]:
im_size = 128
af = 1.5
kspace_ds_train = masked_kspace_ellipse_dataset(im_size, af=af, batch_size=1)
kspace_ds_val = masked_kspace_ellipse_dataset(im_size, af=af)

In [4]:
learnlet_params = {
    'denoising_activation': 'dynamic_soft_thresholding',
    'learnlet_analysis_kwargs':{
        'n_tiling': 32, 
        'mixing_details': False,    
        'skip_connection': True,
        'kernel_size': 5,
    },
    'learnlet_synthesis_kwargs': {
        'res': True,
        'kernel_size': 7,
    },
    'threshold_kwargs':{
        'noise_std_norm': False,
    },
    'n_scales': 4,
    'exact_reconstruction': True,
    'undecimated': False,
    'clip': False,
}

model = IstaLearnlet(
    n_iterations=10,
    forward_operator=tf_masked_shifted_normed_fft2d,
    adjoint_operator=tf_masked_shifted_normed_ifft2d,
    **learnlet_params,
)
model.compile(
    optimizer=Adam(lr=1e-3),
    loss='mse',
    metrics=[keras_psnr, keras_ssim,],
)

In [5]:
n_epochs = 500
run_id = f'ista_learnlet_ellipses_{int(time.time())}'
chkpt_path = f'checkpoints/{run_id}' + '-{epoch:02d}.hdf5'
print(run_id)

ista_learnlet_ellipses_1583422610


In [6]:
def l_rate_schedule(epoch):
    return max(1e-3 / 2**(epoch//25), 1e-5)
lrate_cback = LearningRateScheduler(l_rate_schedule)

In [7]:
chkpt_cback = ModelCheckpoint(chkpt_path, period=n_epochs, save_weights_only=False)
log_dir = op.join('logs', run_id)
tboard_cback = TensorBoard(
    log_dir=log_dir, 
    histogram_freq=0, 
    write_graph=False, 
    write_images=False, 
    profile_batch=0,
)
tqdm_cb = tfa.callbacks.TQDMProgressBar(metrics_format="{name}: {value:e}")
# val_noisy, val_gt = next(iter(im_ds_val))
# tboard_image_cback = TensorBoardImage(
#     log_dir=log_dir + '/images',
#     image=val_gt[0:1],
#     noisy_image=val_noisy[0:1],
# )

W0305 16:36:50.018442 140518637909760 callbacks.py:886] `period` argument is deprecated. Please use `save_freq` to specify the frequency in number of samples seen.


In [8]:
%%time
# overfitting trials
data = next(iter(kspace_ds_train))
val_data = next(iter(kspace_ds_val))
model.fit(
    x=data[0], 
    y=data[1], 
#     validation_data=val_data, 
    batch_size=1, 
#     callbacks=[tqdm_cb, tboard_cback, tboard_image_cback, norm_cback, lrate_cback],
    callbacks=[tqdm_cb, tboard_cback, lrate_cback],
    epochs=5, 
    verbose=2, 
    shuffle=False,
)

Train on 1 samples


HBox(children=(FloatProgress(value=0.0, description='Training', layout=Layout(flex='2'), max=5.0, style=Progre…

Epoch 1/5


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1.0), HTML(value='')), layout=Layout(disp…

Epoch 1/5

1/1 - 22s - loss: 11.6582 - keras_psnr: -1.0666e+01 - keras_ssim: 0.0401
Epoch 2/5


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1.0), HTML(value='')), layout=Layout(disp…

Epoch 2/5

1/1 - 1s - loss: 13.9506 - keras_psnr: -1.1446e+01 - keras_ssim: 0.0419
Epoch 3/5


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1.0), HTML(value='')), layout=Layout(disp…

Epoch 3/5

1/1 - 1s - loss: 13.2389 - keras_psnr: -1.1219e+01 - keras_ssim: 0.0413
Epoch 4/5


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1.0), HTML(value='')), layout=Layout(disp…

Epoch 4/5

1/1 - 1s - loss: 12.3783 - keras_psnr: -1.0927e+01 - keras_ssim: 0.0403
Epoch 5/5


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1.0), HTML(value='')), layout=Layout(disp…

Epoch 5/5

1/1 - 1s - loss: 11.5033 - keras_psnr: -1.0608e+01 - keras_ssim: 0.0393

CPU times: user 1min, sys: 9.71 s, total: 1min 10s
Wall time: 33.8 s


In [9]:
%debug

E0305 16:37:23.808835 140518637909760 interactiveshell.py:1178] No traceback has been produced, nothing to debug.
