In [1]:
%cd ..

/volatile/home/Zaccharie/workspace/understanding-unets


In [2]:
# # this just to make sure we are using only on CPU
# import os
# os.environ["CUDA_VISIBLE_DEVICES"]="-1"

In [3]:
%load_ext autoreload
%autoreload 2
import os.path as op
import time

from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint
from keras_tqdm import TQDMNotebookCallback
import tensorflow as tf
from tqdm import tqdm_notebook

from learning_wavelets.data import im_generators
from learning_wavelets.evaluate import psnr, ssim
from learning_wavelets.keras_utils.filters_cback import NormalizeWeights
from learning_wavelets.keras_utils.image_tboard_cback import TensorBoardImage
from learning_wavelets.keras_utils.normalisation import NormalisationAdjustment
from learning_wavelets.keras_utils.thresholding import SoftThresholding
from learning_wavelets.learned_wavelet import learned_wavelet

Using TensorFlow backend.


In [4]:
tf.random.set_seed(1)

In [5]:
source = 'div2k'
noise_std = 30
grey = True
im_gen_train, im_gen_val, im_gen_test, size, n_samples_train = im_generators(
    source, 
    batch_size=1, 
    validation_split=0.1, 
    no_augment=False, 
    noise_std=noise_std,
    grey=grey,
)  

Found 720 images belonging to 1 classes.
Found 720 images belonging to 1 classes.
Found 80 images belonging to 1 classes.
Found 80 images belonging to 1 classes.
Found 100 images belonging to 1 classes.
Found 100 images belonging to 1 classes.


In [6]:
n_coarse = 3
if grey:
    n_coarse = 1
thresh = 2*noise_std/255
run_params = {
    'n_scales': 5, 
    'n_details': 256, 
    'n_coarse': n_coarse, 
    'n_groupping': 256,
    'denoising_activation': SoftThresholding(thresh),
    'wav_pooling': True,
    'wav_use_bias': False,
    'wav_normed': True,
    'filters_normed': ['details', 'coarse'],
}
n_epochs = 500
run_id = f'learned_wavelet_{source}_{noise_std}_{int(time.time())}'
chkpt_path = f'checkpoints/{run_id}' + '-{epoch:02d}.hdf5'
print(run_id)

learned_wavelet_div2k_30_1573569554


In [7]:
chkpt_cback = ModelCheckpoint(chkpt_path, period=n_epochs, save_weights_only=False)
log_dir = op.join('logs', run_id)
tboard_cback = TensorBoard(
    log_dir=log_dir, 
    histogram_freq=0, 
    write_graph=True, 
    write_images=False, 
)
tqdm_cb = TQDMNotebookCallback(metric_format="{name}: {value:e}")
tqdm_cb.on_train_batch_begin = tqdm_cb.on_batch_begin
tqdm_cb.on_train_batch_end = tqdm_cb.on_batch_end
val_noisy, val_gt = im_gen_val[0]
tboard_image_cback = TensorBoardImage(
    log_dir=log_dir + '/images',
    image=val_gt[0:1],
    noisy_image=val_noisy[0:1],
)
norm_cback = NormalisationAdjustment(momentum=0.99, n_pooling=5)
norm_cback.on_train_batch_end = norm_cback.on_batch_end

W1112 15:39:14.780871 140640970262272 callbacks.py:859] `period` argument is deprecated. Please use `save_freq` to specify the frequency in number of samples seen.


In [8]:
n_channels = 3
if grey:
    n_channels = 1
model = learned_wavelet(input_size=(None, None, n_channels), lr=1e-3, **run_params)
print(model.summary(line_length=150))

Model: "model"
______________________________________________________________________________________________________________________________________________________
Layer (type)                                     Output Shape                     Param #           Connected to                                      
input_1 (InputLayer)                             [(None, None, None, 1)]          0                                                                   
______________________________________________________________________________________________________________________________________________________
low_pass_filtering_1 (Conv2D)                    (None, None, None, 1)            25                input_1[0][0]                                     
______________________________________________________________________________________________________________________________________________________
average_pooling2d (AveragePooling2D)             (None, None, None, 1)         

