In [None]:
import numpy as np 
import os
import glob
%tensorflow_version 1.x
import skimage.io as io
import skimage.transform as trans
import tensorflow as tf
from keras import *
from keras import backend as K
from keras.callbacks import *
from keras.layers import *
from keras.models import *
from keras.optimizers import *
from keras.preprocessing.image import *
from keras.losses import *
tf.logging.set_verbosity(tf.logging.ERROR)
import cv2

from tensorflow.python.compat import compat
from tensorflow.python.framework import *
from tensorflow.python.ops import *
from tensorflow.python.util import *
import matplotlib.pyplot as plt
%matplotlib inline
!nvidia-smi

TensorFlow 1.x selected.


Using TensorFlow backend.


In [None]:
## PATHS AND EXPERIMENT INFO ##
train_data_image_folder = 'data/LID/train/high'
train_data_mask_folder = 'data/LID/train/low'
val_data_image_folder = 'data/LID/val/high'
val_data_mask_folder = 'data/LID/val/low'
model_folder = 'clienet/models'

IMG_HEIGHT = 256
IMG_WIDTH = 256
NUM_EPOCHS = 60
BATCH_SIZE = 8

In [None]:
## LOAD TRAIN AND VALIDATION IMAGES AND RESPECTIVE MASKS ##
train_images = []
train_masks = []
val_images = []
val_masks = []

train_files = os.listdir(train_data_image_folder)
train_files = sorted([file for file in train_files if file.endswith(".jpg")])

val_files = os.listdir(val_data_image_folder)
val_files = sorted([file for file in test_files if file.endswith(".jpg")])

print(train_files)
print(val_files)

# LOAD TRAIN IMAGES
for image in train_files:

  print(image)

  train_img = cv2.imread(os.path.join(train_data_image_folder, image))
  train_img = cv2.resize(train_img, (IMG_HEIGHT, IMG_WIDTH), interpolation = cv2.INTER_CUBIC)
  train_images.append(train_img)

  train_mask = cv2.imread(os.path.join(train_data_mask_folder, image))
  train_mask = cv2.resize(train_mask, (IMG_HEIGHT, IMG_WIDTH), interpolation = cv2.INTER_CUBIC)
  train_masks.append(train_mask)

# LOAD VAL IMAGES
for image in val_files:

  print(image)

  val_img = cv2.imread(os.path.join(val_data_image_folder, image))
  val_img = cv2.resize(val_img, (IMG_HEIGHT, IMG_WIDTH), interpolation = cv2.INTER_CUBIC)
  val_images.append(val_img)

  val_mask = cv2.imread(os.path.join(val_data_mask_folder, image))
  val_mask = cv2.resize(val_mask, (IMG_HEIGHT, IMG_WIDTH), interpolation = cv2.INTER_CUBIC)
  val_masks.append(val_mask)

# CREATE THE ARRAYS
train_image_array = np.array(train_images)
train_mask_array = np.array(train_masks)
val_image_array = np.array(val_images)
val_mask_array = np.array(val_masks)

print(train_image_array.shape,train_mask_array.shape)
print(val_image_array.shape,val_mask_array.shape)

In [None]:
## DEFINE METRICS AND LOSS COMPONENTS ##

from keras import backend as K
from keras.applications.vgg16 import VGG16
vggmodel = VGG16(include_top=False)

# PEAK SIGNAL TO NOISE RATIO
def psnr(y_true, y_pred):
    return tf.image.psnr(y_true, y_pred, max_val=255.0)

# STRUCTURAL SIMILARITY INDEX
def ssim(y_true, y_pred):
    return tf.image.ssim(y_true, y_pred, max_val=255.0)

# ABSOLUTE BRIGHTNESS
def ab(y_true, y_pred):
    return K.abs(K.mean(y_true[:,:,:,:3])-K.mean(y_pred[:,:,:,:3]))

# STRUCTURAL LOSS
def ssim_loss(y_true, y_pred):
	return 1-tf.image.ssim(y_true, y_pred, max_val=255.0)

