In [1]:
%cd /volatile/home/Zaccharie/workspace/understanding-unets

/volatile/home/Zaccharie/workspace/understanding-unets


In [2]:
%matplotlib nbagg
import os.path as op
import time

from keras.callbacks import TensorBoard
from keras.initializers import Constant
from keras.layers import Conv2D, MaxPooling2D, concatenate, Dropout, UpSampling2D, Input, AveragePooling2D, BatchNormalization
from keras.models import Model
from keras.optimizers import Adam
from keras_tqdm import TQDMNotebookCallback
import matplotlib.pyplot as plt
import numpy as np

from modopt.signal.wavelet import get_mr_filters
from study import handle_source

Using TensorFlow backend.


In [3]:
source = 'cifar_grey'
im_gen_train, im_gen_val, im_gen_test, n_samples_train, size, n_channels = handle_source(source)

In [4]:
# low pass filter
wavelet_id = '2'
data_shape = (size, size)
low_pass_filter = get_mr_filters(data_shape, opt=[f'-t {wavelet_id}', '-n 2'], coarse=True)[-1].astype(np.float32)
print(f'Filter shape: {low_pass_filter.shape}')
keras_low_filter = Constant(low_pass_filter)

Filter shape: (5, 5)


In [5]:
# network params
n_extension = 64
kernel_size = 3

In [6]:
input_im = Input((size, size, n_channels))
extended_input_im = Conv2D(
    n_extension,
    kernel_size, 
    activation='relu',
    padding='same', 
    kernel_initializer='he_normal',
)(input_im)
high_1 = Conv2D(
    n_extension,
    kernel_size, 
    activation='relu',
    padding='same', 
    kernel_initializer='he_normal',
)(extended_input_im)
low_1_layer = Conv2D(
    n_extension, 
    kernel_size, 
    activation='linear',
    padding='same', 
    use_bias=False, 
    kernel_initializer=keras_low_filter,
)
low_1_layer.trainable = False
low_1 = low_1_layer(extended_input_im)
low_1 = AveragePooling2D(pool_size=(2, 2))(low_1)

high_2 = Conv2D(
    n_extension,
    kernel_size, 
    activation='relu',
    padding='same', 
    kernel_initializer='he_normal',
)(low_1)
low_2_layer = Conv2D(
    n_extension, 
    kernel_size, 
    activation='linear',
    padding='same', 
    use_bias=False, 
    kernel_initializer=keras_low_filter,
)
low_2_layer.trainable = False
low_2 = low_2_layer(low_1)

# merge_2 = concatenate([low_2, high_2, low_1], axis=3)  # with skip connection
merge_2 = concatenate([low_2, high_2], axis=3)
merge_2 = UpSampling2D(size=(2, 2))(merge_2)
merge_2 = Conv2D(
    n_extension,
    kernel_size, 
    activation='relu',
    padding='same', 
    kernel_initializer='he_normal',
)(merge_2)

# merge_1 = concatenate([merge_2, high_1, extended_input_im], axis=3)  # with skip connection
merge_1 = concatenate([merge_2, high_1], axis=3)

output = Conv2D(
    n_channels,
    kernel_size, 
    activation='sigmoid',
    padding='same', 
    kernel_initializer='he_normal',
)(merge_1)

model = Model(input_im, output)

model.compile(optimizer=Adam(lr=1e-4), loss='mean_squared_error')

Instructions for updating:
Colocations handled automatically by placer.


