In [1]:
%cd ..

/volatile/home/Zaccharie/workspace/understanding-unets


In [2]:
# # this just to make sure we are using only on CPU
# import os
# os.environ["CUDA_VISIBLE_DEVICES"]="-1"

In [3]:
%load_ext autoreload
%autoreload 2
import os.path as op
import time

from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint
from keras_tqdm import TQDMNotebookCallback
import tensorflow as tf
from tqdm import tqdm_notebook

from learning_wavelets.data import im_generators
from learning_wavelets.evaluate import psnr, ssim
from learning_wavelets.keras_utils.filters_cback import NormalizeWeights
from learning_wavelets.keras_utils.image_tboard_cback import TensorBoardImage
from learning_wavelets.keras_utils.normalisation import NormalisationAdjustment
from learning_wavelets.keras_utils.thresholding import SoftThresholding
from learning_wavelets.learned_wavelet import learned_wavelet

Using TensorFlow backend.


In [4]:
tf.random.set_seed(1)

In [5]:
source = 'div2k'
noise_std = 30
grey = True
im_gen_train, im_gen_val, im_gen_test, size, n_samples_train = im_generators(
    source, 
    batch_size=1, 
    validation_split=0.1, 
    no_augment=False, 
    noise_std=noise_std,
    grey=grey,
)  

Found 720 images belonging to 1 classes.
Found 720 images belonging to 1 classes.
Found 80 images belonging to 1 classes.
Found 80 images belonging to 1 classes.
Found 100 images belonging to 1 classes.
Found 100 images belonging to 1 classes.


In [6]:
n_coarse = 3
if grey:
    n_coarse = 1
thresh = 2*noise_std/255
run_params = {
    'n_scales': 5, 
    'n_details': 256, 
    'n_coarse': n_coarse, 
    'n_groupping': 256,
    'denoising_activation': SoftThresholding(thresh),
    'wav_pooling': True,
    'wav_use_bias': False,
    'wav_normed': True,
    'filters_normed': ['details', 'coarse'],
}
n_epochs = 10
run_id = f'learned_wavelet_{source}_{noise_std}_{int(time.time())}'
chkpt_path = f'checkpoints/{run_id}' + '-{epoch:02d}.hdf5'
print(run_id)

learned_wavelet_div2k_30_1573570871


In [7]:
chkpt_cback = ModelCheckpoint(chkpt_path, period=n_epochs, save_weights_only=False)
log_dir = op.join('logs', run_id)
tboard_cback = TensorBoard(
    log_dir=log_dir, 
    histogram_freq=0, 
    write_graph=True, 
    write_images=False, 
)
tqdm_cb = TQDMNotebookCallback(metric_format="{name}: {value:e}")
tqdm_cb.on_train_batch_begin = tqdm_cb.on_batch_begin
tqdm_cb.on_train_batch_end = tqdm_cb.on_batch_end
val_noisy, val_gt = im_gen_val[0]
tboard_image_cback = TensorBoardImage(
    log_dir=log_dir + '/images',
    image=val_gt[0:1],
    noisy_image=val_noisy[0:1],
)
norm_cback = NormalisationAdjustment(momentum=0.99, n_pooling=5)
norm_cback.on_train_batch_end = norm_cback.on_batch_end

W1112 16:01:11.391887 140599946725120 callbacks.py:859] `period` argument is deprecated. Please use `save_freq` to specify the frequency in number of samples seen.


In [8]:
n_channels = 3
if grey:
    n_channels = 1
model = learned_wavelet(input_size=(None, None, n_channels), lr=1e-3, **run_params)
print(model.summary(line_length=150))

Model: "model"
______________________________________________________________________________________________________________________________________________________
Layer (type)                                     Output Shape                     Param #           Connected to                                      
input_1 (InputLayer)                             [(None, None, None, 1)]          0                                                                   
______________________________________________________________________________________________________________________________________________________
low_pass_filtering_1 (Conv2D)                    (None, None, None, 1)            25                input_1[0][0]                                     
______________________________________________________________________________________________________________________________________________________
average_pooling2d (AveragePooling2D)             (None, None, None, 1)         

In [9]:
%%time
model.fit_generator(
    im_gen_train, 
#     steps_per_epoch=int((1-validation_split) * n_samples_train / batch_size), 
    steps_per_epoch=5, 
    epochs=n_epochs,
    validation_data=im_gen_val,
#     validation_steps=int(validation_split * n_samples_train / batch_size),
    validation_steps=1,
    verbose=0,
    callbacks=[tqdm_cb, tboard_cback, chkpt_cback, tboard_image_cback, norm_cback],
    max_queue_size=100,
    use_multiprocessing=True,
    workers=35,
    shuffle=False,
)

HBox(children=(IntProgress(value=0, description='Training', max=10, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 0', max=5, style=ProgressStyle(description_width='initi…

W1112 16:01:21.788461 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.387348). Check your callbacks.
W1112 16:01:21.985952 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.179893). Check your callbacks.
W1112 16:01:22.194286 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.158560). Check your callbacks.
W1112 16:01:22.405254 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.138671). Check your callbacks.