# PERCEPTUAL LOSS
def per_loss_vgg(img_true, img_generated):
	image_shape = (256, 256, 3)
	vgg = VGG16(include_top=False, weights='imagenet', input_shape=image_shape)
	loss_block1 = Model(inputs=vgg.input, outputs=vgg.get_layer('block1_conv2').output)
	loss_block1.trainable = False
	loss_block2 = Model(inputs=vgg.input, outputs=vgg.get_layer('block2_conv2').output)
	loss_block2.trainable = False
	loss_block3 = Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output)
	loss_block3.trainable = False
	normalisation = 15 * 256 * 256
	loss = 	K.mean(K.square(img_true - img_generated)) + \
			2 * K.mean(K.square(loss_block1(img_true) - loss_block1(img_generated))) + \
			4 * K.mean(K.square(loss_block2(img_true) - loss_block2(img_generated))) + \
			8 * K.mean(K.square(loss_block3(img_true) - loss_block3(img_generated)))
	return loss/normalisation

# WEIGHTED PATCH-WISE EUCLIDEAN LOSS
def wpw(y_true,y_pred,weight = 4, percentage = 0.25, patches_per_row = 16):
	gray_org = 0.39 * y_pred[:, :, :, 0] + 0.5 * y_pred[:, :, :, 1] + 0.11 * y_pred[:, :, :, 2]
	gray_true = 0.39 * y_true[:, :, :, 0] + 0.5 * y_true[:, :, :, 1] + 0.11 * y_true[:, :, :, 2]
	gray = tf.expand_dims(gray_org,-1)
	patch_length = int(gray_org.shape[1])/patches_per_row 
	no_of_patches = patches_per_row*patches_per_row
	no_of_patches_to_consider = int(no_of_patches * percentage)
	normalization_factor = int(no_of_patches*((weight - 1)*percentage + 1)*int(gray_org.shape[1])*int(gray_org.shape[1]))
	filter_of_ones = tf.ones([patch_length,patch_length,1,1], tf.float32)
	sum_of_patches = tf.nn.conv2d(gray, filter = filter_of_ones, strides= patch_length, padding= 'SAME')
	sorted_sums = tf.sort(tf.reshape(sum_of_patches, shape = [-1]))
	threshold_sum = sorted_sums[no_of_patches_to_consider]
	mask = tf.to_float(sum_of_patches <= threshold_sum)
	weighted_mask_per_channel = tf.add(tf.multiply(float(weight),mask), tf.subtract(float(1),mask))
	squared_loss = tf.square(gray_org - gray_true)
	squared_loss = tf.expand_dims(squared_loss,-1)
	sum_s = tf.reduce_sum(squared_loss)
	filter_of_ones_d = tf.ones([patch_length,patch_length,1,1], tf.float32)
	sum_of_squared_loss_patches = tf.nn.conv2d(squared_loss, filter = filter_of_ones_d, strides= patch_length, padding= 'SAME')
	loss = tf.reduce_sum(tf.multiply(weighted_mask_per_channel, sum_of_squared_loss_patches))
	return loss/normalization_factor

# TOTAL LOSS
def total_loss(y_true,y_pred):
	w_per = 1
	w_ssim = 1
	w_wpw = 0.1
	total_loss = w_per * per_loss_vgg(y_true, y_pred) + w_ssim * (1 - ssim(y_true,y_pred)) + w_wpw * wpw(y_true, y_pred)
 
	return total_loss


In [None]:
## DEFINE MODEL COMPONENTS ##

def _define_conv_block(
    input_, layers, filters,
    kernel_size=3, padding='same', activation='relu', kernel_initializer='he_normal', **kwargs):
    output_ = Conv2D(filters, kernel_size, activation = activation, padding = padding, kernel_initializer = kernel_initializer, **kwargs)(input_)
    for layer in range(1, layers):
        output_ = Conv2D(filters, kernel_size, activation = activation, padding = padding, kernel_initializer = kernel_initializer, **kwargs)(output_)
    return output_

