# deep_deblur_keras
* [Deep Generative Filter for motion deblurring](https://arxiv.org/pdf/1709.03481.pdf)
*conference:ICCV2017

## 导入相关的包

In [None]:
import glob as gb
import numpy as np
from PIL import Image
import os
import h5py
from tqdm import tqdm

## 构建数据集

In [None]:
# normalization x to [-1,1]
def normalization(x):
    return x / 127.5 - 1
# according the image path to read the image and covert it
# to the given size, then slice it, finally return the full and blur images
def format_image(image_path, size):
    image = Image.open(image_path)
    # slice image into full and blur images
    image_full = image.crop((0, 0, image.size[0] / 2, image.size[1]))
    # Note the full image in left, the blur image in right
    image_blur = image.crop((image.size[0] / 2, 0, image.size[0], image.size[1]))
    # image_full.show()
    # image_blur.show()
    image_full = image_full.resize((size, size), Image.ANTIALIAS)
    image_blur = image_blur.resize((size, size), Image.ANTIALIAS)
    # return the numpy arrays
    return np.array(image_full), np.array(image_blur)

# convert images to hdf5 data
def build_hdf5(jpeg_dir, size=256):
    # put data in HDF5
    hdf5_file = os.path.join('data', 'data.h5')
    with h5py.File(hdf5_file, 'w') as f:
        for data_type in tqdm(['train', 'test'], desc='create HDF5 dataset from images'):
            data_path = jpeg_dir + '/%s/*.jpg' % data_type
            images_path = gb.glob(data_path)
            # print(images_path)
            data_full = []
            data_blur = []
            for image_path in images_path:
                image_full, image_blur = format_image(image_path, size)
                data_full.append(image_full)
                data_blur.append(image_blur)
            # print(len(data_full))
            # print(len(data_blur))
            f.create_dataset('%s_data_full' % data_type, data=data_full)
            f.create_dataset('%s_data_blur' % data_type, data=data_blur)

# load data by data type
def load_data(data_type):
    with h5py.File('data/data.h5', 'r') as f:
        data_full = f['%s_data_full' % data_type][:].astype(np.float32)
        data_full = normalization(data_full)
        data_blur = f['%s_data_blur' % data_type][:].astype(np.float32)
        data_blur = normalization(data_blur)
        return data_full, data_blur

def generate_image(full, blur, generated, path, epoch=None, index=None):
    full = full * 127.5 + 127.5
    blur = blur * 127.5 + 127.5
    generated = generated * 127.5 + 127.5
    for i in range(generated.shape[0]):
        image_full = full[i, :, :, :]
        image_blur = blur[i, :, :, :]
        image_generated = generated[i, :, :, :]
        image = np.concatenate((image_full, image_blur, image_generated), axis=1)
        if (epoch is not None) and (index is not None):
            Image.fromarray(image.astype(np.uint8)).save(path + str(epoch + 1) + '_' + str(index + 1) + '.png')
        else:
            Image.fromarray(image.astype(np.uint8)).save(path + str(i) + '.png')
# format_image('data/small/test/301.jpg', size=256)
# build_hdf5('data/small')
# img_full, img_blur = load_data('train')
# print(img_full, '\n', len(img_blur))

## 定义网络结构

In [None]:
from keras.layers import Input, concatenate
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Convolution2D
from keras.layers.core import Dropout, Dense, Flatten, Lambda
from keras.layers.merge import Average
from keras.layers.normalization import BatchNormalization
from keras.models import Model
from keras.utils.vis_utils import plot_model
# the paper defined hyper-parameter:chr
channel_rate = 64
# Note the image_shape must be multiple of patch_shape
image_shape = (256, 256, 3)
patch_shape = (channel_rate, channel_rate, 3)

### 生成器

In [None]:
# Dense Block
def dense_block(inputs, dilation_factor=None):
    x = LeakyReLU(alpha=0.2)(inputs)
    x = Convolution2D(filters=4 * channel_rate, kernel_size=(1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    # the 3 × 3 convolutions along the dense field are alternated between ‘spatial’ convolution
    # and ‘dilated’ convolution with linearly increasing dilation factor
    if dilation_factor is not None:
        x = Convolution2D(filters=channel_rate, kernel_size=(3, 3), padding='same',
                          dilation_rate=dilation_factor)(x)
    else:
        x = Convolution2D(filters=channel_rate, kernel_size=(3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    # add Gaussian noise
    x = Dropout(rate=0.5)(x)
    return x

def generator_model():
    # Input Image, Note the shape is variable
    inputs = Input(shape=(None, None, 3))
    # The Head
    h = Convolution2D(filters=4 * channel_rate, kernel_size=(3, 3), padding='same')(inputs)
    # The Dense Field
    d_1 = dense_block(inputs=h)
    x = concatenate([h, d_1])
    # the paper used dilated convolution at every even numbered layer within the dense field
    d_2 = dense_block(inputs=x, dilation_factor=(1, 1))
    x = concatenate([x, d_2])
    d_3 = dense_block(inputs=x)
    x = concatenate([x, d_3])
    d_4 = dense_block(inputs=x, dilation_factor=(2, 2))
    x = concatenate([x, d_4])
    d_5 = dense_block(inputs=x)
    x = concatenate([x, d_5])
    d_6 = dense_block(inputs=x, dilation_factor=(3, 3))
    x = concatenate([x, d_6])
    d_7 = dense_block(inputs=x)
    x = concatenate([x, d_7])
    d_8 = dense_block(inputs=x, dilation_factor=(2, 2))
    x = concatenate([x, d_8])
    d_9 = dense_block(inputs=x)
    x = concatenate([x, d_9])
    d_10 = dense_block(inputs=x, dilation_factor=(1, 1))
    # The Tail
    x = LeakyReLU(alpha=0.2)(d_10)
    x = Convolution2D(filters=4 * channel_rate, kernel_size=(1, 1), padding='same')(x)
    x = BatchNormalization()(x)
    # The Global Skip Connection
    x = concatenate([h, x])
    x = Convolution2D(filters=channel_rate, kernel_size=(3, 3), padding='same')(x)
    # PReLU can't be used, because it is connected with the input shape
    # x = PReLU()(x)
    x = LeakyReLU(alpha=0.2)(x)

    # Output Image
    outputs = Convolution2D(filters=3, kernel_size=(3, 3), padding='same', activation='tanh')(x)
    model = Model(inputs=inputs, outputs=outputs, name='Generator')
    return model

### 判别器

In [None]:
def discriminator_model():
    # PatchGAN
    inputs = Input(shape=patch_shape)
    x = Convolution2D(filters=channel_rate, kernel_size=(3, 3), strides=(2, 2), padding="same")(inputs)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Convolution2D(filters=2 * channel_rate, kernel_size=(3, 3), strides=(2, 2), padding="same")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Convolution2D(filters=4 * channel_rate, kernel_size=(3, 3), strides=(2, 2), padding="same")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Convolution2D(filters=4 * channel_rate, kernel_size=(3, 3), strides=(2, 2), padding="same")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Flatten()(x)
    outputs = Dense(units=1, activation='sigmoid')(x)
    model = Model(inputs=inputs, outputs=outputs, name='PatchGAN')
    # model.summary()
    # discriminator
    inputs = Input(shape=image_shape)
    list_row_idx = [(i * channel_rate, (i + 1) * channel_rate) for i in
                    range(int(image_shape[0] / patch_shape[0]))]
    list_col_idx = [(i * channel_rate, (i + 1) * channel_rate) for i in
                    range(int(image_shape[1] / patch_shape[1]))]
    list_patch = []
    for row_idx in list_row_idx:
        for col_idx in list_col_idx:
            x_patch = Lambda(lambda z: z[:, row_idx[0]:row_idx[1], col_idx[0]:col_idx[1], :])(inputs)
            list_patch.append(x_patch)

    x = [model(patch) for patch in list_patch]
    outputs = Average()(x)
    model = Model(inputs=inputs, outputs=outputs, name='Discriminator')
    return model

### 结合

In [None]:
def generator_containing_discriminator(generator, discriminator):
    inputs = Input(shape=image_shape)
    generated_image = generator(inputs)
    outputs = discriminator(generated_image)
    model = Model(inputs=inputs, outputs=outputs)
    return model

# g = generator_model()
# g.summary()
# d = discriminator_model()
# d.summary()
# plot_model(d)
# m = generator_containing_discriminator(generator_model(), discriminator_model())
# m.summary()

## 训练网络

### Base_functions

In [None]:
import keras.backend as K
import numpy as np
from keras.applications.vgg16 import VGG16
from keras.models import Model
image_shape = (256, 256, 3)
K_1 = 145
K_2 = 170
def l1_loss(y_true, y_pred):
    return K.mean(K.abs(y_pred - y_true))
def perceptual_loss(y_true, y_pred):
    vgg = VGG16(include_top=False, weights='imagenet', input_shape=image_shape)
    loss_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output)
    # let the loss model can't be trained
    loss_model.trainable = False
    # loss_model.summary()
    return K.mean(K.square(loss_model(y_true) - loss_model(y_pred)))
def generator_loss(y_true, y_pred):
    return K_1 * perceptual_loss(y_true, y_pred) + K_2 * l1_loss(y_true, y_pred)
def adversarial_loss(y_true, y_pred):
    return -K.log(y_pred)
#a, b = data_utils.format_image('data/small/test/301.jpg', size=256)
#print(l1_loss(a.astype(np.float32), b.astype(np.float32)))

### 训练整个网络

In [None]:
def train(batch_size, epoch_num):
    # Note the x(blur) in the second, the y(full) in the first
    y_train, x_train = data_utils.load_data(data_type='train')
    # GAN
    g = generator_model()
    d = discriminator_model()
    d_on_g = generator_containing_discriminator(g, d)
    # compile the models, use default optimizer parameters
    # generator use adversarial loss
    g.compile(optimizer='adam', loss=generator_loss)
    # discriminator use binary cross entropy loss
    d.compile(optimizer='adam', loss='binary_crossentropy')
    # adversarial net use adversarial loss
    d_on_g.compile(optimizer='adam', loss=adversarial_loss)

    for epoch in range(epoch_num):
        print('epoch: ', epoch + 1, '/', epoch_num)
        print('batches: ', int(x_train.shape[0] / batch_size))
        for index in range(int(x_train.shape[0] / batch_size)):
            # select a batch data
            image_blur_batch = x_train[index * batch_size:(index + 1) * batch_size]
            image_full_batch = y_train[index * batch_size:(index + 1) * batch_size]
            generated_images = g.predict(x=image_blur_batch, batch_size=batch_size)

            # output generated images for each 30 iters
            if (index % 30 == 0) and (index != 0):
                data_utils.generate_image(image_full_batch, image_blur_batch, generated_images,
                                          'result/interim/', epoch, index)
            # concatenate the full and generated images,
            # the full images at top, the generated images at bottom
            x = np.concatenate((image_full_batch, generated_images))

            # generate labels for the full and generated images
            y = [1] * batch_size + [0] * batch_size
            # train discriminator
            d_loss = d.train_on_batch(x, y)
            print('batch %d d_loss : %f' % (index + 1, d_loss))
            # let discriminator can't be trained
            d.trainable = False
            # train adversarial net
            d_on_g_loss = d_on_g.train_on_batch(image_blur_batch, [1] * batch_size)
            print('batch %d d_on_g_loss : %f' % (index + 1, d_on_g_loss))
            # train generator
            g_loss = g.train_on_batch(image_blur_batch, image_full_batch)
            print('batch %d g_loss : %f' % (index + 1, g_loss))
            # let discriminator can be trained
            d.trainable = True
            # output weights for generator and discriminator each 30 iters
            if (index % 30 == 0) and (index != 0):
                g.save_weights('weight/generator_weights.h5', True)
                d.save_weights('weight/discriminator_weights.h5', True)

# 测试网络

In [None]:
def test(batch_size):
    # Note the x(blur) in the second, the y(full) in the first
    y_test, x_test = data_utils.load_data(data_type='test')
    g = generator_model()
    g.load_weights('weight/generator_weights.h5')
    generated_images = g.predict(x=x_test, batch_size=batch_size)
    data_utils.generate_image(y_test, x_test, generated_images, 'result/finally/')

# 使用网络

In [None]:
def test_pictures(batch_size):
    data_path = 'data/test/*.jpeg'
    images_path = gb.glob(data_path)
    data_blur = []
    for image_path in images_path:
        image_blur = Image.open(image_path)
        data_blur.append(np.array(image_blur))

    data_blur = np.array(data_blur).astype(np.float32)
    data_blur = data_utils.normalization(data_blur)

    g = generator_model()
    g.load_weights('weight/generator_weights.h5')
    generated_images = g.predict(x=data_blur, batch_size=batch_size)
    generated = generated_images * 127.5 + 127.5
    for i in range(generated.shape[0]):
        image_generated = generated[i, :, :, :]
        Image.fromarray(image_generated.astype(np.uint8)).save('result/test/' + str(i) + '.png')