# Low light face enhancement with Retinex-Net

## 1.import library

In [1]:
#this project is for low light face enhancement
import numpy as np
import sklearn
import tensorflow as tf
import os
import scipy
import random
from PIL import Image

  from ._conv import register_converters as _register_converters


## 2.define hyper parameters

In [2]:
#define all parameters
load_pretrain = 1
auto_checkpoint = 0
seed = 2018
patch = 48 
lr = 0.001 #learning rate
batch = 15 #batch number, better a factor of total image number
epoch = 100 #total epoch for both training Decom-Net + Enhance-Net
dk1 = 64 #decom-net kernel number for conv1
dk2 = 64
dk3 = 64
dk4 = 64
dk5 = 64
dk6 = 64
dk7 = 4

ek1 = 64 #enhance-net kernel number for conv1
ek2 = 64
ek3 = 64
rk = 64 #resize layer kernel numebr

random.seed = seed

## 3.define local functions

In [3]:
def nearestNeighborScaling4D( source, newHt,newWid):
    target = tf.image.resize_nearest_neighbor(source, (newHt,newWid))
    return target

'''
def nearestNeighborScaling4D( source, newHt,newWid):
     #souce: 4D tensor, 4D=[batch,height,width,channel]
     #newHt: target Height
     #newWid: target Width
    width  = int(source.shape[2])
    height = int(source.shape[1])
    print("source shape:",source.shape)    
    newX = tf.Variable(tf.zeros(shape=[source.shape[0],newHt+1,1,source.shape[-1]]), name="tempX")
    for x in range(0, newWid):  
        newY = tf.Variable(tf.zeros(shape=[source.shape[0],1,1,source.shape[-1]]), name="tempY")        
        for y in range(0, newHt):
            srcX = int( round( float(x) / float(newWid) * float(width) ) )
            srcY = int( round( float(y) / float(newHt) * float(height) ) )
            srcX = min( srcX, width-1)
            srcY = min( srcY, height-1)
            sourceSlice = tf.slice(source,[0,srcY,srcX,0],[source.shape[0],1,1,source.shape[-1]])
            newY = tf.concat([newY,sourceSlice],axis=1)
        newX = tf.concat([newX,newY],axis=2)
    target = tf.slice(newX,[0,1,1,0],[source.shape[0],newHt,newWid,source.shape[-1]])
    print("target shape:",target.shape)
    return target

def nearestNeighborScaling3D( source, newHt,newWid):
     #souce: 3D array, 3D=[height,width,channel]
     #newHt: target Height
     #newWid: target Width
    width  = int(source.shape[1])
    height = int(source.shape[0])
    target = np.zeros((newHt,newWid,source.shape[-1]),dtype = np.float32)
    for x in range(0, newWid):  
        for y in range(0, newHt):
            srcX = int( round( float(x) / float(newWid) * float(width) ) )
            srcY = int( round( float(y) / float(newHt) * float(height) ) )
            srcX = min( srcX, width-1)
            srcY = min( srcY, height-1)
            target[y,x,:] = source[srcY,srcX,:]
    return target
'''

# random data augmentation
def data_aug(image0, image1):
    #total 4*2*2 = 16 choice
    mode = np.random.randint(2,size=3)   #mode is an array, 3 elements
    rotate = random.randint(0,4)
    new_image0 = image0
    new_image1 = image1
    if mode[0]==1:  #rotate 0/90/180/270
        new_image0 = tf.image.rot90(new_image0,k=rotate)
        new_image1 = tf.image.rot90(new_image1,k=rotate)
    if mode[1]==1: #flip left/right
        new_image0 = tf.image.flip_left_right(new_image0)
        new_image1 = tf.image.flip_left_right(new_image1)
    if mode[2]==1:  #flip up/down
        new_image0 = tf.image.flip_up_down(new_image0)
        new_image1 = tf.image.flip_up_down(new_image1)
    return new_image0,new_image1 

def patch_gen_do(image0,image1):
    imh = tf.shape(image0)[1]
    imw = tf.shape(image0)[2]
    image0_patch = tf.Variable(lambda:tf.zeros(shape=[1,patch,patch,3]), name="image_patch0")
    image1_patch = tf.Variable(lambda:tf.zeros(shape=[1,patch,patch,3]), name="image_patch1")
    for i in range(batch):
        randh =random.randint(0,400-patch) #random.randint(0,imh-patch)
        randw =random.randint(0,600-patch) #random.randint(0,imw-patch)
        image0_temp = tf.slice(image0,[i,randh,randw,0],[1,patch,patch,-1]) #get slice
        image1_temp = tf.slice(image1,[i,randh,randw,0],[1,patch,patch,-1])
        image0_temp,image1_temp = data_aug(image0_temp,image1_temp) #dat augmentation
        image0_patch = tf.concat([image0_patch,image0_temp],axis=0)
        image1_patch = tf.concat([image1_patch,image1_temp],axis=0)
    image0_patch = tf.slice(image0_patch,[1,0,0,0],[-1,-1,-1,-1])
    image1_patch = tf.slice(image1_patch,[1,0,0,0],[-1,-1,-1,-1])
    return image0_patch,image1_patch    

