In [1]:
%cd ..

/volatile/home/Zaccharie/workspace/understanding-unets


In [2]:
# # this just to make sure we are using only on CPU
# import os
# os.environ["CUDA_VISIBLE_DEVICES"]="-1"

In [3]:
%load_ext autoreload
%autoreload 2
import os.path as op
import time

from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, LearningRateScheduler
from keras_tqdm import TQDMNotebookCallback
import tensorflow as tf
# from tensorflow_addons.callbacks import TQDMProgressBar
from tqdm import tqdm_notebook

from learning_wavelets.datasets import im_dataset_div2k
from learning_wavelets.evaluate import psnr, ssim
from learning_wavelets.keras_utils.filters_cback import NormalizeWeights
from learning_wavelets.keras_utils.image_tboard_cback import TensorBoardImage
from learning_wavelets.keras_utils.normalisation import NormalisationAdjustment
from learning_wavelets.keras_utils.thresholding import SoftThresholding
from learning_wavelets.learned_wavelet import learned_wavelet

Using TensorFlow backend.


In [4]:
tf.random.set_seed(1)

In [5]:
noise_std = 30
batch_size = 8
n_samples_train = 800
im_ds_train = im_dataset_div2k(mode='training', batch_size=batch_size, patch_size=256, noise_std=30)
im_ds_val = im_dataset_div2k(mode='validation', batch_size=1, patch_size=256, noise_std=30)

In [6]:
n_coarse = 1
thresh = 2*noise_std/255
run_params = {
    'n_scales': 5, 
    'n_details': 256, 
    'n_coarse': n_coarse, 
    'mixing_details': False,
    'denoising_activation': SoftThresholding(thresh),
    'wav_pooling': True,
    'wav_use_bias': False,
    'wav_normed': True,
    'filters_normed': ['details', 'coarse'],
}
n_epochs = 250
run_id = f'learned_wavelet_div2k_{noise_std}_{int(time.time())}'
chkpt_path = f'checkpoints/{run_id}' + '-{epoch:02d}.hdf5'
print(run_id)

learned_wavelet_div2k_30_1574790444


In [7]:
def l_rate_schedule(epoch):
    return max(1e-3 / 2**(epoch//25), 1e-5)
lrate_cback = LearningRateScheduler(l_rate_schedule)

In [8]:
chkpt_cback = ModelCheckpoint(chkpt_path, period=n_epochs, save_weights_only=False)
log_dir = op.join('logs', run_id)
tboard_cback = TensorBoard(
    log_dir=log_dir, 
    histogram_freq=0, 
    write_graph=True, 
    write_images=False, 
    profile_batch=0,
)
tqdm_cb = TQDMNotebookCallback(metric_format="{name}: {value:e}")
tqdm_cb.on_train_batch_begin = tqdm_cb.on_batch_begin
tqdm_cb.on_train_batch_end = tqdm_cb.on_batch_end
# tqdm_cb = TQDMProgressBar()
val_noisy, val_gt = next(iter(im_ds_val))
tboard_image_cback = TensorBoardImage(
    log_dir=log_dir + '/images',
    image=val_gt[0:1],
    noisy_image=val_noisy[0:1],
)
norm_cback = NormalisationAdjustment(momentum=0.99, n_pooling=5)
norm_cback.on_train_batch_end = norm_cback.on_batch_end

W1126 18:47:24.810588 139855565879040 callbacks.py:863] `period` argument is deprecated. Please use `save_freq` to specify the frequency in number of samples seen.


In [9]:
n_channels = 1
model = learned_wavelet(input_size=(None, None, n_channels), lr=1e-3, **run_params)
print(model.summary(line_length=150))

Model: "model"
______________________________________________________________________________________________________________________________________________________
Layer (type)                                     Output Shape                     Param #           Connected to                                      
input_1 (InputLayer)                             [(None, None, None, 1)]          0                                                                   
______________________________________________________________________________________________________________________________________________________
low_pass_filtering_1 (Conv2D)                    (None, None, None, 1)            25                input_1[0][0]                                     
______________________________________________________________________________________________________________________________________________________
average_pooling2d (AveragePooling2D)             (None, None, None, 1)         

In [None]:
%%time
model.fit(
    im_ds_train, 
    steps_per_epoch=int(n_samples_train / batch_size), 
#     steps_per_epoch=5, 
    epochs=n_epochs,
    validation_data=im_ds_val,
#     validation_steps=int(validation_split * n_samples_train / batch_size),
    validation_steps=1,
    verbose=0,
#     callbacks=[tqdm_cb, norm_cback],
    callbacks=[tqdm_cb, tboard_cback, chkpt_cback, tboard_image_cback, norm_cback, lrate_cback],
    shuffle=False,
)

HBox(children=(IntProgress(value=0, description='Training', max=250, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 0', style=ProgressStyle(description_width='initial')), …

HBox(children=(IntProgress(value=0, description='Epoch 1', style=ProgressStyle(description_width='initial')), …

HBox(children=(IntProgress(value=0, description='Epoch 2', style=ProgressStyle(description_width='initial')), …

W1126 18:49:00.855328 139855565879040 callbacks.py:244] Method (on_train_batch_end) is slow compared to the batch update (0.355224). Check your callbacks.
W1126 18:49:01.298155 139855565879040 callbacks.py:244] Method (on_train_batch_end) is slow compared to the batch update (0.261720). Check your callbacks.


HBox(children=(IntProgress(value=0, description='Epoch 3', style=ProgressStyle(description_width='initial')), …

HBox(children=(IntProgress(value=0, description='Epoch 4', style=ProgressStyle(description_width='initial')), …

W1126 18:50:27.874625 139855565879040 callbacks.py:244] Method (on_train_batch_end) is slow compared to the batch update (0.294275). Check your callbacks.


HBox(children=(IntProgress(value=0, description='Epoch 5', style=ProgressStyle(description_width='initial')), …

W1126 18:51:11.220229 139855565879040 callbacks.py:244] Method (on_train_batch_end) is slow compared to the batch update (0.312131). Check your callbacks.


HBox(children=(IntProgress(value=0, description='Epoch 6', style=ProgressStyle(description_width='initial')), …

HBox(children=(IntProgress(value=0, description='Epoch 7', style=ProgressStyle(description_width='initial')), …

HBox(children=(IntProgress(value=0, description='Epoch 8', style=ProgressStyle(description_width='initial')), …

In [None]:
# %%time
# # overfitting trials
# data = im_gen_train[0]
# val_data = im_gen_val[0]
# model.fit(
#     x=data[0], 
#     y=data[1], 
# #     validation_data=val_data, 
#     batch_size=data[0].shape[0], 
#     callbacks=[tqdm_cb, tboard_cback, tboard_image_cback, norm_cback],
#     epochs=200, 
#     verbose=2, 
#     shuffle=False,
# )
# print('Original metrics')
# print(psnr(*data))
# print(ssim(*data))

In [None]:
%matplotlib nbagg
import matplotlib.pyplot as plt
plt.figure()
plt.plot(norm_cback.stds_lists[3])
plt.ylim([0.3, 1.3])