In [9]:
# %%time
# model.fit_generator(
#     im_gen_train, 
# #     steps_per_epoch=int((1-validation_split) * n_samples_train / batch_size), 
#     steps_per_epoch=5, 
#     epochs=n_epochs,
#     validation_data=im_gen_val,
# #     validation_steps=int(validation_split * n_samples_train / batch_size),
#     validation_steps=1,
#     verbose=0,
#     callbacks=[tqdm_cb, tboard_cback, chkpt_cback, tboard_image_cback, norm_cback],
#     max_queue_size=100,
#     use_multiprocessing=True,
#     workers=35,
#     shuffle=False,
# )

In [10]:
%%time
# overfitting trials
data = im_gen_train[0]
val_data = im_gen_val[0]
model.fit(
    x=data[0], 
    y=data[1], 
#     validation_data=val_data, 
    batch_size=data[0].shape[0], 
    callbacks=[tqdm_cb, tboard_cback, tboard_image_cback, norm_cback],
    epochs=200, 
    verbose=2, 
    shuffle=False,
)
print('Original metrics')
print(psnr(*data))
print(ssim(*data))

Train on 1 samples


HBox(children=(IntProgress(value=0, description='Training', max=20, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 0', max=1, style=ProgressStyle(description_width='initi…

Epoch 1/20
1/1 - 4s - loss: 1.0000 - keras_psnr: 11.6873 - keras_ssim: 0.0139


HBox(children=(IntProgress(value=0, description='Epoch 1', max=1, style=ProgressStyle(description_width='initi…

Epoch 2/20


W1112 15:39:23.570662 140640970262272 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.223439). Check your callbacks.


1/1 - 0s - loss: 1.1981 - keras_psnr: 10.9025 - keras_ssim: 0.0020


HBox(children=(IntProgress(value=0, description='Epoch 2', max=1, style=ProgressStyle(description_width='initi…

Epoch 3/20


W1112 15:39:23.839141 140640970262272 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.143260). Check your callbacks.


1/1 - 0s - loss: 2.5604 - keras_psnr: 7.6044 - keras_ssim: 0.0025


HBox(children=(IntProgress(value=0, description='Epoch 3', max=1, style=ProgressStyle(description_width='initi…

Epoch 4/20


W1112 15:39:24.099596 140640970262272 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.134546). Check your callbacks.


1/1 - 0s - loss: 1.4214 - keras_psnr: 10.1602 - keras_ssim: 0.0291


HBox(children=(IntProgress(value=0, description='Epoch 4', max=1, style=ProgressStyle(description_width='initi…

Epoch 5/20


W1112 15:39:24.363886 140640970262272 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.144053). Check your callbacks.


1/1 - 0s - loss: 1.7114 - keras_psnr: 9.3539 - keras_ssim: 0.0575


HBox(children=(IntProgress(value=0, description='Epoch 5', max=1, style=ProgressStyle(description_width='initi…

Epoch 6/20


W1112 15:39:24.634099 140640970262272 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.148367). Check your callbacks.


1/1 - 0s - loss: 1.1942 - keras_psnr: 10.9168 - keras_ssim: 0.0367


HBox(children=(IntProgress(value=0, description='Epoch 6', max=1, style=ProgressStyle(description_width='initi…

Epoch 7/20


W1112 15:39:24.905940 140640970262272 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.146099). Check your callbacks.


1/1 - 0s - loss: 0.9842 - keras_psnr: 11.7567 - keras_ssim: 0.0253


HBox(children=(IntProgress(value=0, description='Epoch 7', max=1, style=ProgressStyle(description_width='initi…

Epoch 8/20


W1112 15:39:25.171598 140640970262272 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.143696). Check your callbacks.


1/1 - 0s - loss: 0.9113 - keras_psnr: 12.0906 - keras_ssim: 0.0189


HBox(children=(IntProgress(value=0, description='Epoch 8', max=1, style=ProgressStyle(description_width='initi…

Epoch 9/20


W1112 15:39:25.428303 140640970262272 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.133069). Check your callbacks.


1/1 - 0s - loss: 0.9533 - keras_psnr: 11.8953 - keras_ssim: 0.0124


HBox(children=(IntProgress(value=0, description='Epoch 9', max=1, style=ProgressStyle(description_width='initi…

Epoch 10/20


W1112 15:39:25.694834 140640970262272 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.140751). Check your callbacks.


1/1 - 0s - loss: 0.6821 - keras_psnr: 13.3490 - keras_ssim: 0.0149


HBox(children=(IntProgress(value=0, description='Epoch 10', max=1, style=ProgressStyle(description_width='init…

Epoch 11/20


W1112 15:39:25.983670 140640970262272 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.142037). Check your callbacks.


1/1 - 0s - loss: 1.0112 - keras_psnr: 11.6391 - keras_ssim: 0.0310


HBox(children=(IntProgress(value=0, description='Epoch 11', max=1, style=ProgressStyle(description_width='init…

Epoch 12/20


W1112 15:39:26.260824 140640970262272 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.147560). Check your callbacks.


1/1 - 0s - loss: 0.3865 - keras_psnr: 15.8154 - keras_ssim: 0.0773


HBox(children=(IntProgress(value=0, description='Epoch 12', max=1, style=ProgressStyle(description_width='init…

Epoch 13/20


W1112 15:39:26.527197 140640970262272 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.142250). Check your callbacks.


1/1 - 0s - loss: 0.5742 - keras_psnr: 14.0970 - keras_ssim: 0.0861


HBox(children=(IntProgress(value=0, description='Epoch 13', max=1, style=ProgressStyle(description_width='init…

Epoch 14/20


W1112 15:39:26.798288 140640970262272 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.136515). Check your callbacks.


1/1 - 0s - loss: 0.7339 - keras_psnr: 13.0312 - keras_ssim: 0.0681


HBox(children=(IntProgress(value=0, description='Epoch 14', max=1, style=ProgressStyle(description_width='init…

Epoch 15/20


W1112 15:39:27.058813 140640970262272 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.135068). Check your callbacks.


1/1 - 0s - loss: 0.6575 - keras_psnr: 13.5085 - keras_ssim: 0.0767


HBox(children=(IntProgress(value=0, description='Epoch 15', max=1, style=ProgressStyle(description_width='init…

Epoch 16/20


W1112 15:39:27.323155 140640970262272 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.138602). Check your callbacks.


1/1 - 0s - loss: 0.4227 - keras_psnr: 15.4267 - keras_ssim: 0.1232


HBox(children=(IntProgress(value=0, description='Epoch 16', max=1, style=ProgressStyle(description_width='init…

Epoch 17/20


W1112 15:39:27.585628 140640970262272 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.136806). Check your callbacks.


1/1 - 0s - loss: 0.2882 - keras_psnr: 17.0902 - keras_ssim: 0.1601


HBox(children=(IntProgress(value=0, description='Epoch 17', max=1, style=ProgressStyle(description_width='init…

Epoch 18/20


W1112 15:39:27.847534 140640970262272 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.136147). Check your callbacks.


1/1 - 0s - loss: 0.3909 - keras_psnr: 15.7672 - keras_ssim: 0.1201


HBox(children=(IntProgress(value=0, description='Epoch 18', max=1, style=ProgressStyle(description_width='init…

Epoch 19/20


W1112 15:39:28.124894 140640970262272 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.142864). Check your callbacks.


1/1 - 0s - loss: 0.5253 - keras_psnr: 14.4836 - keras_ssim: 0.1039


HBox(children=(IntProgress(value=0, description='Epoch 19', max=1, style=ProgressStyle(description_width='init…

Epoch 20/20


W1112 15:39:28.386079 140640970262272 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.140896). Check your callbacks.


1/1 - 0s - loss: 0.3806 - keras_psnr: 15.8826 - keras_ssim: 0.1330

Original metrics
23.271581883963698
0.3025488964153973
CPU times: user 9.73 s, sys: 1.85 s, total: 11.6 s
Wall time: 11.3 s
