## WESPE_GAN: Weakly Supervised Photo Enhancer for Digital Cameras

In [51]:
import tensorflow as tf
import keras

In [52]:
import matplotlib.pyplot as plt
import numpy as np
import scipy
import glob
import cv2
import os

In [53]:
import scipy.stats as st
from keras.layers import Conv2D, BatchNormalization, Activation, Add, Flatten, Dense, Input, Lambda, DepthwiseConv2D
from keras.initializers import glorot_uniform
from keras.models import Model

In [54]:
class WESPE_GAN(object):
  
  def __init__(self, input_shape):
    self.input_shape = input_shape
    self.batch_size = 30
    self.patch_size = 96
    self.w_content = 0.1
    self.w_texture = 3
    self.w_color = 20
    self.w_tv = 1/400
    self.optimizer = keras.optimizers.Adam(lr=1e-4)
    
    self.discriminator_Dc = self.build_discriminator("Dc", "blur")
    self.discriminator_Dt = self.build_discriminator("Dt", "grayscale")
    self.discriminator_Dc.compile(loss=keras.losses.binary_crossentropy, optimizer=self.optimizer, metrics=['accuracy'])
    self.discriminator_Dt.compile(loss=keras.losses.binary_crossentropy, optimizer=self.optimizer, metrics=['accuracy'])
    
    self.generator = self.build_generator("G")
    self.inverse_generator = self.build_generator("F")
    img = Input(shape=self.input_shape)
    # dslr_img = Input(shape=self.input_shape)
    enhanced_img = self.generator(img)
    self.discriminator_Dc.trainable = False
    self.discriminator_Dt.trainable = False
    Dc_valid = self.discriminator_Dc(enhanced_img)
    Dt_valid = self.discriminator_Dt(enhanced_img)
    # Dc_valid_dslr = self.discriminator_Dc(dslr_img)
    # Dt_valid_dslr = self.discriminator_Dt(dslr_img)
    reconstructed_img = self.inverse_generator(enhanced_img)
    self.combined_model = Model(inputs=img, outputs=[Dc_valid, Dt_valid, reconstructed_img])
    self.combined_model.compile(loss=self.build_generator_loss(img, enhanced_img, reconstructed_img), optimizer=self.optimizer)
  
  def gaussian_kernel(self, kernlen=21, nsig=3, channels=1):
    interval = (2*nsig+1.)/(kernlen)
    x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1)
    kern1d = np.diff(st.norm.cdf(x))
    kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
    kernel = kernel_raw/kernel_raw.sum()
    out_filter = np.array(kernel, dtype = np.float32)
    out_filter = out_filter.reshape((kernlen, kernlen, 1, 1))
    out_filter = np.repeat(out_filter, channels, axis = 2)
    return out_filter

  def gaussian_blur(self, X):
    kernel_var = self.gaussian_kernel(21, 3, 3)
    return Lambda(lambda x:tf.nn.depthwise_conv2d(x, kernel_var , [1, 1, 1, 1], padding='SAME'))(X)
    
  def rgb_to_grayscale(self, X):
    return Lambda(lambda x:tf.image.rgb_to_grayscale(x))(X)
  
  def get_block(self, X, nn_name, block_number):
    X_shortcut = X
    X = Conv2D(64, kernel_size=3, strides=(1, 1), data_format="channels_last", use_bias=True, padding="same", kernel_initializer=glorot_uniform(seed=0), 
              name=nn_name+"_conv"+str(block_number)+"_1")(X)
    X = BatchNormalization(axis=-1, center=True, scale=True, name=nn_name+"_bn"+str(block_number)+"_1")(X) # axis=-1 as data_format="channels_last"
    X = Activation('relu')(X)
    X = Conv2D(64, kernel_size=3, strides=(1, 1), data_format="channels_last", use_bias=True, padding="same", kernel_initializer=glorot_uniform(seed=0), 
              name=nn_name+"_conv"+str(block_number)+"_2")(X)
    X = BatchNormalization(axis=-1, center=True, scale=True, name=nn_name+"_bn"+str(block_number)+"_2")(X)
    X = Activation('relu')(X)
    X = Add()([X, X_shortcut])
    return X
  
  def build_generator(self, nn_name):
    X_Input = Input(self.input_shape)
    X = Conv2D(64, kernel_size=9, strides=(1, 1), data_format="channels_last", activation="relu", use_bias=True, padding="same", 
               kernel_initializer=glorot_uniform(seed=0), name=nn_name+"_conv0")(X_Input)
    X = self.get_block(X, nn_name, 1)
    X = self.get_block(X, nn_name, 2)
    X = self.get_block(X, nn_name, 3)
    X = self.get_block(X, nn_name, 4)
    X = Conv2D(64, kernel_size=3, strides=(1, 1), data_format="channels_last", activation="relu", use_bias=True, padding="same", 
               kernel_initializer=glorot_uniform(seed=0), name=nn_name+"_conv5")(X)
    X = Conv2D(64, kernel_size=3, strides=(1, 1), data_format="channels_last", activation="relu", use_bias=True, padding="same", 
               kernel_initializer=glorot_uniform(seed=0), name=nn_name+"_conv6")(X)
    X = Conv2D(64, kernel_size=9, strides=(1, 1), data_format="channels_last", activation="relu", use_bias=True, padding="same", 
               kernel_initializer=glorot_uniform(seed=0), name=nn_name+"_conv7")(X)
    X = Conv2D(3, kernel_size=1, strides=(1, 1), data_format="channels_last", use_bias=True, padding="same", 
               kernel_initializer=glorot_uniform(seed=0), name=nn_name+"_conv8")(X)
    model = Model(inputs=X_Input, outputs=X, name='Network_'+nn_name)
    return model
  
  def build_discriminator(self, nn_name, preprocess):
    X_Input = Input((96, 96, 3))
    if preprocess == "grayscale":
        X = self.rgb_to_grayscale(X_Input)
    elif preprocess == 'blur':
        X = self.gaussian_blur(X_Input)
    else:
        X = X_Input
    X = Conv2D(48, kernel_size=11, strides=(4, 4), data_format="channels_last", activation=keras.layers.LeakyReLU(alpha=0.2), use_bias=True,
               padding="same", kernel_initializer=glorot_uniform(seed=0), name=nn_name+"_conv0")(X)
    X = Conv2D(128, kernel_size=5, strides=(2, 2), data_format="channels_last", use_bias=True, padding="same", 
               kernel_initializer=glorot_uniform(seed=0), name=nn_name+"_conv1")(X)
    X = BatchNormalization(axis=-1, center=True, scale=True, name=nn_name+"_bn1")(X)
    X = keras.layers.LeakyReLU(alpha=0.2)(X)
    X = Conv2D(192, kernel_size=3, strides=(1, 1), data_format="channels_last", use_bias=True, padding="same", 
               kernel_initializer=glorot_uniform(seed=0), name=nn_name+"_conv2")(X)
    X = BatchNormalization(axis=-1, center=True, scale=True, name=nn_name+"_bn2")(X)
    X = keras.layers.LeakyReLU(alpha=0.2)(X)
    X = Conv2D(192, kernel_size=3, strides=(1, 1), data_format="channels_last", use_bias=True, padding="same", 
               kernel_initializer=glorot_uniform(seed=0), name=nn_name+"_conv3")(X)
    X = BatchNormalization(axis=-1, center=True, scale=True, name=nn_name+"_bn3")(X)
    X = keras.layers.LeakyReLU(alpha=0.2)(X)
    X = Conv2D(128, kernel_size=3, strides=(2, 2), data_format="channels_last", use_bias=True, padding="same", 
               kernel_initializer=glorot_uniform(seed=0), name=nn_name+"_conv4")(X)
    X = BatchNormalization(axis=-1, center=True, scale=True, name=nn_name+"_bn4")(X)
    X = keras.layers.LeakyReLU(alpha=0.2)(X)
    X = Flatten(data_format="channels_last")(X)
    X = Dense(1024, activation=keras.layers.LeakyReLU(alpha=0.2))(X)
    X_out = Dense(1, activation='sigmoid')(X)
    model = Model(X_Input, X_out, name='Network_'+nn_name)
    return model
  
  def build_generator_loss(self, img, enhanced_img, reconstructed_img):
    mobile_net_v2 = keras.applications.mobilenet_v2.MobileNetV2(input_shape=(96, 96, 3), include_top=False, weights='imagenet')
    mobile_net_v2.trainable = False
    mobile_net_v2_truncated = Model(mobile_net_v2.input, mobile_net_v2.get_layer('Conv_1').output)
    content_loss = keras.backend.mean(keras.backend.square(mobile_net_v2_truncated(img) - mobile_net_v2_truncated(reconstructed_img)))
    tv_loss = keras.backend.mean(tf.image.total_variation(enhanced_img))
    
    def my_loss(y_true, y_pred):
      color_loss = keras.backend.mean(keras.losses.binary_crossentropy(y_true[0], y_pred[0]))
      texture_loss = keras.backend.mean(keras.losses.binary_crossentropy(y_true[1], y_pred[1]))
      generator_loss = self.w_content * content_loss + self.w_color * color_loss + self.w_texture * texture_loss + self.w_tv * tv_loss
      return generator_loss
    
    return my_loss
  
  def preprocess(self, img):
    mean_RGB = np.array([123.68, 116.779, 103.939])
    return (img - mean_RGB)/255
  
  def load_dataset(self, phone_dir, dslr_dir):
    train_imgs_iphone = [phone_dir+'{}.jpg'.format(i) for i in range(0, 3500)]
    train_imgs_dslr = [dslr_dir+'{}.jpg'.format(i) for i in range(0, 3500)]
    phone_images = []
    dslr_images = []
    for file1, file2 in zip(train_imgs_iphone, train_imgs_dslr):
      try:
        phone_images.append(cv2.resize(cv2.imread(file1, cv2.IMREAD_COLOR), (96, 96), interpolation=cv2.INTER_AREA))
        dslr_images.append(cv2.resize(cv2.imread(file2, cv2.IMREAD_COLOR), (96, 96), interpolation=cv2.INTER_AREA)) 
      except:
        continue
    print(len(phone_images))
    print(len(dslr_images))
    return phone_images, dslr_images

  def get_batch(self, phone_images, dslr_images, augmentation):
    phone_batch = np.zeros([self.batch_size, self.patch_size, self.patch_size, 3], dtype = 'float32')
    dslr_batch = np.zeros([self.batch_size, self.patch_size, self.patch_size, 3], dtype = 'float32')
    for img_no in range(self.batch_size):
        random_index = np.random.randint(len(phone_images))
        phone_patch = phone_images[random_index]
        dslr_patch = dslr_images[random_index]
        # randomly flip, rotate patch (assuming that the patch shape is square)
        if augmentation == True:
            probability = np.random.rand()
            if probability > 0.5:
                phone_patch = np.flip(phone_patch, axis = 0)
                dslr_patch = np.flip(dslr_patch, axis = 0)
            probability = np.random.rand()
            if probability > 0.5:
                phone_patch = np.flip(phone_patch, axis = 1)
                dslr_patch = np.flip(dslr_patch, axis = 1)
            probability = np.random.rand()
            if probability > 0.5:
                phone_patch = np.rot90(phone_patch)
                dslr_patch = np.rot90(dslr_patch)
        phone_batch[img_no,:,:,:] = self.preprocess(phone_patch)
        dslr_batch[img_no,:,:,:] = self.preprocess(dslr_patch)
    return phone_batch, dslr_batch
  
  def train(self, epochs=5000, batch_size=30, save_interval=20):
    phone_images, dslr_images = self.load_dataset('../input/train/train/iphone/', '../input/train/train/canon/')
    valid = np.ones((self.batch_size, 1))
    fake = np.zeros((self.batch_size, 1))
    print(self.discriminator_Dc.metrics_names)
    print(self.discriminator_Dt.metrics_names)
    print(self.combined_model.metrics_names)

    for epoch in range(epochs):
        # print("Epoch: ", epoch)
        # batch_X = tf.placeholder(dtype=tf.float32,shape=(self.batch_size, self.patch_size, self.patch_size, 3))
        # blur_op = self.gaussian_blur(batch_X)
        # sess = tf.Session()
        # blurred_enhanced_phone_batch = sess.run(blur_op, feed_dict={batch_X: enhanced_phone_batch})
        # blurred_dslr_batch = sess.run(blur_op, feed_dict={batch_X: dslr_batch})
        # sess.close()
        # grayscale_enhanced_phone_batch = [color.rgb2gray(img) for img in enhanced_phone_batch]
        # grayscale_dslr_batch = [color.rgb2gray(img) for img in dslr_batch]
        self.discriminator_Dc.trainable = True
        self.discriminator_Dt.trainable = True
        for j in range(1):
            phone_batch, dslr_batch = self.get_batch(phone_images, dslr_images, True)
            enhanced_phone_batch = self.generator.predict(phone_batch)
            Dc_loss_real = self.discriminator_Dc.train_on_batch(dslr_batch, valid)
            Dc_loss_fake = self.discriminator_Dc.train_on_batch(enhanced_phone_batch, fake)
            Dt_loss_real = self.discriminator_Dt.train_on_batch(dslr_batch, valid)
            Dt_loss_fake = self.discriminator_Dt.train_on_batch(enhanced_phone_batch, fake)
        self.discriminator_Dc.trainable = False
        self.discriminator_Dt.trainable = False
        for j in range(4):
            # phone_batch, dslr_batch = self.get_batch(phone_images, dslr_images, True)
            phone_batch, dslr_batch = self.get_batch(phone_images, dslr_images, True) # Can change False to True
            Dc_pred = self.discriminator_Dc.predict(dslr_batch)
            Dt_pred = self.discriminator_Dt.predict(dslr_batch)
            gen_loss = self.combined_model.train_on_batch([phone_batch], [Dc_pred, Dt_pred, dslr_batch])
        if epoch % 100 is  99:
            print("--------------------------------------------{0}----------------------------------------".format(epoch))
            print('Dc_loss_real:', Dc_loss_real)
            print('Dc_loss_fake:', Dc_loss_fake)
            print('Dt_loss_real:', Dt_loss_real)
            print('Dt_loss_fake:', Dt_loss_fake)
            print("Generator Loss: ", gen_loss)


