In [1]:
%cd ..

/volatile/home/Zaccharie/workspace/understanding-unets


In [2]:
# # this just to make sure we are using only on CPU
# import os
# os.environ["CUDA_VISIBLE_DEVICES"]="-1"

In [3]:
%load_ext autoreload
%autoreload 2
import os.path as op
import time

from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, LearningRateScheduler
from keras_tqdm import TQDMNotebookCallback
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tqdm import tqdm_notebook

from learning_wavelets.datasets import im_dataset_div2k, im_dataset_bsd500
from learning_wavelets.evaluate import keras_psnr, keras_ssim, center_keras_psnr
from learning_wavelets.keras_utils.filters_cback import NormalizeWeights
from learning_wavelets.keras_utils.image_tboard_cback import TensorBoardImage
from learning_wavelets.keras_utils.normalisation import NormalisationAdjustment
from learning_wavelets.keras_utils.thresholding import DynamicSoftThresholding, DynamicHardThresholding
from learning_wavelets.learnlet_model import Learnlet

Using TensorFlow backend.


In [4]:
tf.random.set_seed(1)

In [5]:
noise_std_train = (0, 55)
noise_std_val = 30
batch_size = 8
source = 'bsd500'
if source == 'bsd500':
    data_func = im_dataset_bsd500
    n_samples_train = 400
else:
    data_func = im_dataset_div2k
    n_samples_train = 800
im_ds_train = data_func(
    mode='training', 
    batch_size=batch_size, 
    patch_size=256, 
    noise_std=noise_std_train, 
    return_noise_level=True,
)
im_ds_val = data_func(
    mode='validation', 
    batch_size=batch_size, 
    patch_size=256, 
    noise_std=noise_std_val, 
    return_noise_level=True,
)

In [6]:
# alpha = 2
run_params = {
    'denoising_activation': 'dynamic_soft_thresholding',
    'learnlet_analysis_kwargs':{
        'n_tiling': 64, 
        'mixing_details': False,    
        'skip_connection': True,
        'kernel_size': 11,
    },
    'learnlet_synthesis_kwargs': {
        'res': True,
        'kernel_size': 13,
    },
    'threshold_kwargs':{
        'noise_std_norm': True,
    },
#     'wav_type': 'bior',
    'n_scales': 5,
    'n_reweights_learn': 3,
#     'exact_reconstruction_weight': 0,
    'clip': False,
}
n_epochs = 500
run_id = f'learnlet_subclassing_st_{source}_{noise_std_train[0]}_{noise_std_train[1]}_{int(time.time())}'
chkpt_path = f'checkpoints/{run_id}' + '-{epoch:02d}.hdf5'
print(run_id)

learnlet_subclassing_st_bsd500_0_55_1581766374


In [7]:
def l_rate_schedule(epoch):
    return max(1e-3 / 2**(epoch//25), 1e-5)
lrate_cback = LearningRateScheduler(l_rate_schedule)

In [8]:
chkpt_cback = ModelCheckpoint(chkpt_path, period=n_epochs, save_weights_only=False)
log_dir = op.join('logs', run_id)
tboard_cback = TensorBoard(
    log_dir=log_dir, 
    histogram_freq=0, 
    write_graph=True, 
    write_images=False, 
    profile_batch=0,
)
tqdm_cb = TQDMNotebookCallback(metric_format="{name}: {value:e}")
tqdm_cb.on_train_batch_begin = tqdm_cb.on_batch_begin
tqdm_cb.on_train_batch_end = tqdm_cb.on_batch_end
# tqdm_cb = TQDMProgressBar()
val_noisy, val_gt = next(iter(im_ds_val))
tboard_image_cback = TensorBoardImage(
    log_dir=log_dir + '/images',
    image=val_gt[0:1],
    noisy_image=val_noisy[0:1],
)
norm_cback = NormalisationAdjustment(momentum=0.99, n_pooling=5, dynamic_denoising=True)
norm_cback.on_train_batch_end = norm_cback.on_batch_end



In [9]:
model = Learnlet(**run_params)
model.compile(
    optimizer=Adam(lr=1e-3),
    loss='mse',
    metrics=[keras_psnr, keras_ssim, center_keras_psnr],
)
# print(model.summary(line_length=114))

In [None]:
%%time
model.fit(
    im_ds_train, 
    steps_per_epoch=200, 
#     steps_per_epoch=5, 
    epochs=n_epochs,
    validation_data=im_ds_val,
#     validation_steps=int(validation_split * n_samples_train / batch_size),
    validation_steps=1,
    verbose=0,
    callbacks=[tqdm_cb, tboard_cback, chkpt_cback, norm_cback, lrate_cback],
#     callbacks=[tqdm_cb, tboard_cback, chkpt_cback, tboard_image_cback, norm_cback, lrate_cback],
    shuffle=False,
)

HBox(children=(FloatProgress(value=0.0, description='Training', max=500.0, style=ProgressStyle(description_wid…

HBox(children=(FloatProgress(value=0.0, description='Epoch 0', max=200.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=200.0, style=ProgressStyle(description_widt…

In [None]:
# %%time
# # overfitting trials
# data = next(iter(im_ds_train))
# val_data = next(iter(im_ds_val))
# model.fit(
#     x=data[0], 
#     y=data[1], 
# #     validation_data=val_data, 
#     batch_size=1, 
# #     callbacks=[tqdm_cb, tboard_cback, tboard_image_cback, norm_cback, lrate_cback],
#     callbacks=[tqdm_cb, tboard_cback, lrate_cback],
#     epochs=n_epochs, 
#     verbose=2, 
#     shuffle=False,
# )

In [None]:
dyn = [l for l in model.layers if 'dynamic' in l.name]

In [None]:
[l.alpha_thresh.numpy() for l in dyn]

In [None]:
[l.alpha_bias.numpy() for l in dyn]