In [1]:
%cd ..

/home/zaccharie/workspace/understanding-unets


In [2]:
%matplotlib nbagg
import math
import os.path as op
import time

from keras.callbacks import TensorBoard, LearningRateScheduler
from keras.initializers import Constant
from keras.layers import Conv2D, MaxPooling2D, concatenate, Dropout, UpSampling2D, Input, AveragePooling2D, BatchNormalization
from keras.models import Model
from keras.optimizers import Adam
from keras_tqdm import TQDMNotebookCallback
import matplotlib.pyplot as plt
import numpy as np

from modopt.signal.wavelet import get_mr_filters
from study import handle_source

Using TensorFlow backend.


In [3]:
source = 'cifar_grey'
im_gen_train, im_gen_val, im_gen_test, n_samples_train, size, n_channels = handle_source(source)

In [4]:
# low pass filter
wavelet_id = '2'
data_shape = (size, size)
high_pass_filter, low_pass_filter = get_mr_filters(data_shape, opt=[f'-t {wavelet_id}', '-n 2'], coarse=True)
keras_low_filter = Constant(low_pass_filter.astype(np.float32))
keras_high_filter = Constant(high_pass_filter.astype(np.float32))

In [5]:
# network params
n_extension = 4 * n_channels
kernel_size = 3

In [6]:
input_im = Input((size, size, n_channels))

# First decomposition
## high pass
high_1_layer = Conv2D(
    n_channels,
    5, 
    activation='linear',
    padding='same', 
    use_bias=False, 
    kernel_initializer=keras_high_filter,
)
high_1_layer.trainable = False
high_1 = high_1_layer(input_im)
high_1 = Conv2D(  # apply learned filter on high frequency component
    n_extension,
    kernel_size, 
    activation='relu',
    padding='same', 
    kernel_initializer='he_normal',
)(high_1)

## low pass
low_1_layer = Conv2D(
    n_channels, 
    5, 
    activation='linear',
    padding='same', 
    use_bias=False, 
    kernel_initializer=keras_low_filter,
)
low_1_layer.trainable = False
low_1 = low_1_layer(input_im)
low_1 = AveragePooling2D(pool_size=(2, 2))(low_1)


# Second decomposition
## high pass
high_2_layer = Conv2D(
    n_channels,
    5, 
    activation='linear',
    padding='same', 
    use_bias=False, 
    kernel_initializer=keras_high_filter,
)
high_2_layer.trainable = False
high_2 = high_2_layer(low_1)
high_2 = Conv2D(  # apply learned filter on high frequency component
    n_extension,
    kernel_size, 
    activation='relu',
    padding='same', 
    kernel_initializer='he_normal',
)(high_2)
## low pass
low_2_layer = Conv2D(
    n_channels, 
    5, 
    activation='linear',
    padding='same', 
    use_bias=False, 
    kernel_initializer=keras_low_filter,
)
low_2_layer.trainable = False
low_2 = low_2_layer(low_1)

merge_2 = concatenate([low_2, high_2], axis=3)
merge_2 = UpSampling2D(size=(2, 2))(merge_2)
merge_2 = Conv2D(
    n_extension,
    kernel_size, 
    activation='relu',
    padding='same', 
    kernel_initializer='he_normal',
)(merge_2)

merge_1 = concatenate([merge_2, high_1], axis=3)

output = Conv2D(
    n_channels,
    kernel_size, 
    activation='sigmoid',
    padding='same', 
    kernel_initializer='he_normal',
)(merge_1)

model = Model(input_im, output)

model.compile(optimizer=Adam(lr=1e-4), loss='mean_squared_error')

Instructions for updating:
Colocations handled automatically by placer.