HBox(children=(IntProgress(value=0, description='Epoch 1', max=5, style=ProgressStyle(description_width='initi…

W1112 16:01:26.824437 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.149321). Check your callbacks.
W1112 16:01:27.035589 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.144917). Check your callbacks.
W1112 16:01:27.246660 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.140514). Check your callbacks.
W1112 16:01:27.458166 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.140227). Check your callbacks.
W1112 16:01:27.665353 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.139941). Check your callbacks.


HBox(children=(IntProgress(value=0, description='Epoch 2', max=5, style=ProgressStyle(description_width='initi…

W1112 16:01:30.787612 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.174363). Check your callbacks.
W1112 16:01:31.000374 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.156259). Check your callbacks.
W1112 16:01:31.245677 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.158406). Check your callbacks.
W1112 16:01:31.460886 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.150919). Check your callbacks.
W1112 16:01:31.661001 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.143433). Check your callbacks.


HBox(children=(IntProgress(value=0, description='Epoch 3', max=5, style=ProgressStyle(description_width='initi…

W1112 16:01:34.829657 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.201356). Check your callbacks.
W1112 16:01:35.039169 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.169822). Check your callbacks.
W1112 16:01:35.263656 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.149636). Check your callbacks.
W1112 16:01:35.467710 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.143962). Check your callbacks.
W1112 16:01:35.677603 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.138288). Check your callbacks.


HBox(children=(IntProgress(value=0, description='Epoch 4', max=5, style=ProgressStyle(description_width='initi…

W1112 16:01:38.747855 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.194165). Check your callbacks.
W1112 16:01:38.969483 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.168755). Check your callbacks.
W1112 16:01:39.203340 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.154939). Check your callbacks.
W1112 16:01:39.416548 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.149142). Check your callbacks.
W1112 16:01:39.620813 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.143344). Check your callbacks.


HBox(children=(IntProgress(value=0, description='Epoch 5', max=5, style=ProgressStyle(description_width='initi…

W1112 16:01:43.134669 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.192375). Check your callbacks.
W1112 16:01:43.380552 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.178525). Check your callbacks.
W1112 16:01:43.601836 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.164675). Check your callbacks.
W1112 16:01:43.814809 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.153239). Check your callbacks.
W1112 16:01:44.021260 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.141804). Check your callbacks.


HBox(children=(IntProgress(value=0, description='Epoch 6', max=5, style=ProgressStyle(description_width='initi…

W1112 16:01:47.453667 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.202170). Check your callbacks.
W1112 16:01:47.707277 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.189343). Check your callbacks.
W1112 16:01:47.935419 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.176517). Check your callbacks.
W1112 16:01:48.161219 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.165951). Check your callbacks.
W1112 16:01:48.381062 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.155386). Check your callbacks.


HBox(children=(IntProgress(value=0, description='Epoch 7', max=5, style=ProgressStyle(description_width='initi…

W1112 16:01:51.911108 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.197746). Check your callbacks.
W1112 16:01:52.204931 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.204907). Check your callbacks.
W1112 16:01:52.429196 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.197746). Check your callbacks.
W1112 16:01:52.641567 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.167871). Check your callbacks.
W1112 16:01:52.848496 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.137997). Check your callbacks.


HBox(children=(IntProgress(value=0, description='Epoch 8', max=5, style=ProgressStyle(description_width='initi…

W1112 16:01:56.331531 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.203851). Check your callbacks.
W1112 16:01:56.584855 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.187215). Check your callbacks.
W1112 16:01:56.796975 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.170580). Check your callbacks.
W1112 16:01:57.007047 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.155760). Check your callbacks.
W1112 16:01:57.225702 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.146119). Check your callbacks.


HBox(children=(IntProgress(value=0, description='Epoch 9', max=5, style=ProgressStyle(description_width='initi…

W1112 16:02:00.554909 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.168361). Check your callbacks.
W1112 16:02:00.769980 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.155672). Check your callbacks.
W1112 16:02:01.039528 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.168361). Check your callbacks.
W1112 16:02:01.262384 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.159929). Check your callbacks.
W1112 16:02:01.486136 140599946725120 callbacks.py:241] Method (on_train_batch_end) is slow compared to the batch update (0.151497). Check your callbacks.



CPU times: user 17.5 s, sys: 28.2 s, total: 45.7 s
Wall time: 51 s


<tensorflow.python.keras.callbacks.History at 0x7fdeac04c6a0>

In [10]:
# %%time
# # overfitting trials
# data = im_gen_train[0]
# val_data = im_gen_val[0]
# model.fit(
#     x=data[0], 
#     y=data[1], 
# #     validation_data=val_data, 
#     batch_size=data[0].shape[0], 
#     callbacks=[tqdm_cb, tboard_cback, tboard_image_cback, norm_cback],
#     epochs=200, 
#     verbose=2, 
#     shuffle=False,
# )
# print('Original metrics')
# print(psnr(*data))
# print(ssim(*data))