# Low light face enhancement with Retinex-Net

## 1.import library

In [None]:
#this project is for low light face enhancement
import numpy as np
import sklearn
import tensorflow as tf
import os
import scipy

## 2.define hyper parameters

In [None]:
#define all parameters
inw = 600   #input image width
inh = 400
inc = 3
lr = 0.01 #learning rate
batch = 5 #batch number, better a factor of total image number
epoch = 100
dk1 = 6 #decom-net kernel number for conv1
dk2 = 12
dk3 = 24
dk4 = 48
dk5 = 24
dk6 = 12
dk7 = 6
ek1 = 3 #enhance-net kernel number for conv1
ek2 = 3
ek3 = 3
lamda00=1   
lamda01=0.01
lamda10=0.01
lamda11=1
lamda_g= -10
lamda_ir = 0.001
lamda_is = 0.1

## 3.define local functions

In [None]:
def nearestNeighborScaling4D( source, newHt,newWid):
    target = tf.image.resize_nearest_neighbor(source, (newHt,newWid))
    return target

'''
def nearestNeighborScaling4D( source, newHt,newWid):
     #souce: 4D tensor, 4D=[batch,height,width,channel]
     #newHt: target Height
     #newWid: target Width
    width  = int(source.shape[2])
    height = int(source.shape[1])
    print("source shape:",source.shape)    
    newX = tf.Variable(tf.zeros(shape=[source.shape[0],newHt+1,1,source.shape[-1]]), name="tempX")
    for x in range(0, newWid):  
        newY = tf.Variable(tf.zeros(shape=[source.shape[0],1,1,source.shape[-1]]), name="tempY")        
        for y in range(0, newHt):
            srcX = int( round( float(x) / float(newWid) * float(width) ) )
            srcY = int( round( float(y) / float(newHt) * float(height) ) )
            srcX = min( srcX, width-1)
            srcY = min( srcY, height-1)
            sourceSlice = tf.slice(source,[0,srcY,srcX,0],[source.shape[0],1,1,source.shape[-1]])
            newY = tf.concat([newY,sourceSlice],axis=1)
        newX = tf.concat([newX,newY],axis=2)
    target = tf.slice(newX,[0,1,1,0],[source.shape[0],newHt,newWid,source.shape[-1]])
    print("target shape:",target.shape)
    return target
'''

def nearestNeighborScaling3D( source, newHt,newWid):
    '''
     souce: 3D array, 3D=[height,width,channel]
     newHt: target Height
     newWid: target Width
    '''
    width  = int(source.shape[1])
    height = int(source.shape[0])
    target = np.zeros((newHt,newWid,source.shape[-1]),dtype = np.float32)
    for x in range(0, newWid):  
        for y in range(0, newHt):
            srcX = int( round( float(x) / float(newWid) * float(width) ) )
            srcY = int( round( float(y) / float(newHt) * float(height) ) )
            srcX = min( srcX, width-1)
            srcY = min( srcY, height-1)
            target[y,x,:] = source[srcY,srcX,:]
    return target


def input_parser(img_path):
    # read the img from file
    img_file = tf.read_file(img_path)
    img_decoded = tf.image.decode_image(img_file, channels=3)
    return img_decoded

def input_parser_pair(img_path0,img_path1):
    # read the img from file
    img_file0 = tf.read_file(img_path0)
    img_file1 = tf.read_file(img_path1)
    img_decoded0 = tf.image.decode_image(img_file0,dtype=tf.float32, channels=3)/255.0
    img_decoded1 = tf.image.decode_image(img_file1,dtype=tf.float32, channels=3)/255.0
    return img_decoded0,img_decoded1 

## 4.build network

In [None]:
#build network
#with tf.Graph().as_default():
with tf.name_scope('input_image'):
    xlow = tf.placeholder(tf.float32, shape=[batch,inh,inw,inc],name = 'input_low')   #input low light image
    print(xlow.shape)
    xnorm = tf.placeholder(tf.float32, shape=[batch, inh, inw, inc], name='input_norm') #input norm light image
        