def patch_gen_bypass(image0,image1):
    in_batch = tf.shape(image0)[0]
    image0_slice = tf.slice(image0,[0,0,0,0],[in_batch,-1,-1,-1])
    image1_slice = tf.slice(image1,[0,0,0,0],[in_batch,-1,-1,-1])
    return image0_slice,image1_slice    

#generate random patch with data augmentation
#def patch_gen(image0,image1,bypass_patch_gen): #can't use if/else inside


#extend channel form 3 to 4, by adding max_in_channel
def channel_extension(image0,image1):
    image0_ex = tf.reduce_max(image0, axis=3, keepdims=True)
    image1_ex = tf.reduce_max(image1, axis=3, keepdims=True)
    image0_new = tf.concat([image0_ex,image0],axis=3)  ##?? the sequence?
    image1_new = tf.concat([image1_ex,image1],axis=3)
    return image0_new,image1_new
    

#from original image pair --> input image pair, resize to 400*600
def input_parser_pair(img_path0,img_path1):
    # read the img from file
    img_file0 = tf.read_file(img_path0)
    img_file1 = tf.read_file(img_path1)
    img_decoded0 = tf.image.decode_image(img_file0,dtype=tf.float32, channels=3) #3D tensor
    img_decoded1 = tf.image.decode_image(img_file1,dtype=tf.float32, channels=3)
    image_stack0 = tf.stack([img_decoded0],axis=0) #3D->4D
    image_stack1 = tf.stack([img_decoded1],axis=0)
    img0_resize = tf.image.resize_nearest_neighbor(image_stack0,[400,600]) #only support 4D input
    img1_resize = tf.image.resize_nearest_neighbor(image_stack1,[400,600])
    img0_resize_3D = (tf.squeeze(img0_resize)) #4D->3D
    img1_resize_3D = (tf.squeeze(img1_resize))
    return img0_resize_3D,img1_resize_3D


#input 4D array to batch number of images,saved to file
def array_to_image(inarray,epoch, fpath, prefix=""):
    batch = inarray.shape[0]
    if (inarray.shape[3]==3): #RGB
        for i in range(batch):
            image_pre = np.uint8(inarray[i,:,:,:]*255)
            im = Image.fromarray(image_pre,mode='RGB')
            im.save(os.path.join(fpath, prefix+'epoch'+str(epoch)+'img'+str(i)+'.png'),'png')
    elif (inarray.shape[3]==1): #gray image, for I
        for i in range(batch):
            image_pre = np.uint8(inarray[i,:,:,0]*255)
            im = Image.fromarray(image_pre,mode='L')
            im.save(os.path.join(fpath, prefix+'epoch'+str(epoch)+'img'+str(i)+'.png'),'png')

'''
#test array_to_image function
im = np.asarray(Image.open("E:\\MyDownloads\\Download\\1006\\RetinexNet-master\\data\\our485\\high\\9.png")).reshape((1, 400, 600, 3))/255.0
array_to_image(im,0,"E:\\MyDownloads\\Download\\1006\\RetinexNet-master\\myresult\\train" , "HH")
'''

'\n#test array_to_image function\nim = np.asarray(Image.open("E:\\MyDownloads\\Download\\1006\\RetinexNet-master\\data\\our485\\high\\9.png")).reshape((1, 400, 600, 3))/255.0\narray_to_image(im,0,"E:\\MyDownloads\\Download\\1006\\RetinexNet-master\\myresult\\train" , "HH")\n'

## 4.build network

In [4]:
#build network
with tf.name_scope('input_image'):
    imlow = tf.placeholder(tf.float32, shape=[None,None,None,3],name = 'input_low')   #input low light image
    imnorm = tf.placeholder(tf.float32, shape=[None,None,None,3], name='input_norm') #input norm light image
    bypass_patch_gen = tf.placeholder(tf.bool)
    xlow_pre, xnorm_pre = tf.cond(bypass_patch_gen,lambda:patch_gen_bypass(imlow,imnorm),lambda:patch_gen_do(imlow,imnorm))
    xlow, xnorm = channel_extension(xlow_pre,xnorm_pre) #channel extention 3->4 
       
