In [1]:
%cd ..

/volatile/home/Zaccharie/workspace/understanding-unets


In [2]:
# # this just to make sure we are using only on CPU
# import os
# os.environ["CUDA_VISIBLE_DEVICES"]="-1"

In [3]:
%load_ext autoreload
%autoreload 2
import os.path as op
import time

from keras.callbacks import TensorBoard, ModelCheckpoint
from keras_tqdm import TQDMNotebookCallback
from tensorflow import set_random_seed
from tqdm import tqdm_notebook

from learning_wavelets.data import im_generators
from learning_wavelets.evaluate import psnr, ssim
from learning_wavelets.keras_utils.image_tboard_cback import TensorBoardImage
from learning_wavelets.unet import unet

Using TensorFlow backend.


In [4]:
set_random_seed(1)

In [5]:
source = 'div2k'
noise_std = 30
grey = True
im_gen_train, im_gen_val, im_gen_test, size, n_samples_train = im_generators(
    source, 
    batch_size=1, 
    validation_split=0.1, 
    no_augment=False, 
    noise_std=noise_std,
    grey=grey,
)  

Found 720 images belonging to 1 classes.
Found 720 images belonging to 1 classes.
Found 80 images belonging to 1 classes.
Found 80 images belonging to 1 classes.
Found 100 images belonging to 1 classes.
Found 100 images belonging to 1 classes.


In [6]:
run_params = {
    'n_layers': 5, 
#     'n_layers': 2, 
    'pool': 'max', 
    "layers_n_channels": [64, 128, 256, 512, 1024], 
#     "layers_n_channels": [16, 32], 
    'layers_n_non_lins': 2,
    'non_relu_contract': False,
    'bn': False,
}
n_epochs = 250
run_id = f'unet_{source}_{noise_std}_{int(time.time())}'
chkpt_path = f'checkpoints/{run_id}' + '-{epoch:02d}.hdf5'
print(run_id)

unet_div2k_30_1571407692


In [7]:
chkpt_cback = ModelCheckpoint(chkpt_path, period=n_epochs, save_weights_only=False)
log_dir = op.join('logs', run_id)
tboard_cback = TensorBoard(
    log_dir=log_dir, 
    histogram_freq=0, 
    write_graph=True, 
    write_images=False, 
)
tqdm_cb = TQDMNotebookCallback(metric_format="{name}: {value:e}")
val_noisy, val_gt = im_gen_val[0]
tboard_image_cback = TensorBoardImage(
    log_dir=log_dir + '/images',
    image=val_gt[0:1],
    noisy_image=val_noisy[0:1],
)

In [8]:
n_channels = 3
if grey:
    n_channels = 1
model = unet(input_size=(size, size, n_channels), lr=1e-3, **run_params)
print(model.summary(line_length=150))

Instructions for updating:
Colocations handled automatically by placer.
______________________________________________________________________________________________________________________________________________________
Layer (type)                                     Output Shape                     Param #           Connected to                                      
input_1 (InputLayer)                             (None, 256, 256, 1)              0                                                                   
______________________________________________________________________________________________________________________________________________________
conv2d_1 (Conv2D)                                (None, 256, 256, 64)             640               input_1[0][0]                                     
______________________________________________________________________________________________________________________________________________________
conv2d_2 (Conv2D)     

In [None]:
%%time
model.fit_generator(
    im_gen_train, 
#     steps_per_epoch=int((1-validation_split) * n_samples_train / batch_size), 
    steps_per_epoch=n_samples_train, 
    epochs=n_epochs,
    validation_data=im_gen_val,
#     validation_steps=int(validation_split * n_samples_train / batch_size),
    validation_steps=1,
    verbose=0,
    callbacks=[tqdm_cb, tboard_cback, chkpt_cback, tboard_image_cback],
    max_queue_size=100,
    use_multiprocessing=True,
    workers=35,
    shuffle=False,
)

Instructions for updating:
Use tf.cast instead.


HBox(children=(IntProgress(value=0, description='Training', max=250, style=ProgressStyle(description_width='in…

  .format(dtypeobj_in, dtypeobj_out))


HBox(children=(IntProgress(value=0, description='Epoch 0', max=800, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 1', max=800, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 2', max=800, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 3', max=800, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 4', max=800, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 5', max=800, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 6', max=800, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 7', max=800, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 8', max=800, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 9', max=800, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 10', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 11', max=800, style=ProgressStyle(description_width='in…

HBox(children=(IntProgress(value=0, description='Epoch 12', max=800, style=ProgressStyle(description_width='in…

In [None]:
# %%time
# # overfitting trials
# data = im_gen_train[0]
# val_data = im_gen_val[0]
# model.fit(
#     x=data[0], 
#     y=data[1], 
# #     validation_data=val_data, 
#     batch_size=data[0].shape[0], 
#     callbacks=[tqdm_cb, tboard_cback, tboard_image_cback],
#     epochs=200, 
#     verbose=2, 
#     shuffle=False,
# )
# print('Original metrics')
# print(psnr(*data))
# print(ssim(*data))