In [1]:
%cd ..

/volatile/home/Zaccharie/workspace/understanding-unets


In [2]:
# # this just to make sure we are using only on CPU
# import os
# os.environ["CUDA_VISIBLE_DEVICES"]="-1"

In [3]:
%load_ext autoreload
%autoreload 2
import os.path as op
import time

from keras.callbacks import TensorBoard, ModelCheckpoint
from keras_tqdm import TQDMNotebookCallback
from tensorflow import set_random_seed
from tqdm import tqdm_notebook

from learning_wavelets.data import im_generators
from learning_wavelets.evaluate import psnr, ssim
from learning_wavelets.keras_utils.filters_cback import NormalizeWeights
from learning_wavelets.keras_utils.image_tboard_cback import TensorBoardImage
from learning_wavelets.keras_utils.thresholding import SoftThresholding
from learning_wavelets.learned_wavelet import learned_wavelet

Using TensorFlow backend.


In [4]:
set_random_seed(1)

In [5]:
source = 'div2k'
noise_std = 30
grey = True
im_gen_train, im_gen_val, im_gen_test, size, n_samples_train = im_generators(
    source, 
    batch_size=1, 
    validation_split=0.1, 
    no_augment=False, 
    noise_std=noise_std,
    grey=grey,
)  

Found 720 images belonging to 1 classes.
Found 720 images belonging to 1 classes.
Found 80 images belonging to 1 classes.
Found 80 images belonging to 1 classes.
Found 100 images belonging to 1 classes.
Found 100 images belonging to 1 classes.


In [6]:
n_coarse = 3
if grey:
    n_coarse = 1
run_params = {
    'n_scales': 5, 
    'n_details': 256, 
    'n_coarse': n_coarse, 
    'n_groupping': 256,
    'denoising_activation': SoftThresholding(0.1),
    'wav_pooling': False,
}
n_epochs = 250
run_id = f'learned_wavelet_{source}_{noise_std}_{int(time.time())}'
chkpt_path = f'checkpoints/{run_id}' + '-{epoch:02d}.hdf5'
print(run_id)

learned_wavelet_div2k_30_1571384947


In [7]:
chkpt_cback = ModelCheckpoint(chkpt_path, period=n_epochs, save_weights_only=False)
log_dir = op.join('logs', run_id)
tboard_cback = TensorBoard(
    log_dir=log_dir, 
    histogram_freq=0, 
    write_graph=True, 
    write_images=False, 
)
tqdm_cb = TQDMNotebookCallback(metric_format="{name}: {value:e}")
val_noisy, val_gt = im_gen_val[0]
tboard_image_cback = TensorBoardImage(
    log_dir=log_dir + '/images',
    image=val_gt[0:1],
    noisy_image=val_noisy[0:1],
)

In [8]:
n_channels = 3
if grey:
    n_channels = 1
model = learned_wavelet(input_size=(size, size, n_channels), lr=1e-3, **run_params)
print(model.summary(line_length=150))

Instructions for updating:
Colocations handled automatically by placer.


  bias=bias,
  identifier=identifier.__class__.__name__))
  bias=bias,
  bias=bias,


______________________________________________________________________________________________________________________________________________________
Layer (type)                                     Output Shape                     Param #           Connected to                                      
input_1 (InputLayer)                             (None, 256, 256, 1)              0                                                                   
______________________________________________________________________________________________________________________________________________________
conv2d_2 (Conv2D)                                (None, 256, 256, 1)              10                input_1[0][0]                                     
______________________________________________________________________________________________________________________________________________________
activation_2 (Activation)                        (None, 256, 256, 1)              0           

In [None]:
%%time
model.fit_generator(
    im_gen_train, 
#     steps_per_epoch=int((1-validation_split) * n_samples_train / batch_size), 
    steps_per_epoch=n_samples_train, 
    epochs=n_epochs,
    validation_data=im_gen_val,
#     validation_steps=int(validation_split * n_samples_train / batch_size),
    validation_steps=1,
    verbose=0,
    callbacks=[tqdm_cb, tboard_cback, chkpt_cback, tboard_image_cback],
    max_queue_size=100,
    use_multiprocessing=True,
    workers=35,
    shuffle=False,
)

Instructions for updating:
Use tf.cast instead.


HBox(children=(IntProgress(value=0, description='Training', max=250, style=ProgressStyle(description_width='in…

  .format(dtypeobj_in, dtypeobj_out))


HBox(children=(IntProgress(value=0, description='Epoch 0', max=800, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 1', max=800, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 2', max=800, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 3', max=800, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 4', max=800, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 5', max=800, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 6', max=800, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 7', max=800, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 8', max=800, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 9', max=800, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 10', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 11', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 12', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 13', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 14', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 15', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 16', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 17', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 18', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 19', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 20', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 21', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 22', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 23', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 24', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 25', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 26', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 27', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 28', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 29', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 30', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 31', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 32', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 33', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 34', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 35', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 36', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 37', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 38', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 39', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 40', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 41', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 42', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 43', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 44', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 45', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 46', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 47', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 48', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 49', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 50', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 51', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 52', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 53', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 54', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 55', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 56', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 57', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 58', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 59', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 60', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 61', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 62', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 63', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 64', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 65', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 66', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 67', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 68', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 69', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 70', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 71', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 72', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 73', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 74', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 75', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 76', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 77', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 78', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 79', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 80', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 81', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 82', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 83', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 84', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 85', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 86', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 87', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 88', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 89', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 90', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 91', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 92', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 93', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 94', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 95', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 96', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 97', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 98', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 99', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 100', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 101', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 102', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 103', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 104', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 105', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 106', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 107', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 108', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 109', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 110', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 111', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 112', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 113', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 114', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 115', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 116', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 117', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 118', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 119', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 120', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 121', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 122', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 123', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 124', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 125', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 126', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 127', max=800, style=ProgressStyle(description_width='i…

HBox(children=(IntProgress(value=0, description='Epoch 128', max=800, style=ProgressStyle(description_width='i…

In [None]:
# %%time
# # overfitting trials
# data = im_gen_train[0]
# val_data = im_gen_val[0]
# model.fit(
#     x=data[0], 
#     y=data[1], 
# #     validation_data=val_data, 
#     batch_size=data[0].shape[0], 
#     callbacks=[tqdm_cb, tboard_cback, tboard_image_cback],
#     epochs=200, 
#     verbose=2, 
#     shuffle=False,
# )
# print('Original metrics')
# print(psnr(*data))
# print(ssim(*data))