with tf.variable_scope("decom", reuse=tf.AUTO_REUSE):  #define weight and bias
    #shared weight between low light Decom-net and normal light Decom-net
    weight1 = tf.get_variable(name='w1', shape=[3,3,4,dk1],  initializer=tf.contrib.layers.xavier_initializer_conv2d(), dtype=tf.float32)
    weight2 = tf.get_variable(name='w2', shape=[3,3,dk1,dk2], initializer=tf.contrib.layers.xavier_initializer_conv2d(), dtype=tf.float32)
    weight3 = tf.get_variable(name='w3', shape=[3,3,dk2,dk3],initializer=tf.contrib.layers.xavier_initializer_conv2d(), dtype=tf.float32)
    weight4 = tf.get_variable(name='w4', shape=[3,3,dk3,dk4],initializer=tf.contrib.layers.xavier_initializer_conv2d(), dtype=tf.float32)
    weight5 = tf.get_variable(name='w5', shape=[3,3,dk4,dk5],initializer=tf.contrib.layers.xavier_initializer_conv2d(), dtype=tf.float32)
    weight6 = tf.get_variable(name='w6', shape=[3,3,dk5,dk6],initializer=tf.contrib.layers.xavier_initializer_conv2d(), dtype=tf.float32)
    weight7 = tf.get_variable(name='w7', shape=[3,3,dk6,dk7], initializer=tf.contrib.layers.xavier_initializer_conv2d(), dtype=tf.float32)

    #bias for normal light conv
    b1 = tf.get_variable("b1", shape=[dk1], initializer=tf.constant_initializer(0.001), dtype=tf.float32)
    b2 = tf.get_variable("b2", shape=[dk2], initializer=tf.constant_initializer(0.001), dtype=tf.float32)
    b3 = tf.get_variable("b3", shape=[dk3], initializer=tf.constant_initializer(0.001), dtype=tf.float32)
    b4 = tf.get_variable("b4", shape=[dk4], initializer=tf.constant_initializer(0.001), dtype=tf.float32)
    b5 = tf.get_variable("b5", shape=[dk5], initializer=tf.constant_initializer(0.001), dtype=tf.float32)
    b6 = tf.get_variable("b6", shape=[dk6], initializer=tf.constant_initializer(0.001), dtype=tf.float32)
    b7 = tf.get_variable("b7", shape=[dk7], initializer=tf.constant_initializer(0.001), dtype=tf.float32)    
    
with tf.name_scope('decom_net_low'):
    lowcc1 = tf.nn.conv2d(xlow, weight1, strides=[1,1,1,1], padding='SAME', name="cc1")
        
    lowcc2 = tf.nn.conv2d(lowcc1, weight2, strides=[1,1,1,1], padding='SAME', name="cc2")
    lowconv2 = tf.nn.relu(lowcc2 + b2, name="conv2")        
            
    lowcc3 = tf.nn.conv2d(lowcc2, weight3, strides=[1,1,1,1], padding='SAME', name="cc3")
    lowconv3 = tf.nn.relu(lowcc3 + b3, name="conv3")  
        
    lowcc4 = tf.nn.conv2d(lowcc3, weight4, strides=[1,1,1,1], padding='SAME', name="cc4")
    lowconv4 = tf.nn.relu(lowcc4 + b4, name="conv4")  

    lowcc5 = tf.nn.conv2d(lowcc4, weight5, strides=[1,1,1,1], padding='SAME', name="cc5")
    lowconv5 = tf.nn.relu(lowcc5 + b5, name="conv5")  
        
    lowcc6 = tf.nn.conv2d(lowcc5, weight6, strides=[1,1,1,1], padding='SAME', name="cc6")
    lowconv6 = tf.nn.relu(lowcc6 + b6, name="conv6")  
        
    lowcc7 = tf.nn.conv2d(lowcc6, weight7, strides=[1,1,1,1], padding='SAME', name="cc7")
    lowconv7 = tf.nn.sigmoid(lowcc7 + b7, name="conv7") 
    print("low conv7 shape:",lowconv7.shape)