In [55]:
gan_model = WESPE_GAN((96, 96, 3))

  identifier=identifier.__class__.__name__))


In [56]:
gan_model.generator.trainable_weights

[<tf.Variable 'G_conv0_7/kernel:0' shape=(9, 9, 3, 64) dtype=float32_ref>,
 <tf.Variable 'G_conv0_7/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'G_conv1_1_7/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
 <tf.Variable 'G_conv1_1_7/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'G_bn1_1_7/gamma:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'G_bn1_1_7/beta:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'G_conv1_2_7/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
 <tf.Variable 'G_conv1_2_7/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'G_bn1_2_7/gamma:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'G_bn1_2_7/beta:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'G_conv2_1_7/kernel:0' shape=(3, 3, 64, 64) dtype=float32_ref>,
 <tf.Variable 'G_conv2_1_7/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'G_bn2_1_7/gamma:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'G_bn2_1_7/beta:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'G_conv2_2_7/kerne

In [57]:
gan_model.generator.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_45 (InputLayer)           (None, 96, 96, 3)    0                                            
__________________________________________________________________________________________________
G_conv0 (Conv2D)                (None, 96, 96, 64)   15616       input_45[0][0]                   
__________________________________________________________________________________________________
G_conv1_1 (Conv2D)              (None, 96, 96, 64)   36928       G_conv0[0][0]                    
__________________________________________________________________________________________________
G_bn1_1 (BatchNormalization)    (None, 96, 96, 64)   256         G_conv1_1[0][0]                  
__________________________________________________________________________________________________
activation

In [58]:
gan_model.inverse_generator.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_46 (InputLayer)           (None, 96, 96, 3)    0                                            
__________________________________________________________________________________________________
F_conv0 (Conv2D)                (None, 96, 96, 64)   15616       input_46[0][0]                   
__________________________________________________________________________________________________
F_conv1_1 (Conv2D)              (None, 96, 96, 64)   36928       F_conv0[0][0]                    
__________________________________________________________________________________________________
F_bn1_1 (BatchNormalization)    (None, 96, 96, 64)   256         F_conv1_1[0][0]                  
__________________________________________________________________________________________________
activation

In [59]:
gan_model.combined_model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_47 (InputLayer)           (None, 96, 96, 3)    0                                            
__________________________________________________________________________________________________
Network_G (Model)               (None, 96, 96, 3)    718979      input_47[0][0]                   
__________________________________________________________________________________________________
Network_Dc (Model)              (None, 1)            5669057     Network_G[1][0]                  
__________________________________________________________________________________________________
Network_Dt (Model)              (None, 1)            5657441     Network_G[1][0]                  
__________________________________________________________________________________________________
Network_F 

In [60]:
gan_model.discriminator_Dc.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_43 (InputLayer)        (None, 96, 96, 3)         0         
_________________________________________________________________
lambda_15 (Lambda)           (None, 96, 96, 3)         0         
_________________________________________________________________
Dc_conv0 (Conv2D)            (None, 24, 24, 48)        17472     
_________________________________________________________________
Dc_conv1 (Conv2D)            (None, 12, 12, 128)       153728    
_________________________________________________________________
Dc_bn1 (BatchNormalization)  (None, 12, 12, 128)       512       
_________________________________________________________________
leaky_re_lu_86 (LeakyReLU)   (None, 12, 12, 128)       0         
_________________________________________________________________
Dc_conv2 (Conv2D)            (None, 12, 12, 192)       221376    
__________

  'Discrepancy between trainable weights and collected trainable'


In [61]:
gan_model.discriminator_Dt.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_44 (InputLayer)        (None, 96, 96, 3)         0         
_________________________________________________________________
lambda_16 (Lambda)           (None, 96, 96, 1)         0         
_________________________________________________________________
Dt_conv0 (Conv2D)            (None, 24, 24, 48)        5856      
_________________________________________________________________
Dt_conv1 (Conv2D)            (None, 12, 12, 128)       153728    
_________________________________________________________________
Dt_bn1 (BatchNormalization)  (None, 12, 12, 128)       512       
_________________________________________________________________
leaky_re_lu_92 (LeakyReLU)   (None, 12, 12, 128)       0         
_________________________________________________________________
Dt_conv2 (Conv2D)            (None, 12, 12, 192)       221376    
__________

  'Discrepancy between trainable weights and collected trainable'


In [62]:
gan_model.train(epochs=1000)

3500
3500
['loss', 'acc']
['loss', 'acc']
['loss', 'Network_Dc_loss', 'Network_Dt_loss', 'Network_F_loss']
--------------------------------------------99----------------------------------------
Dc_loss_real: [3.7133937e-06, 1.0]
Dc_loss_fake: [0.0003523166, 1.0]
Dt_loss_real: [0.38510722, 0.76666665]
Dt_loss_fake: [0.008477378, 1.0]
Generator Loss:  [42.86021, 1.6617748, 22.82245, 18.375984]
--------------------------------------------199----------------------------------------
Dc_loss_real: [1.3987433e-06, 1.0]
Dc_loss_fake: [0.00829271, 1.0]
Dt_loss_real: [0.19569421, 0.96666664]
Dt_loss_fake: [0.5290775, 0.6666667]
Generator Loss:  [-16.266788, 21.035378, 18.08718, -55.389347]
--------------------------------------------299----------------------------------------
Dc_loss_real: [1.3797901e-05, 1.0]
Dc_loss_fake: [5.630855e-07, 1.0]
Dt_loss_real: [0.44486132, 0.76666665]
Dt_loss_fake: [0.50736576, 0.8]
Generator Loss:  [86.95972, 1.9448267, 42.317368, 42.697517]


KeyboardInterrupt: 

### References:

https://stackoverflow.com/questions/38553927/batch-normalization-in-convolutional-neural-network  
https://github.com/eriklindernoren/Keras-GAN/blob/master/dcgan/dcgan.py  
https://github.com/JuheonYi/WESPE-TensorFlow  
https://keras.io/applications/  
https://medium.com/tensorflow/neural-style-transfer-creating-art-with-deep-learning-using-tf-keras-and-eager-execution-7d541ac31398  
https://gogul09.github.io/software/flower-recognition-deep-learning   
https://github.com/keras-team/keras-applications/blob/master/keras_applications/mobilenet_v2.py
https://www.pyimagesearch.com/2018/12/24/how-to-use-keras-fit-and-fit_generator-a-hands-on-tutorial/  
https://github.com/keras-team/keras/issues/7491  
https://keras.io/getting-started/functional-api-guide/#multi-input-and-multi-output-models  