with tf.variable_scope("decom", reuse=tf.AUTO_REUSE):  #define weight and bias
    #shared weight between low light Decom-net and normal light Decom-net
    weight1 = tf.get_variable(name='w1', shape=[3,3,3,dk1],  initializer=tf.contrib.layers.xavier_initializer_conv2d(), dtype=tf.float32)
    weight2 = tf.get_variable(name='w2', shape=[3,3,dk1,dk2], initializer=tf.contrib.layers.xavier_initializer_conv2d(), dtype=tf.float32)
    weight3 = tf.get_variable(name='w3', shape=[3,3,dk2,dk3],initializer=tf.contrib.layers.xavier_initializer_conv2d(), dtype=tf.float32)
    weight4 = tf.get_variable(name='w4', shape=[3,3,dk3,dk4],initializer=tf.contrib.layers.xavier_initializer_conv2d(), dtype=tf.float32)
    weight5 = tf.get_variable(name='w5', shape=[3,3,dk4,dk5],initializer=tf.contrib.layers.xavier_initializer_conv2d(), dtype=tf.float32)
    weight6 = tf.get_variable(name='w6', shape=[3,3,dk5,dk6],initializer=tf.contrib.layers.xavier_initializer_conv2d(), dtype=tf.float32)
    weight7 = tf.get_variable(name='w7', shape=[3,3,dk6,dk7], initializer=tf.contrib.layers.xavier_initializer_conv2d(), dtype=tf.float32)
    #bias for low light conv
    lowb1 = tf.get_variable("lowb1", shape=[dk1], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    lowb2 = tf.get_variable("lowb2", shape=[dk2], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    lowb3 = tf.get_variable("lowb3", shape=[dk3], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    lowb4 = tf.get_variable("lowb4", shape=[dk4], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    lowb5 = tf.get_variable("lowb5", shape=[dk5], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    lowb6 = tf.get_variable("lowb6", shape=[dk6], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    lowb7 = tf.get_variable("lowb7", shape=[dk7], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    #bias for normal light conv
    normb1 = tf.get_variable("normb1", shape=[dk1], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    normb2 = tf.get_variable("normb2", shape=[dk2], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    normb3 = tf.get_variable("normb3", shape=[dk3], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    normb4 = tf.get_variable("normb4", shape=[dk4], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    normb5 = tf.get_variable("normb5", shape=[dk5], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    normb6 = tf.get_variable("normb6", shape=[dk6], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    normb7 = tf.get_variable("normb7", shape=[dk7], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
        
with tf.name_scope('decom_net_low'):
    lowcc1 = tf.nn.conv2d(xlow, weight1, strides=[1,1,1,1], padding='SAME', name="cc1")
    lowconv1 = tf.nn.relu(lowcc1 + lowb1, name="conv1")
    print("low conv1 shape:",lowconv1.shape)
        
    lowcc2 = tf.nn.conv2d(lowcc1, weight2, strides=[1,1,1,1], padding='SAME', name="cc2")
    lowconv2 = tf.nn.relu(lowcc2 + lowb2, name="conv2")        
    #print("low conv2 shape:",lowconv2.shape)
            
    lowcc3 = tf.nn.conv2d(lowcc2, weight3, strides=[1,1,1,1], padding='SAME', name="cc3")
    lowconv3 = tf.nn.relu(lowcc3 + lowb3, name="conv3")  
    #print("low conv3 shape:",lowconv3.shape)
        
    lowcc4 = tf.nn.conv2d(lowcc3, weight4, strides=[1,1,1,1], padding='SAME', name="cc4")
    lowconv4 = tf.nn.relu(lowcc4 + lowb4, name="conv4")  
    #print("low conv4 shape:",lowconv4.shape) 

    lowcc5 = tf.nn.conv2d(lowcc4, weight5, strides=[1,1,1,1], padding='SAME', name="cc5")
    lowconv5 = tf.nn.relu(lowcc5 + lowb5, name="conv5")  
    #print("low conv5 shape:",lowconv5.shape)
        
    lowcc6 = tf.nn.conv2d(lowcc5, weight6, strides=[1,1,1,1], padding='SAME', name="cc6")
    lowconv6 = tf.nn.relu(lowcc6 + lowb6, name="conv6")  
    #print("low conv6 shape:",lowconv6.shape)
        
    lowcc7 = tf.nn.conv2d(lowcc6, weight7, strides=[1,1,1,1], padding='SAME', name="cc7")
    lowconv7 = tf.nn.sigmoid(lowcc7 + lowb7, name="conv7") 
    print("low conv7 shape:",lowconv7.shape)

with tf.name_scope('decom_net_normal'):
    normcc1 = tf.nn.conv2d(xnorm, weight1, strides=[1,1,1,1], padding='SAME', name="cc1")
    normconv1 = tf.nn.relu(normcc1 + normb1, name="conv1")
        
    normcc2 = tf.nn.conv2d(normcc1, weight2, strides=[1,1,1,1], padding='SAME', name="cc2")
    normconv2 = tf.nn.relu(normcc2 + normb2, name="conv2")        
        
    normcc3 = tf.nn.conv2d(normcc2, weight3, strides=[1,1,1,1], padding='SAME', name="cc3")
    normconv3 = tf.nn.relu(normcc3 + normb3, name="conv3")  
        
    normcc4 = tf.nn.conv2d(normcc3, weight4, strides=[1,1,1,1], padding='SAME', name="cc4")
    normconv4 = tf.nn.relu(normcc4 + normb4, name="conv4")  
        
    normcc5 = tf.nn.conv2d(normcc4, weight5, strides=[1,1,1,1], padding='SAME', name="cc5")
    normconv5 = tf.nn.relu(normcc5 + normb5, name="conv5")  
        
    normcc6 = tf.nn.conv2d(normcc5, weight6, strides=[1,1,1,1], padding='SAME', name="cc6")
    normconv6 = tf.nn.relu(normcc6 + normb6, name="conv6")  
        
    normcc7 = tf.nn.conv2d(normcc6, weight7, strides=[1,1,1,1], padding='SAME', name="cc7")
    normconv7 = tf.nn.sigmoid(normcc7 + normb7, name="conv7")  
    print("normal conv7 shape:",normconv7.shape)
        
with tf.name_scope('decom_output_low'):
    Ilow = tf.slice(lowconv7,[0,0,0,0],[batch,-1,-1,int(dk7/2)],name="Ilow_output")   #output I low image
    Rlow = tf.slice(lowconv7,[0,0,0,int(dk7/2)],[batch,-1,-1,int(dk7/2)],name='Rlow_output') #output R low image        
    
with tf.name_scope('decom_output_norm'):
    Inorm = tf.slice(normconv7,[0,0,0,0],[batch,-1,-1,int(dk7/2)],name="Inorm_output")   #output I norm image
    Rnorm = tf.slice(normconv7,[0,0,0,int(dk7/2)],[batch,-1,-1,int(dk7/2)],name='Rnorm_output') #output R norm image 
    print("Ilow shape:",Ilow.shape)
    print("Rlow shape:",Rlow.shape)
    print("Inorm shape:",Inorm.shape)
    print("Rnorm shape:",Rnorm.shape) 

with tf.name_scope('enhance_net'):
    with tf.name_scope("encoder_decoder"):
        encconv1  =  tf.layers.conv2d(inputs=Ilow,filters=ek1,kernel_size=[3,3],strides=[2,2],padding='same',activation=tf.nn.relu,name='downsample1')
        encconv2  =  tf.layers.conv2d(inputs=encconv1,filters=ek2,kernel_size=[3,3],strides=[2,2],padding='same',activation=tf.nn.relu,name='downsample2')
        encconv3  =  tf.layers.conv2d(inputs=encconv2,filters=ek3,kernel_size=[3,3],strides=[2,2],padding='same',activation=tf.nn.relu,name='downsample3')
        
        #import types
        #print(type(normconv7))
        conv3_shape = encconv3.shape
        upsample1 = nearestNeighborScaling4D(encconv3, int(conv3_shape[1]*2),int(conv3_shape[2]*2))  #upsample,W*2,H*2 
        upconv1  =  tf.layers.conv2d(inputs=upsample1,filters=ek1,kernel_size=[3,3],strides=[1,1],padding='same',activation=tf.nn.relu,name='upconv1')
        upresidual1 =  tf.add(upconv1,encconv2,name="upres1")
        
        upresidual1_shape = upresidual1.shape
        upsample2 = nearestNeighborScaling4D(upresidual1, int(upresidual1_shape[1]*2),int(upresidual1_shape[2]*2))  #upsample,W*2,H*2 
        upconv2  =  tf.layers.conv2d(inputs=upsample2,filters=ek1,kernel_size=[3,3],strides=[1,1],padding='same',activation=tf.nn.relu,name='upconv2')
        upresidual2 =  tf.add(upconv2,encconv1,name="upres2")
        
        upresidual2_shape = upresidual2.shape
        upsample3 = nearestNeighborScaling4D(upresidual2, int(upresidual2_shape[1]*2),int(upresidual2_shape[2]*2))  #upsample,W*2,H*2 
        upconv3  =  tf.layers.conv2d(inputs=upsample3,filters=ek1,kernel_size=[3,3],strides=[1,1],padding='same',activation=tf.nn.relu,name='upconv3')
        upresidual3 =  tf.add(upconv3,Ilow,name="upres3")  
        print("upresidual3 shape:",upresidual3.shape)
        
    with tf.name_scope("upsample_concat"):
        upresidual3_shape = upresidual3.shape
        preconcat1 = nearestNeighborScaling4D(upresidual1, int(upresidual3_shape[1]),int(upresidual3_shape[2])) 
        preconcat2 = nearestNeighborScaling4D(upresidual2, int(upresidual3_shape[1]),int(upresidual3_shape[2])) 
        preconcat3 = upresidual3
        concat = tf.concat([preconcat1,preconcat2,preconcat3],axis=3)  #concat on channel direction
        print("concat shape:",concat.shape)
        
    with tf.name_scope('reconstruct_illumination'):
        resizeconv = tf.layers.conv2d(inputs=concat,filters=inc,kernel_size=[1,1],strides=[1,1],padding='same',activation=tf.nn.relu,name='resize_conv')
        reconI = tf.layers.conv2d(inputs=resizeconv,filters=inc,kernel_size=[3,3],strides=[1,1],padding='same',activation=tf.nn.relu,name='reconstruct_illumination')
        print("reconI shape:",reconI.shape)
        
with tf.name_scope("reconstruct_Image"):
    denoiseR = Rlow
    reconImage = tf.multiply(reconI,denoiseR,name="reconstruct_Image")

## 5.define loss function

In [None]:
##decom-net loss function
Lrecon00 = tf.abs(tf.multiply(Ilow,Rlow)-xlow)
Lrecon01 = tf.abs(tf.multiply(Ilow,Rnorm)-xlow)
Lrecon10 = tf.abs(tf.multiply(Inorm,Rlow)-xnorm)
Lrecon11 = tf.abs(tf.multiply(Inorm,Rnorm)-xnorm)
Lrecon_decom = tf.reduce_mean(lamda00*Lrecon00+lamda01*Lrecon01+lamda10*Lrecon10+lamda11*Lrecon11)
Lir_decom = tf.reduce_mean(tf.abs(Rlow-Rnorm))

Lis_decom = tf.reduce_mean(tf.image.image_gradients(Ilow)[0]*tf.exp(lamda_g*tf.image.image_gradients(Rlow)[0])) + \
            tf.reduce_mean(tf.image.image_gradients(Ilow)[1]*tf.exp(lamda_g*tf.image.image_gradients(Rlow)[1])) + \
            tf.reduce_mean(tf.image.image_gradients(Inorm)[0]*tf.exp(lamda_g*tf.image.image_gradients(Rnorm)[0])) + \
            tf.reduce_mean(tf.image.image_gradients(Inorm)[1]*tf.exp(lamda_g*tf.image.image_gradients(Rnorm)[1]))
Ldecom = Lrecon_decom + lamda_ir*Lir_decom + lamda_is*Lis_decom


##enhance-net loss function
Lrecon_enh = tf.reduce_mean(reconI*denoiseR-xnorm)
Lis_enh = tf.reduce_mean(tf.image.image_gradients(reconI)[0]*tf.exp(lamda_g*tf.image.image_gradients(denoiseR)[0])) + \
          tf.reduce_mean(tf.image.image_gradients(reconI)[1]*tf.exp(lamda_g*tf.image.image_gradients(denoiseR)[1]))
Lenh = Lrecon_enh + lamda_is*Lis_enh

## 6.generate dataset and iterator

In [None]:
#get file name list
low_dir = "E:\\MyDownloads\\Download\\1006\\RetinexNet-master\\data\\our485\\low\\*.png"
norm_dir = "E:\\MyDownloads\\Download\\1006\\RetinexNet-master\\data\\our485\\high\\*.png"
low_names = tf.train.match_filenames_once(low_dir) #return all matched names, a variable
norm_names = tf.train.match_filenames_once(norm_dir) 
low_list = tf.Variable("",dtype=tf.string) 
norm_list = tf.Variable("",dtype=tf.string) 
with tf.Session() as sess:    
    sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))
    filesfind0,filesfind1 = sess.run((low_names,norm_names))
    f_string0 = [""]*filesfind0.shape[0]
    f_string1 = [""]*filesfind1.shape[0]
    assert filesfind0.shape[0]==filesfind1.shape[0]
    for i in range(0,filesfind0.shape[0]):
        f_string0[i] = filesfind0[i].decode()  #from byte to string
        f_string1[i] = filesfind1[i].decode()
    low_list = f_string0
    norm_list = f_string1

print("low light file:",low_list[0])
print("low light file number:",len(low_list))
print("norm light file number:",len(norm_list))

#build data set
total_train_images = len(low_list)
tr_data = tf.data.Dataset.from_tensor_slices((low_list,norm_list))   #buid dataset from file name
tr_data = tr_data.map(input_parser_pair) #map file name to file
tr_data = tr_data.shuffle(buffer_size=total_train_images).batch(batch).repeat(epoch) #form batch, pack 1 batch into 1 element

#get iterator
iterator = tr_data.make_initializable_iterator() #make iterator
next_element = iterator.get_next()  #next batch actually
training_init_op = iterator.make_initializer(tr_data)

## 7.begin training

### 7.1 train decom-net

In [None]:
##run training
optimizer = tf.train.GradientDescentOptimizer(lr).minimize(Ldecom)
i=0
j=0
with tf.Session() as sess:
    # initialize the iterator on the training data
    sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))
    sess.run(training_init_op)

    # get each element of the training dataset until the end is reached
    while True:
        try:
            i +=1
            elem = sess.run(next_element)  #get next batch input image
            feed = {
                xlow:elem[0],
                xnorm:elem[1]
            }
            sess.run(optimizer,feed_dict=feed)
            if i%(485/5)==0:
                j +=1
                print("epoch:",j)
            if i%20 == 0:
                print('loss for batch ',i," is ", sess.run(Ldecom, feed_dict=feed))
        except tf.errors.OutOfRangeError:
            print("End of training dataset.")
            break    