with tf.name_scope('decom_net_normal'):
    normcc1 = tf.nn.conv2d(xnorm, weight1, strides=[1,1,1,1], padding='SAME', name="cc1")
        
    normcc2 = tf.nn.conv2d(normcc1, weight2, strides=[1,1,1,1], padding='SAME', name="cc2")
    normconv2 = tf.nn.relu(normcc2 + b2, name="conv2")        
        
    normcc3 = tf.nn.conv2d(normcc2, weight3, strides=[1,1,1,1], padding='SAME', name="cc3")
    normconv3 = tf.nn.relu(normcc3 + b3, name="conv3")  
        
    normcc4 = tf.nn.conv2d(normcc3, weight4, strides=[1,1,1,1], padding='SAME', name="cc4")
    normconv4 = tf.nn.relu(normcc4 + b4, name="conv4")  
        
    normcc5 = tf.nn.conv2d(normcc4, weight5, strides=[1,1,1,1], padding='SAME', name="cc5")
    normconv5 = tf.nn.relu(normcc5 + b5, name="conv5")  
        
    normcc6 = tf.nn.conv2d(normcc5, weight6, strides=[1,1,1,1], padding='SAME', name="cc6")
    normconv6 = tf.nn.relu(normcc6 + b6, name="conv6")  
        
    normcc7 = tf.nn.conv2d(normcc6, weight7, strides=[1,1,1,1], padding='SAME', name="cc7")
    normconv7 = tf.nn.sigmoid(normcc7 + b7, name="conv7")  
    print("normal conv7 shape:",normconv7.shape)
        
with tf.name_scope('decom_output_low'):
    Rlow = tf.slice(lowconv7,[0,0,0,0],[-1,-1,-1,3],name="Rlow_output")   #output R low image
    Ilow = tf.slice(lowconv7,[0,0,0,3],[-1,-1,-1,1],name='Ilow_output') #output I low image        
    
with tf.name_scope('decom_output_norm'):
    Rnorm = tf.slice(normconv7,[0,0,0,0],[-1,-1,-1,3],name="Rnorm_output")   #output R norm image
    Inorm = tf.slice(normconv7,[0,0,0,3],[-1,-1,-1,1],name='Inorm_output') #output I norm image 
    print("Ilow shape:",Ilow.shape)
    print("Rlow shape:",Rlow.shape)
    print("Inorm shape:",Inorm.shape)
    print("Rnorm shape:",Rnorm.shape) 

with tf.name_scope('enhance_net'):
    with tf.name_scope("encoder_decoder"):
        pre_enh0 = tf.concat([Rlow,Ilow],axis=3) #concat on channel
        pre_enh1 = tf.layers.conv2d(inputs=pre_enh0,filters=ek1,kernel_size=[3,3],padding='same', activation=None)   
        encconv1  =  tf.layers.conv2d(inputs=pre_enh1,filters=ek1,kernel_size=[3,3],strides=[2,2],padding='same',activation=tf.nn.relu,name='downsample1')
        encconv2  =  tf.layers.conv2d(inputs=encconv1,filters=ek2,kernel_size=[3,3],strides=[2,2],padding='same',activation=tf.nn.relu,name='downsample2')
        encconv3  =  tf.layers.conv2d(inputs=encconv2,filters=ek3,kernel_size=[3,3],strides=[2,2],padding='same',activation=tf.nn.relu,name='downsample3')
        
        #import types
        #print(type(normconv7))
        conv3_shape = encconv3.shape
        upsample1 = tf.image.resize_nearest_neighbor(encconv3, (tf.shape(encconv2)[1],tf.shape(encconv2)[2]))
        upconv1  =  tf.layers.conv2d(inputs=upsample1,filters=ek1,kernel_size=[3,3],strides=[1,1],padding='same',activation=tf.nn.relu,name='upconv1')
        upresidual1 =  tf.add(upconv1,encconv2,name="upres1")
        
        upresidual1_shape = upresidual1.shape
        upsample2 = tf.image.resize_nearest_neighbor(upresidual1, (tf.shape(encconv1)[1],tf.shape(encconv1)[2]))
        upconv2  =  tf.layers.conv2d(inputs=upsample2,filters=ek1,kernel_size=[3,3],strides=[1,1],padding='same',activation=tf.nn.relu,name='upconv2')
        upresidual2 =  tf.add(upconv2,encconv1,name="upres2")
        
        upresidual2_shape = upresidual2.shape
        upsample3 = tf.image.resize_nearest_neighbor(upresidual2, (tf.shape(pre_enh1)[1],tf.shape(pre_enh1)[2]))
        upconv3  =  tf.layers.conv2d(inputs=upsample3,filters=ek1,kernel_size=[3,3],strides=[1,1],padding='same',activation=tf.nn.relu,name='upconv3')
        upresidual3 =  tf.add(upconv3,pre_enh1,name="upres3")  
        print("upresidual3 shape:",upresidual3.shape)
        
    with tf.name_scope("upsample_concat"):
        upresidual3_shape = tf.shape(upresidual3)
        preconcat1 = tf.image.resize_nearest_neighbor(upresidual1, (upresidual3_shape[1],upresidual3_shape[2])) 
        preconcat2 = tf.image.resize_nearest_neighbor(upresidual2, (upresidual3_shape[1],upresidual3_shape[2])) 
        preconcat3 = upresidual3
        concat = tf.concat([preconcat1,preconcat2,preconcat3],axis=3)  #concat on channel direction
        print("concat shape:",concat.shape)
        
    with tf.name_scope('reconstruct_illumination'):
        resizeconv = tf.layers.conv2d(inputs=concat,filters=rk,kernel_size=[1,1],strides=[1,1],padding='same',activation=None,name='resize_conv')
        reconI = tf.layers.conv2d(inputs=resizeconv,filters=1,kernel_size=[3,3],strides=[1,1],padding='same',activation=None,name='reconstruct_illumination')
        print("reconI shape:",reconI.shape)
        