In [7]:
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 32, 32, 1)    0                                            
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 32, 32, 1)    25          input_1[0][0]                    
__________________________________________________________________________________________________
average_pooling2d_1 (AveragePoo (None, 16, 16, 1)    0           conv2d_3[0][0]                   
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 16, 16, 1)    25          average_pooling2d_1[0][0]        
__________________________________________________________________________________________________
conv2d_6 (

In [8]:
# training params
epochs = 50
batch_size = 32
validation_split = 0.1
run_id = str(int(time.time()))
run_id = 'wavedecnet_with_lrate'
print(run_id)

wavedecnet_with_lrate


In [9]:
def step_decay(epoch):
    initial_lrate = 0.1
    drop = 0.5
    epochs_drop = 5.0
    lrate = initial_lrate * math.pow(drop,  
           math.floor((1+epoch)/epochs_drop))
    return lrate

lrate = LearningRateScheduler(step_decay, verbose=1)

In [10]:
log_dir = op.join('experiments', 'logs', run_id)
tboard_cback = TensorBoard(
    log_dir=log_dir,
    histogram_freq=0,
    batch_size=batch_size,
    write_graph=True,
    write_images=True,
)
model.fit_generator(
    im_gen_train,
    steps_per_epoch=int((1-validation_split) * n_samples_train / batch_size),
    epochs=epochs,
    validation_data=im_gen_val,
    validation_steps=int(validation_split * n_samples_train / batch_size),
    verbose=2,
    callbacks=[
        tboard_cback, 
#         lrate,
        TQDMNotebookCallback(),
    ],
)

Instructions for updating:
Use tf.cast instead.


HBox(children=(IntProgress(value=0, description='Training', max=50, style=ProgressStyle(description_width='ini…

Epoch 1/50

Epoch 00001: LearningRateScheduler setting learning rate to 0.1.


HBox(children=(IntProgress(value=0, description='Epoch 0', max=1406, style=ProgressStyle(description_width='in…

 - 38s - loss: 0.0053 - val_loss: 0.0035
Epoch 2/50

Epoch 00002: LearningRateScheduler setting learning rate to 0.1.


HBox(children=(IntProgress(value=0, description='Epoch 1', max=1406, style=ProgressStyle(description_width='in…

 - 39s - loss: 0.0037 - val_loss: 0.0035
Epoch 3/50

Epoch 00003: LearningRateScheduler setting learning rate to 0.1.


HBox(children=(IntProgress(value=0, description='Epoch 2', max=1406, style=ProgressStyle(description_width='in…

 - 38s - loss: 0.0037 - val_loss: 0.0042
Epoch 4/50

Epoch 00004: LearningRateScheduler setting learning rate to 0.1.


HBox(children=(IntProgress(value=0, description='Epoch 3', max=1406, style=ProgressStyle(description_width='in…

 - 38s - loss: 0.0037 - val_loss: 0.0033
Epoch 5/50

Epoch 00005: LearningRateScheduler setting learning rate to 0.05.


HBox(children=(IntProgress(value=0, description='Epoch 4', max=1406, style=ProgressStyle(description_width='in…

 - 39s - loss: 0.0032 - val_loss: 0.0032
Epoch 6/50

Epoch 00006: LearningRateScheduler setting learning rate to 0.05.


HBox(children=(IntProgress(value=0, description='Epoch 5', max=1406, style=ProgressStyle(description_width='in…

 - 39s - loss: 0.0032 - val_loss: 0.0032
Epoch 7/50

Epoch 00007: LearningRateScheduler setting learning rate to 0.05.


HBox(children=(IntProgress(value=0, description='Epoch 6', max=1406, style=ProgressStyle(description_width='in…

 - 38s - loss: 0.0032 - val_loss: 0.0034
Epoch 8/50

Epoch 00008: LearningRateScheduler setting learning rate to 0.05.


HBox(children=(IntProgress(value=0, description='Epoch 7', max=1406, style=ProgressStyle(description_width='in…

 - 38s - loss: 0.0033 - val_loss: 0.0031
Epoch 9/50

Epoch 00009: LearningRateScheduler setting learning rate to 0.05.


HBox(children=(IntProgress(value=0, description='Epoch 8', max=1406, style=ProgressStyle(description_width='in…

 - 38s - loss: 0.0033 - val_loss: 0.0034
Epoch 10/50

Epoch 00010: LearningRateScheduler setting learning rate to 0.025.


HBox(children=(IntProgress(value=0, description='Epoch 9', max=1406, style=ProgressStyle(description_width='in…

 - 38s - loss: 0.0031 - val_loss: 0.0035
Epoch 11/50

Epoch 00011: LearningRateScheduler setting learning rate to 0.025.


HBox(children=(IntProgress(value=0, description='Epoch 10', max=1406, style=ProgressStyle(description_width='i…

 - 38s - loss: 0.0032 - val_loss: 0.0031
Epoch 12/50

Epoch 00012: LearningRateScheduler setting learning rate to 0.025.


HBox(children=(IntProgress(value=0, description='Epoch 11', max=1406, style=ProgressStyle(description_width='i…

 - 39s - loss: 0.0032 - val_loss: 0.0031
Epoch 13/50

Epoch 00013: LearningRateScheduler setting learning rate to 0.025.


HBox(children=(IntProgress(value=0, description='Epoch 12', max=1406, style=ProgressStyle(description_width='i…

 - 37s - loss: 0.0032 - val_loss: 0.0032
Epoch 14/50

Epoch 00014: LearningRateScheduler setting learning rate to 0.025.


HBox(children=(IntProgress(value=0, description='Epoch 13', max=1406, style=ProgressStyle(description_width='i…

 - 39s - loss: 0.0032 - val_loss: 0.0032
Epoch 15/50

Epoch 00015: LearningRateScheduler setting learning rate to 0.0125.


HBox(children=(IntProgress(value=0, description='Epoch 14', max=1406, style=ProgressStyle(description_width='i…

 - 38s - loss: 0.0031 - val_loss: 0.0032
Epoch 16/50

Epoch 00016: LearningRateScheduler setting learning rate to 0.0125.


HBox(children=(IntProgress(value=0, description='Epoch 15', max=1406, style=ProgressStyle(description_width='i…

 - 40s - loss: 0.0031 - val_loss: 0.0031
Epoch 17/50

Epoch 00017: LearningRateScheduler setting learning rate to 0.0125.


HBox(children=(IntProgress(value=0, description='Epoch 16', max=1406, style=ProgressStyle(description_width='i…

KeyboardInterrupt: 