In [1]:
%cd ..

/volatile/home/Zaccharie/workspace/understanding-unets


In [2]:
# # this just to make sure we are using only on CPU
# import os
# os.environ["CUDA_VISIBLE_DEVICES"]="-1"

In [3]:
%load_ext autoreload
%autoreload 2
import os.path as op
import time

from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, LearningRateScheduler
import tensorflow_addons as tfa
import tensorflow as tf
from tqdm import tqdm_notebook

from learning_wavelets.data.datasets import im_dataset_div2k, im_dataset_bsd500
from learning_wavelets.evaluate import psnr, ssim
from learning_wavelets.keras_utils.image_tboard_cback import TensorBoardImage
from learning_wavelets.models.unet import unet

In [4]:
tf.random.set_seed(1)

In [5]:
noise_std_train = (0, 55)
noise_std_val = 30
batch_size = 8
source = 'bsd500'
if source == 'bsd500':
    data_func = im_dataset_bsd500
    n_samples_train = 400
else:
    data_func = im_dataset_div2k
    n_samples_train = 800
im_ds_train = data_func(
    mode='training', 
    batch_size=batch_size, 
    patch_size=256, 
    noise_std=noise_std_train, 
    return_noise_level=False,
)
im_ds_val = data_func(
    mode='validation', 
    batch_size=1, 
    patch_size=256, 
    noise_std=noise_std_val, 
    return_noise_level=False,
)

In [6]:
run_params = {
    'n_layers': 5, 
#     'n_layers': 2, 
    'pool': 'max', 
    "layers_n_channels": [64, 128, 256, 512, 1024], 
#     "layers_n_channels": [16, 32], 
    'layers_n_non_lins': 2,
    'non_relu_contract': False,
    'bn': True,
}
n_epochs = 2
run_id = f'unet_dynamic_st_{source}_{noise_std_train[0]}_{noise_std_train[1]}_{int(time.time())}'
chkpt_path = f'checkpoints/{run_id}' + '-{epoch:02d}.hdf5'
print(run_id)

unet_dynamic_st_bsd500_0_55_1582625310


In [7]:
def l_rate_schedule(epoch):
    return max(1e-3 / 2**(epoch//25), 1e-5)
lrate_cback = LearningRateScheduler(l_rate_schedule)

In [8]:
chkpt_cback = ModelCheckpoint(chkpt_path, period=n_epochs, save_weights_only=False)
log_dir = op.join('logs', run_id)
tboard_cback = TensorBoard(
    log_dir=log_dir, 
    histogram_freq=0, 
    write_graph=False, 
    write_images=False, 
    profile_batch=0,
)
tqdm_cb = tfa.callbacks.TQDMProgressBar(metrics_format="{name}: {value:e}")
# val_noisy, val_gt = next(iter(im_ds_val))
# tboard_image_cback = TensorBoardImage(
#     log_dir=log_dir + '/images',
#     image=val_gt[0:1],
#     noisy_image=val_noisy[0:1],
# )



In [9]:
n_channels = 1
model = unet(input_size=(None, None, n_channels), lr=1e-3, **run_params)
print(model.summary(line_length=114))

Model: "model"
__________________________________________________________________________________________________________________
Layer (type)                         Output Shape             Param #       Connected to                          
input_1 (InputLayer)                 [(None, None, None, 1)]  0                                                   
__________________________________________________________________________________________________________________
conv2d (Conv2D)                      (None, None, None, 64)   640           input_1[0][0]                         
__________________________________________________________________________________________________________________
conv2d_1 (Conv2D)                    (None, None, None, 64)   36928         conv2d[0][0]                          
__________________________________________________________________________________________________________________
max_pooling2d (MaxPooling2D)         (None, None, None, 64)   0  

In [10]:
%%time
model.fit(
    im_ds_train, 
    steps_per_epoch=20, 
#     steps_per_epoch=5, 
    epochs=n_epochs,
    validation_data=im_ds_val,
#     validation_steps=int(validation_split * n_samples_train / batch_size),
    validation_steps=1,
    verbose=0,
#     callbacks=[tqdm_cb, tboard_cback, chkpt_cback, lrate_cback],
    callbacks=[tqdm_cb, lrate_cback],
#     callbacks=[tqdm_cb, tboard_cback, chkpt_cback, tboard_image_cback, norm_cback, lrate_cback],
    shuffle=False,
)

HBox(children=(FloatProgress(value=0.0, description='Training', layout=Layout(flex='2'), max=2.0, style=Progre…

Epoch 1/2


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=20.0), HTML(value='')), layout=Layout(dis…


Epoch 2/2


HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=20.0), HTML(value='')), layout=Layout(dis…



CPU times: user 25.5 s, sys: 7.45 s, total: 32.9 s
Wall time: 24.1 s


<tensorflow.python.keras.callbacks.History at 0x7f127032a978>

In [11]:
# %%time
# # overfitting trials
# data = next(iter(im_ds_train))
# val_data = next(iter(im_ds_val))
# model.fit(
#     x=data[0], 
#     y=data[1], 
# #     validation_data=val_data, 
#     batch_size=1, 
# #     callbacks=[tqdm_cb, tboard_cback, tboard_image_cback, norm_cback, lrate_cback],
#     callbacks=[tqdm_cb, tboard_cback,],
#     epochs=50, 
#     verbose=2, 
#     shuffle=False,
# )
# # print('Original metrics')
# # print(psnr(data[0].numpy(), data[1].numpy()))
# # print(ssim(data[0].numpy(), data[1].numpy()))