with tf.name_scope("reconstruct_Image"):
    denoiseR = Rlow
    reconI_ex = tf.concat([reconI,reconI,reconI],axis=3)
    reconImage = tf.multiply(reconI_ex,denoiseR,name="reconstruct_Image")

low conv7 shape: (?, ?, ?, 4)
normal conv7 shape: (?, ?, ?, 4)
Ilow shape: (?, ?, ?, 1)
Rlow shape: (?, ?, ?, 3)
Inorm shape: (?, ?, ?, 1)
Rnorm shape: (?, ?, ?, 3)
upresidual3 shape: (?, ?, ?, 64)
concat shape: (?, ?, ?, 192)
reconI shape: (?, ?, ?, 1)


## 5.define loss function

In [5]:
lamda00=1   
lamda01=0.001
lamda10=0.001
lamda11=1
lamda_g= -10
lamda_ir = 0.01
lamda_is = 0.1
lamda_is_enh = 3

##decom-net loss function
Ilow_ex = tf.concat([Ilow,Ilow,Ilow],axis=3)
Inorm_ex = tf.concat([Inorm,Inorm,Inorm],axis=3)
xlow_image = tf.slice(xlow,[0,0,0,1],[-1,-1,-1,3])  #real image patch
xnorm_image = tf.slice(xnorm,[0,0,0,1],[-1,-1,-1,3]) #real image patch

Lrecon00 = tf.abs(tf.multiply(Ilow_ex,Rlow)-xlow_image)
Lrecon01 = tf.abs(tf.multiply(Ilow_ex,Rnorm)-xlow_image)
Lrecon10 = tf.abs(tf.multiply(Inorm_ex,Rlow)-xnorm_image)
Lrecon11 = tf.abs(tf.multiply(Inorm_ex,Rnorm)-xnorm_image)
Lrecon_decom = tf.reduce_mean(lamda00*Lrecon00+lamda01*Lrecon01+lamda10*Lrecon10+lamda11*Lrecon11)
Lir_decom = tf.reduce_mean(tf.abs(Rlow-Rnorm))

Rlow_gray = tf.image.rgb_to_grayscale(Rlow)
Rnorm_gray = tf.image.rgb_to_grayscale(Rnorm) 
'''
Lis_decom = tf.reduce_mean(tf.abs(tf.image.image_gradients(Ilow_ex)[0])*tf.exp(lamda_g*tf.image.image_gradients(Rlow)[0])) + \
            tf.reduce_mean(tf.abs(tf.image.image_gradients(Ilow_ex)[1])*tf.exp(lamda_g*tf.image.image_gradients(Rlow)[1])) + \
            tf.reduce_mean(tf.abs(tf.image.image_gradients(Inorm_ex)[0])*tf.exp(lamda_g*tf.image.image_gradients(Rnorm)[0])) + \
            tf.reduce_mean(tf.abs(tf.image.image_gradients(Inorm_ex)[1])*tf.exp(lamda_g*tf.image.image_gradients(Rnorm)[1]))
'''
Lis_decom = tf.reduce_mean(tf.abs(tf.image.image_gradients(Ilow)[0])*tf.exp(lamda_g*tf.abs(tf.image.image_gradients(Rlow_gray)[0]))) + \
            tf.reduce_mean(tf.abs(tf.image.image_gradients(Ilow)[1])*tf.exp(lamda_g*tf.abs(tf.image.image_gradients(Rlow_gray)[1]))) + \
            tf.reduce_mean(tf.abs(tf.image.image_gradients(Inorm)[0])*tf.exp(lamda_g*tf.abs(tf.image.image_gradients(Rnorm_gray)[0]))) + \
            tf.reduce_mean(tf.abs(tf.image.image_gradients(Inorm)[1])*tf.exp(lamda_g*tf.abs(tf.image.image_gradients(Rnorm_gray)[1])))