def _define_sep_conv_block(
    input_, layers, filters,
    kernel_size=3, padding='same', activation='relu', kernel_initializer='he_normal', **kwargs):
    output_ = SeparableConv2D(filters, kernel_size, activation = activation, padding = padding, kernel_initializer = kernel_initializer, **kwargs)(input_)
    for layer in range(1, layers):
        output_ = SeparableConv2D(filters, kernel_size, activation = activation, padding = padding, kernel_initializer = kernel_initializer, **kwargs)(output_)
    return output_

def _define_aspp_block(input_, filters, dilation_rates=list((1, 3, 5, 7)),
                       kernel_size=3, padding='same', activation='relu', kernel_initializer='he_normal', **kwargs):
    num_parallel_outputs = len(dilation_rates)
    parallel_outputs = [None] * num_parallel_outputs
    for output_idx in range(num_parallel_outputs):
        dilation_rate = (dilation_rates[output_idx], dilation_rates[output_idx])
        parallel_outputs[output_idx] = SeparableConv2D(filters, kernel_size, dilation_rate = dilation_rate, activation = activation, padding = padding, kernel_initializer = kernel_initializer, **kwargs)(input_)
    output_ = concatenate(parallel_outputs, axis=3)
    return output_

def encoder_decoder_with_aspp_blocks(input_shape = (256, 256, 3)):
    inputs = Input(input_shape)
    
    block1 = _define_conv_block(inputs, 2, 64)
    pool1 = MaxPooling2D(pool_size=(2, 2))(block1)

    block2 = _define_aspp_block(pool1, 64)
    block2 = _define_sep_conv_block(block2, 2, 64)
    pool2 = MaxPooling2D(pool_size=(2, 2))(block2)
    
    block3 = _define_aspp_block(pool2, 128)
    block3 = _define_sep_conv_block(block3, 2, 128)
    pool3 = MaxPooling2D(pool_size=(2, 2))(block3)
    
    block4 = _define_aspp_block(pool3, 256)
    block4 = _define_sep_conv_block(block4, 2, 256)
    pool4 = MaxPooling2D(pool_size=(2, 2))(block4)

    block5 = _define_aspp_block(pool4, 512)
    block5 = _define_sep_conv_block(block5, 4, 512)
    
    up6 = SeparableConv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(block5))
    merge6 = concatenate([block4, up6], axis = 3)
    block6 = _define_sep_conv_block(merge6, 2, 512)

    up7 = SeparableConv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(block6))
    merge7 = concatenate([block3, up7], axis = 3)
    block7 = _define_sep_conv_block(merge7, 2, 256)

    up8 = SeparableConv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(block7))
    merge8 = concatenate([block2, up8], axis = 3)
    block8 = _define_sep_conv_block(merge8, 2, 128)

    up9 = SeparableConv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(block8))
    merge9 = concatenate([block1, up9], axis = 3)    
      
    block9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
    output = Conv2D(3, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(block9)
    model = Model(inputs = inputs, outputs = output)
    return model

In [None]:
## BUILD MODEL ##

# Model Compile
loss = total_loss

# Configure Optimizer
learning_rate = 1e-4
optimizer = Adam(lr = learning_rate)

# Others
metrics = [psnr, ssim, ab,per_loss_vgg, ssim_loss,wpw]

# Build
model = encoder_decoder_with_aspp_blocks()
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
model.summary()

In [None]:
## TRAIN ##
if not os.path.exists(model_folder):
  os.makedirs(model_folder)
model_name = 'model_clienet_epoch-{epoch:03d}_val_loss-{val_loss:.4f}_psnr-{val_psnr:.3f}_ssim-{val_ssim:.3f}_ab-{val_ab:.3f}.h5'
checkpointer = ModelCheckpoint(os.path.join(model_folder, model_name), verbose=0, save_best_only=False, save_weights_only = True)

model.fit(
    x=train_image_array,
    y=train_mask_array,
    batch_size=BATCH_SIZE,
    epochs=NUM_EPOCHS,
    verbose=1,
    callbacks=[checkpointer],
    validation_data=(val_image_array,val_mask_array)
)