In [7]:
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 32, 32, 1)    0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 32, 32, 64)   640         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 32, 32, 64)   36864       conv2d_1[0][0]                   
__________________________________________________________________________________________________
average_pooling2d_1 (AveragePoo (None, 16, 16, 64)   0           conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv2d_5 (

In [8]:
# training params
epochs = 50
batch_size = 32
validation_split = 0.1
run_id = str(int(time.time()))
print(run_id)

1559575607


In [9]:
log_dir = op.join('experiments', 'logs', run_id)
tboard_cback = TensorBoard(
    log_dir=log_dir,
    histogram_freq=0,
    batch_size=batch_size,
    write_graph=True,
    write_images=True,
)
model.fit_generator(
    im_gen_train,
    steps_per_epoch=int((1-validation_split) * n_samples_train / batch_size),
    epochs=epochs,
    validation_data=im_gen_val,
    validation_steps=int(validation_split * n_samples_train / batch_size),
    verbose=2,
    callbacks=[tboard_cback, TQDMNotebookCallback()],
)

Instructions for updating:
Use tf.cast instead.


HBox(children=(IntProgress(value=0, description='Training', max=50, style=ProgressStyle(description_width='ini…

HBox(children=(IntProgress(value=0, description='Epoch 0', max=1406, style=ProgressStyle(description_width='in…

Epoch 1/50
 - 21s - loss: 0.0080 - val_loss: 0.0033
Epoch 2/50


HBox(children=(IntProgress(value=0, description='Epoch 1', max=1406, style=ProgressStyle(description_width='in…

 - 20s - loss: 0.0031 - val_loss: 0.0029
Epoch 3/50


HBox(children=(IntProgress(value=0, description='Epoch 2', max=1406, style=ProgressStyle(description_width='in…

 - 18s - loss: 0.0028 - val_loss: 0.0027
Epoch 4/50


HBox(children=(IntProgress(value=0, description='Epoch 3', max=1406, style=ProgressStyle(description_width='in…

 - 18s - loss: 0.0027 - val_loss: 0.0026
Epoch 5/50


HBox(children=(IntProgress(value=0, description='Epoch 4', max=1406, style=ProgressStyle(description_width='in…

 - 20s - loss: 0.0026 - val_loss: 0.0025
Epoch 6/50


HBox(children=(IntProgress(value=0, description='Epoch 5', max=1406, style=ProgressStyle(description_width='in…

 - 18s - loss: 0.0025 - val_loss: 0.0025
Epoch 7/50


HBox(children=(IntProgress(value=0, description='Epoch 6', max=1406, style=ProgressStyle(description_width='in…

 - 19s - loss: 0.0024 - val_loss: 0.0024
Epoch 8/50


HBox(children=(IntProgress(value=0, description='Epoch 7', max=1406, style=ProgressStyle(description_width='in…

 - 18s - loss: 0.0024 - val_loss: 0.0024
Epoch 9/50


HBox(children=(IntProgress(value=0, description='Epoch 8', max=1406, style=ProgressStyle(description_width='in…

 - 17s - loss: 0.0024 - val_loss: 0.0024
Epoch 10/50


HBox(children=(IntProgress(value=0, description='Epoch 9', max=1406, style=ProgressStyle(description_width='in…

 - 18s - loss: 0.0024 - val_loss: 0.0025
Epoch 11/50


HBox(children=(IntProgress(value=0, description='Epoch 10', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0023 - val_loss: 0.0023
Epoch 12/50


HBox(children=(IntProgress(value=0, description='Epoch 11', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0023 - val_loss: 0.0023
Epoch 13/50


HBox(children=(IntProgress(value=0, description='Epoch 12', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0023 - val_loss: 0.0023
Epoch 14/50


HBox(children=(IntProgress(value=0, description='Epoch 13', max=1406, style=ProgressStyle(description_width='i…

 - 19s - loss: 0.0023 - val_loss: 0.0023
Epoch 15/50


HBox(children=(IntProgress(value=0, description='Epoch 14', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0023 - val_loss: 0.0023
Epoch 16/50


HBox(children=(IntProgress(value=0, description='Epoch 15', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0023 - val_loss: 0.0023
Epoch 17/50


HBox(children=(IntProgress(value=0, description='Epoch 16', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0023 - val_loss: 0.0023
Epoch 18/50


HBox(children=(IntProgress(value=0, description='Epoch 17', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0023 - val_loss: 0.0023
Epoch 19/50


HBox(children=(IntProgress(value=0, description='Epoch 18', max=1406, style=ProgressStyle(description_width='i…

 - 17s - loss: 0.0023 - val_loss: 0.0023
Epoch 20/50


HBox(children=(IntProgress(value=0, description='Epoch 19', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0023 - val_loss: 0.0023
Epoch 21/50


HBox(children=(IntProgress(value=0, description='Epoch 20', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0023 - val_loss: 0.0023
Epoch 22/50


HBox(children=(IntProgress(value=0, description='Epoch 21', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0023 - val_loss: 0.0023
Epoch 23/50


HBox(children=(IntProgress(value=0, description='Epoch 22', max=1406, style=ProgressStyle(description_width='i…

 - 19s - loss: 0.0023 - val_loss: 0.0022
Epoch 24/50


HBox(children=(IntProgress(value=0, description='Epoch 23', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0023 - val_loss: 0.0022
Epoch 25/50


HBox(children=(IntProgress(value=0, description='Epoch 24', max=1406, style=ProgressStyle(description_width='i…

 - 17s - loss: 0.0023 - val_loss: 0.0022
Epoch 26/50


HBox(children=(IntProgress(value=0, description='Epoch 25', max=1406, style=ProgressStyle(description_width='i…

 - 19s - loss: 0.0023 - val_loss: 0.0023
Epoch 27/50


HBox(children=(IntProgress(value=0, description='Epoch 26', max=1406, style=ProgressStyle(description_width='i…

 - 17s - loss: 0.0023 - val_loss: 0.0023
Epoch 28/50


HBox(children=(IntProgress(value=0, description='Epoch 27', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0022 - val_loss: 0.0022
Epoch 29/50


HBox(children=(IntProgress(value=0, description='Epoch 28', max=1406, style=ProgressStyle(description_width='i…

 - 17s - loss: 0.0022 - val_loss: 0.0022
Epoch 30/50


HBox(children=(IntProgress(value=0, description='Epoch 29', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0022 - val_loss: 0.0023
Epoch 31/50


HBox(children=(IntProgress(value=0, description='Epoch 30', max=1406, style=ProgressStyle(description_width='i…

 - 19s - loss: 0.0022 - val_loss: 0.0022
Epoch 32/50


HBox(children=(IntProgress(value=0, description='Epoch 31', max=1406, style=ProgressStyle(description_width='i…

 - 17s - loss: 0.0022 - val_loss: 0.0022
Epoch 33/50


HBox(children=(IntProgress(value=0, description='Epoch 32', max=1406, style=ProgressStyle(description_width='i…

 - 17s - loss: 0.0022 - val_loss: 0.0022
Epoch 34/50


HBox(children=(IntProgress(value=0, description='Epoch 33', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0022 - val_loss: 0.0022
Epoch 35/50


HBox(children=(IntProgress(value=0, description='Epoch 34', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0022 - val_loss: 0.0022
Epoch 36/50


HBox(children=(IntProgress(value=0, description='Epoch 35', max=1406, style=ProgressStyle(description_width='i…

 - 17s - loss: 0.0022 - val_loss: 0.0022
Epoch 37/50


HBox(children=(IntProgress(value=0, description='Epoch 36', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0022 - val_loss: 0.0022
Epoch 38/50


HBox(children=(IntProgress(value=0, description='Epoch 37', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0022 - val_loss: 0.0022
Epoch 39/50


HBox(children=(IntProgress(value=0, description='Epoch 38', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0022 - val_loss: 0.0022
Epoch 40/50


HBox(children=(IntProgress(value=0, description='Epoch 39', max=1406, style=ProgressStyle(description_width='i…

 - 19s - loss: 0.0022 - val_loss: 0.0022
Epoch 41/50


HBox(children=(IntProgress(value=0, description='Epoch 40', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0022 - val_loss: 0.0022
Epoch 42/50


HBox(children=(IntProgress(value=0, description='Epoch 41', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0022 - val_loss: 0.0022
Epoch 43/50


HBox(children=(IntProgress(value=0, description='Epoch 42', max=1406, style=ProgressStyle(description_width='i…

 - 20s - loss: 0.0022 - val_loss: 0.0022
Epoch 44/50


HBox(children=(IntProgress(value=0, description='Epoch 43', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0022 - val_loss: 0.0022
Epoch 45/50


HBox(children=(IntProgress(value=0, description='Epoch 44', max=1406, style=ProgressStyle(description_width='i…

 - 19s - loss: 0.0022 - val_loss: 0.0022
Epoch 46/50


HBox(children=(IntProgress(value=0, description='Epoch 45', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0022 - val_loss: 0.0022
Epoch 47/50


HBox(children=(IntProgress(value=0, description='Epoch 46', max=1406, style=ProgressStyle(description_width='i…

 - 20s - loss: 0.0022 - val_loss: 0.0022
Epoch 48/50


HBox(children=(IntProgress(value=0, description='Epoch 47', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0022 - val_loss: 0.0022
Epoch 49/50


HBox(children=(IntProgress(value=0, description='Epoch 48', max=1406, style=ProgressStyle(description_width='i…

 - 18s - loss: 0.0022 - val_loss: 0.0022
Epoch 50/50


HBox(children=(IntProgress(value=0, description='Epoch 49', max=1406, style=ProgressStyle(description_width='i…

 - 19s - loss: 0.0022 - val_loss: 0.0022



<keras.callbacks.History at 0x7f87087bb470>