Ldecom = Lrecon_decom + lamda_ir*Lir_decom + lamda_is*Lis_decom
print("image gradient shape:",len(tf.image.image_gradients(Ilow_ex)))
print(tf.image.image_gradients(Ilow_ex)[0])

##enhance-net loss function
Lrecon_enh = tf.reduce_mean(tf.abs(reconI_ex*denoiseR-xnorm_image))
'''
Lis_enh = tf.reduce_mean(tf.abs(tf.image.image_gradients(reconI_ex)[0])*tf.exp(lamda_g*tf.image.image_gradients(denoiseR)[0])) + \
          tf.reduce_mean(tf.abs(tf.image.image_gradients(reconI_ex)[1])*tf.exp(lamda_g*tf.image.image_gradients(denoiseR)[1]))
'''
Lis_enh = tf.reduce_mean(tf.abs(tf.image.image_gradients(reconI)[0])*tf.exp(lamda_g*tf.abs(tf.image.image_gradients(Rlow_gray)[0]))) + \
          tf.reduce_mean(tf.abs(tf.image.image_gradients(reconI)[1])*tf.exp(lamda_g*tf.abs(tf.image.image_gradients(Rlow_gray)[1])))

Lenh = Lrecon_enh + lamda_is_enh*Lis_enh

image gradient shape: 2
Tensor("Reshape_18:0", shape=(?, ?, ?, 3), dtype=float32)


## 6.generate dataset and iterator

In [6]:
#get file name list
low_dir = "E:\\MyDownloads\\Download\\1006\\RetinexNet-master\\data\\our485\\low\\*.png"
norm_dir = "E:\\MyDownloads\\Download\\1006\\RetinexNet-master\\data\\our485\\high\\*.png"
syn_low_dir = "E:\\MyDownloads\\Download\\1006\\RetinexNet-master\\data\\syn\\low\\*.png"
syn_norm_dir = "E:\\MyDownloads\\Download\\1006\\RetinexNet-master\\data\\syn\\high\\*.png"
low_names = tf.train.match_filenames_once(low_dir) #return all matched names, a variable
norm_names = tf.train.match_filenames_once(norm_dir) 
syn_low_names = tf.train.match_filenames_once(syn_low_dir)
syn_norm_names = tf.train.match_filenames_once(syn_norm_dir)
low_list = tf.Variable("",dtype=tf.string) 
norm_list = tf.Variable("",dtype=tf.string) 

with tf.Session() as sess:    
    sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))
    filesfind0,filesfind1,filesfind2,filesfind3 = sess.run((low_names,norm_names,syn_low_names,syn_norm_names))
    low_string  = [""]*(filesfind0.shape[0]+filesfind2.shape[0])
    norm_string = [""]*(filesfind1.shape[0]+filesfind3.shape[0])   
    assert filesfind0.shape[0]==filesfind1.shape[0]
    assert filesfind2.shape[0]==filesfind3.shape[0]
    for i in range(0,filesfind0.shape[0]):
        low_string[i]  = filesfind0[i].decode()  #from byte to string
        norm_string[i] = filesfind1[i].decode()
    for w in range(0,filesfind2.shape[0]):
        low_string[filesfind0.shape[0]+w]  = filesfind2[w].decode()  #from byte to string
        norm_string[filesfind0.shape[0]+w] = filesfind3[w].decode()    
    low_list = low_string
    norm_list = norm_string

print("low light file number:",len(low_list))
print("norm light file number:",len(norm_list))
print("low file0",low_list[0])
print("norm file0",norm_list[0])
#im = np.asarray(Image.open(low_list[0]))/255.0
#print(im)


#build data set
total_train_images = len(low_list)
tr_data = tf.data.Dataset.from_tensor_slices((low_list,norm_list))   #buid dataset from file name
tr_data = tr_data.map(input_parser_pair) #map file name to file
tr_data = tr_data.shuffle(buffer_size=total_train_images).batch(batch).repeat(epoch) #form batch, pack 1 batch into 1 element
#tr_data = tr_data.batch(batch).repeat(epoch)  #no shuffle

#get iterator
iterator = tr_data.make_initializable_iterator() #make iterator
next_element = iterator.get_next()  #next batch actually
training_init_op = iterator.make_initializer(tr_data)

