In [1]:
%cd ..

/volatile/home/Zaccharie/workspace/understanding-unets


In [2]:
# # this just to make sure we are using only on CPU
# import os
# os.environ["CUDA_VISIBLE_DEVICES"]="-1"

In [3]:
%load_ext autoreload
%autoreload 2
import os.path as op
import time

from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, LearningRateScheduler
from keras_tqdm import TQDMNotebookCallback
import tensorflow as tf
from tqdm import tqdm_notebook

from learning_wavelets.datasets import im_dataset_div2k
from learning_wavelets.evaluate import psnr, ssim
from learning_wavelets.keras_utils.image_tboard_cback import TensorBoardImage
from learning_wavelets.dncnn import dncnn

Using TensorFlow backend.


In [4]:
tf.random.set_seed(1)

In [5]:
noise_std = 30
batch_size = 8
n_samples_train = 800
im_ds_train = im_dataset_div2k(mode='training', batch_size=batch_size, patch_size=50, noise_std=noise_std)
im_ds_val = im_dataset_div2k(mode='validation', batch_size=1, patch_size=50, noise_std=noise_std)

In [6]:
run_params = {
    'filters': 64, 
    'depth': 20,
}
n_epochs = 150
run_id = f'dncnn_div2k_{noise_std}_{int(time.time())}'
chkpt_path = f'checkpoints/{run_id}' + '-{epoch:02d}.hdf5'
print(run_id)

dncnn_div2k_30_1575738740


In [7]:
def l_rate_schedule(epoch):
    return 1e-3 / 2**(epoch//25)
lrate_cback = LearningRateScheduler(l_rate_schedule)

In [8]:
chkpt_cback = ModelCheckpoint(chkpt_path, period=n_epochs, save_weights_only=False)
log_dir = op.join('logs', run_id)
tboard_cback = TensorBoard(
    log_dir=log_dir, 
    histogram_freq=0, 
    write_graph=True, 
    write_images=False, 
    profile_batch=0,
)
tqdm_cb = TQDMNotebookCallback(metric_format="{name}: {value:e}")
tqdm_cb.on_train_batch_begin = tqdm_cb.on_batch_begin
tqdm_cb.on_train_batch_end = tqdm_cb.on_batch_end
val_noisy, val_gt = next(iter(im_ds_val))
tboard_image_cback = TensorBoardImage(
    log_dir=log_dir + '/images',
    image=val_gt[0:1],
    noisy_image=val_noisy[0:1],
)

W1207 18:12:20.846529 140512068499200 callbacks.py:863] `period` argument is deprecated. Please use `save_freq` to specify the frequency in number of samples seen.


In [9]:
n_channels = 1
model = dncnn(input_size=(None, None, n_channels), lr=1e-1, **run_params)
print(model.summary(line_length=150))

Model: "model"
______________________________________________________________________________________________________________________________________________________
Layer (type)                                     Output Shape                     Param #           Connected to                                      
input_1 (InputLayer)                             [(None, None, None, 1)]          0                                                                   
______________________________________________________________________________________________________________________________________________________
conv2d (Conv2D)                                  (None, None, None, 64)           640               input_1[0][0]                                     
______________________________________________________________________________________________________________________________________________________
activation (Activation)                          (None, None, None, 64)        

In [None]:
%%time
model.fit(
    im_ds_train, 
    steps_per_epoch=3000, 
#     steps_per_epoch=5, 
    epochs=n_epochs,
    validation_data=im_ds_val,
#     validation_steps=int(validation_split * n_samples_train / batch_size),
    validation_steps=1,
    verbose=0,
    callbacks=[tqdm_cb, tboard_cback, chkpt_cback, tboard_image_cback, lrate_cback],
    shuffle=False,
)

HBox(children=(IntProgress(value=0, description='Training', max=150, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 0', max=3000, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 1', max=3000, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 2', max=3000, style=ProgressStyle(description_width='in…

In [12]:
model.evaluate(im_ds_val, steps=200)



[0.0023909427807211614, 30.57302, 0.8163035]

In [None]:
# %%time
# # overfitting trials
# data = next(iter(im_ds_train))
# val_data = next(iter(im_ds_val))
# model.fit(
#     x=data[0], 
#     y=data[1], 
# #     validation_data=val_data, 
#     batch_size=batch_size, 
# #     callbacks=[tqdm_cb, tboard_cback, tboard_image_cback, norm_cback, lrate_cback],
#     callbacks=[tqdm_cb, tboard_cback, lrate_cback],
#     epochs=250, 
#     verbose=2, 
#     shuffle=False,
# )
# print('Original metrics')
# print(psnr(data[0].numpy(), data[1].numpy()))
# print(ssim(data[0].numpy(), data[1].numpy()))