low light file number: 1485
norm light file number: 1485
low file0 E:\MyDownloads\Download\1006\RetinexNet-master\data\our485\low\10.png
norm file0 E:\MyDownloads\Download\1006\RetinexNet-master\data\our485\high\10.png


## 7.begin training
Train Decom_Net for first half time; then train Enhance_Net for another half time with Decom_Net fixed.    
During training, the input is image patches;
During evaluation and testing, the input is 400*600 image

In [7]:
##run training
decom_checkpoint = "E:\\MyDownloads\\Download\\1006\\RetinexNet-master\\mycheckpoint\\decom\\"
enhance_checkpoint = "E:\\MyDownloads\\Download\\1006\\RetinexNet-master\\mycheckpoint\\enhance\\"
train_decom_path = "E:\\MyDownloads\\Download\\1006\\RetinexNet-master\\myresult\\train\\decom\\"
train_enh_path = "E:\\MyDownloads\\Download\\1006\\RetinexNet-master\\myresult\\train\\enh\\"
test_result_path = "E:\\MyDownloads\\Download\\1006\\RetinexNet-master\\myresult\\test\\"
tboard = "E:\\MyDownloads\\Download\\1006\\RetinexNet-master\\mycheckpoint\\tboard\\"
variable_not_in_decom = [var for var in tf.trainable_variables() if 'decom' not in var.name]
variable_in_decom = [var for var in tf.trainable_variables() if 'decom' in var.name]
#optimizer_decom = tf.train.GradientDescentOptimizer(lr).minimize(Ldecom,var_list = variable_in_decom)
#optimizer_enh = tf.train.GradientDescentOptimizer(lr).minimize(Lenh, var_list = variable_not_in_decom)
optimizer_decom = tf.train.AdamOptimizer(lr).minimize(Ldecom,var_list = variable_in_decom)
optimizer_enh = tf.train.AdamOptimizer(lr).minimize(Lenh, var_list = variable_not_in_decom)
saver = tf.train.Saver()
batch_cycle=0
epoch_cycle=0
all_batch = 0
total_batch = 1485/batch
print("each epoch contain batches:",total_batch)

'''#test dataset
with tf.Session() as sess:
    # initialize the iterator on the training data
    sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))
    sess.run(training_init_op)
    elem = sess.run(next_element)
    print(elem[0][0])
'''

##for TF-board
dec_loss_summary = tf.summary.scalar('Decom_Loss', Ldecom)
enh_loss_summary = tf.summary.scalar('Enhance_Loss', Lenh)
Ilow_hist = tf.summary.histogram('Ilow', Ilow)
reconI_hist = tf.summary.histogram('reconI', reconI)
merged = tf.summary.merge([dec_loss_summary, enh_loss_summary,Ilow_hist,reconI_hist])

'''
with tf.Session() as sess:
    # initialize the iterator on the training data
    sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))
    sess.run(training_init_op)
    train_writer = tf.summary.FileWriter(tboard, sess.graph)  #define summary writter

    # get each element of the training dataset until the end is reached
    while True:
        try:  
            if(epoch_cycle < epoch/2):   #train decom-net
                elem = sess.run(next_element)  #get next batch input
                batch_cycle +=1
                all_batch +=1
                feed_train = {
                    imlow:elem[0],
                    imnorm:elem[1],
                    bypass_patch_gen:False
                }            
                sess.run(optimizer_decom,feed_dict=feed_train)
                
                if batch_cycle%10 == 0:
                    Ldecom_train = sess.run(Ldecom, feed_dict=feed_train)
                    print('**train Decom-Net: loss for batch ',batch_cycle," is ", Ldecom_train)
                    ##for TF-board
                    summary = sess.run(merged,feed_dict=feed_train)
                    train_writer.add_summary(summary,all_batch)
                
                if batch_cycle==total_batch:
                    batch_cycle =0
                    print("****end epoch:",epoch_cycle)
                    epoch_cycle +=1
                    
                if (auto_checkpoint==1 and (epoch_cycle+1)%10 == 0 and epoch_cycle>0 and batch_cycle==total_batch-1):
                    feed_eval = {
                        imlow:elem[0],  
                        imnorm:elem[1],
                        bypass_patch_gen:True 
                    }
                    print("******begin evaluate")
                    Ilow_ex_train, Rlow_train = sess.run([Ilow_ex, Rlow], feed_dict=feed_eval) 
                    save_decom = saver.save(sess,decom_checkpoint )  #save model
                    save_file = np.concatenate((elem[0][0:2],elem[1][0:2],Rlow_train, Ilow_ex_train), axis=2)
                    array_to_image(save_file,epoch_cycle,train_decom_path)

                    
            else: # train enhance-net
                elem = sess.run(next_element)  #get next batch input
                batch_cycle +=1
                all_batch +=1
                feed_train = {
                    imlow:elem[0],
                    imnorm:elem[1],
                    bypass_patch_gen:False
                }            
                sess.run(optimizer_enh,feed_dict=feed_train)
                ##for TF-board
                summary = sess.run(merged,feed_dict=feed_train)
                train_writer.add_summary(summary,all_batch)
                
                if batch_cycle%10 == 0:
                    Lenh_train = sess.run(Lenh, feed_dict=feed_train)
                    print('##train Enhance-Net: loss for batch ',batch_cycle," is ", Lenh_train)
                    ##for TF-board
                    summary = sess.run(merged,feed_dict=feed_train)
                    train_writer.add_summary(summary,all_batch)
                    
                if batch_cycle==total_batch:
                    batch_cycle =0
                    print("####end epoch:",int(epoch_cycle-epoch/2))  
                    epoch_cycle +=1
                    
                if (auto_checkpoint==1 and (epoch_cycle+1)%10 == 0 and epoch_cycle>0 and batch_cycle==total_batch-1):
                    feed_eval = {
                        imlow:elem[0],  
                        imnorm:elem[1],
                        bypass_patch_gen:True
                    }
                    print("######begin evaluate")                    
                    reconImage_train, reconI_ex_train,Ilow_ex_train = sess.run([reconImage, reconI_ex,Ilow_ex], feed_dict=feed_eval)
                    save_enhance = saver.save(sess,enhance_checkpoint ) #save model  
                    save_decom = saver.save(sess,decom_checkpoint )  #save model
                    save_file = np.concatenate((elem[0][0:2],elem[1][0:2],Ilow_ex_train,reconI_ex_train, reconImage_train), axis=2)
                    array_to_image(save_file,int(epoch_cycle-epoch/2),train_enh_path)

        except tf.errors.OutOfRangeError:
            print("End of training for Retinex-Net")
            break  
'''          

each epoch contain batches: 99.0


'\nwith tf.Session() as sess:\n    # initialize the iterator on the training data\n    sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))\n    sess.run(training_init_op)\n    train_writer = tf.summary.FileWriter(tboard, sess.graph)  #define summary writter\n\n    # get each element of the training dataset until the end is reached\n    while True:\n        try:  \n            if(epoch_cycle < epoch/2):   #train decom-net\n                elem = sess.run(next_element)  #get next batch input\n                batch_cycle +=1\n                all_batch +=1\n                feed_train = {\n                    imlow:elem[0],\n                    imnorm:elem[1],\n                    bypass_patch_gen:False\n                }            \n                sess.run(optimizer_decom,feed_dict=feed_train)\n                \n                if batch_cycle%10 == 0:\n                    Ldecom_train = sess.run(Ldecom, feed_dict=feed_train)\n                    print(\'**trai

## 8. image test

In [8]:
#load saved model or test with current model
testface = "E:\\MyDownloads\\Download\\1006\\RetinexNet-master\\data\\MyTestFace_small\\"

testim0 = np.asarray(Image.open(os.path.join(testface,"IMG_2977.jpg")))/255.0
testim1 = np.asarray(Image.open(os.path.join(testface,"IMG_2993.jpg")))/255.0
testim2 = np.asarray(Image.open(os.path.join(testface,"IMG_3017.jpg")))/255.0
testim3 = np.asarray(Image.open(os.path.join(testface,"IMG_3030.jpg")))/255.0
testim = np.stack((testim0,testim1,testim2,testim3),axis=0) #concatenate as batch


if load_pretrain==1:
    with tf.Session() as sess:
        # Restore variables from disk.
        saver.restore(sess, enhance_checkpoint)
        print("Model restored.") 
        #several images
        feed_test = {
                imlow: testim, 
                imnorm: testim,
                bypass_patch_gen:True
        }
        reconImage_test = sess.run(reconImage, feed_dict=feed_test)
        print(testim.shape)
        print(reconImage_test.shape)
        save_test = np.concatenate((testim,reconImage_test), axis=2)
        array_to_image(save_test,100,test_result_path)  
        print("test done!")
else:
    pass

INFO:tensorflow:Restoring parameters from E:\MyDownloads\Download\1006\RetinexNet-master\mycheckpoint\enhance\
Model restored.
(4, 618, 464, 3)
(4, 618, 464, 3)
test done!
