# Low light face enhancement with Retinex-Net

## 1.import library

In [1]:
#this project is for low light face enhancement
import numpy as np
import sklearn
import tensorflow as tf
import os
import scipy
import random

  from ._conv import register_converters as _register_converters


## 2.define hyper parameters

In [2]:
#define all parameters
seed = 2018
patch = 48 
lr = 0.01 #learning rate
batch = 5 #batch number, better a factor of total image number
epoch = 1
dk1 = 64 #decom-net kernel number for conv1
dk2 = 64
dk3 = 64
dk4 = 64
dk5 = 64
dk6 = 64
dk7 = 4

ek1 = 64 #enhance-net kernel number for conv1
ek2 = 64
ek3 = 64
rk = 64 #resize layer kernel numebr

lamda00=1   
lamda01=0.01
lamda10=0.01
lamda11=1
lamda_g= -10
lamda_ir = 0.001
lamda_is = 0.1
lamda_is_enh = 3
random.seed = seed

## 3.define local functions

In [3]:
def nearestNeighborScaling4D( source, newHt,newWid):
    target = tf.image.resize_nearest_neighbor(source, (newHt,newWid))
    return target

'''
def nearestNeighborScaling4D( source, newHt,newWid):
     #souce: 4D tensor, 4D=[batch,height,width,channel]
     #newHt: target Height
     #newWid: target Width
    width  = int(source.shape[2])
    height = int(source.shape[1])
    print("source shape:",source.shape)    
    newX = tf.Variable(tf.zeros(shape=[source.shape[0],newHt+1,1,source.shape[-1]]), name="tempX")
    for x in range(0, newWid):  
        newY = tf.Variable(tf.zeros(shape=[source.shape[0],1,1,source.shape[-1]]), name="tempY")        
        for y in range(0, newHt):
            srcX = int( round( float(x) / float(newWid) * float(width) ) )
            srcY = int( round( float(y) / float(newHt) * float(height) ) )
            srcX = min( srcX, width-1)
            srcY = min( srcY, height-1)
            sourceSlice = tf.slice(source,[0,srcY,srcX,0],[source.shape[0],1,1,source.shape[-1]])
            newY = tf.concat([newY,sourceSlice],axis=1)
        newX = tf.concat([newX,newY],axis=2)
    target = tf.slice(newX,[0,1,1,0],[source.shape[0],newHt,newWid,source.shape[-1]])
    print("target shape:",target.shape)
    return target
'''

def nearestNeighborScaling3D( source, newHt,newWid):
    '''
     souce: 3D array, 3D=[height,width,channel]
     newHt: target Height
     newWid: target Width
    '''
    width  = int(source.shape[1])
    height = int(source.shape[0])
    target = np.zeros((newHt,newWid,source.shape[-1]),dtype = np.float32)
    for x in range(0, newWid):  
        for y in range(0, newHt):
            srcX = int( round( float(x) / float(newWid) * float(width) ) )
            srcY = int( round( float(y) / float(newHt) * float(height) ) )
            srcX = min( srcX, width-1)
            srcY = min( srcY, height-1)
            target[y,x,:] = source[srcY,srcX,:]
    return target

# random data augmentation
def data_aug(image0, image1):
    #total 4*2*2 = 16 choice
    mode = np.random.randint(2,size=3)   #mode is an array, 3 elements
    rotate = random.randint(0,3)
    new_image0 = image0
    new_image1 = image1
    if mode[0]==1:  #rotate 0/90/180/270
        new_image0 = tf.image.rot90(new_image0,k=rotate)
        new_image1 = tf.image.rot90(new_image0,k=rotate)
    if mode[1]==1: #flip left/right
        new_image0 = tf.image.flip_left_right(new_image0)
        new_image1 = tf.image.flip_left_right(new_image1)
    if mode[2]==1:  #flip up/down
        new_image0 = tf.image.flip_up_down(new_image0)
        new_image1 = tf.image.flip_up_down(new_image1)
    return new_image0,new_image1 

#generate random patch with data augmentation
def patch_gen(image0,image1):
    imh = int(image0.shape[1])
    imw = int(image0.shape[2])
    image0_patch = tf.Variable(tf.zeros(shape=[1,patch,patch,3]), name="image_patch0")
    image1_patch = tf.Variable(tf.zeros(shape=[1,patch,patch,3]), name="image_patch1")
    for i in range(batch):
        randh = random.randint(0,imh-patch)
        randw = random.randint(0,imw-patch)
        image0_temp = tf.slice(image0,[i,randh,randw,0],[1,patch,patch,-1]) #get slice
        image1_temp = tf.slice(image1,[i,randh,randw,0],[1,patch,patch,-1])
        image0_temp,image1_temp = data_aug(image0_temp,image1_temp) #dat sugmentation
        image0_patch = tf.concat([image0_patch,image0_temp],axis=0)
        image1_patch = tf.concat([image1_patch,image1_temp],axis=0)
    image0_patch = tf.slice(image0_patch,[1,0,0,0],[-1,-1,-1,-1])
    image1_patch = tf.slice(image1_patch,[1,0,0,0],[-1,-1,-1,-1])
    return image0_patch,image1_patch

#extend channel form 3 to 4, by adding max_in_channel
def channel_extension(image0,image1):
    image0_ex = tf.reduce_max(image0, axis=3, keepdims=True)
    image1_ex = tf.reduce_max(image1, axis=3, keepdims=True)
    image0_new = tf.concat([image0_ex,image0],axis=3)  ##?? the sequence?
    image1_new = tf.concat([image1_ex,image1],axis=3)
    return image0_new,image1_new
    

#from original image pair --> input image pair
def input_parser_pair(img_path0,img_path1):
    # read the img from file
    img_file0 = tf.read_file(img_path0)
    img_file1 = tf.read_file(img_path1)
    img_decoded0 = tf.image.decode_image(img_file0,dtype=tf.float32, channels=3) #3D tensor
    img_decoded1 = tf.image.decode_image(img_file1,dtype=tf.float32, channels=3)
    image_stack0 = tf.stack([img_decoded0],axis=0) #3D->4D
    image_stack1 = tf.stack([img_decoded1],axis=0)
    img0_resize = tf.image.resize_nearest_neighbor(image_stack0,[400,600]) #only support 4D input
    img1_resize = tf.image.resize_nearest_neighbor(image_stack1,[400,600])
    img0_resize_3D = (tf.squeeze(img0_resize))/255.0 #4D->3D
    img1_resize_3D = (tf.squeeze(img0_resize))/255.0
    return img0_resize_3D,img1_resize_3D
    ''' 
    image_stack0 = tf.stack([img_decoded0],axis=0) #3D->4D
    image_stack1 = tf.stack([img_decoded1],axis=0)
    print("image_stack0 shape:",image_stack0.shape)
    image_patch0 = tf.image.extract_image_patches(image_stack0,ksizes=[1,patch,patch,1],strides=[1,patch,patch,1],rates=[1,1,1,1],padding="VALID")
    image_patch1 = tf.image.extract_image_patches(image_stack1,ksizes=[1,patch,patch,1],strides=[1,patch,patch,1],rates=[1,1,1,1],padding="VALID")
    return image_patch0,image_patch1    
    '''
    #return img_decoded0,img_decoded1

## 4.build network

In [4]:
#build network
with tf.name_scope('input_image'):
    imlow = tf.placeholder(tf.float32, shape=[batch,400,600,3],name = 'input_low')   #input low light image
    imnorm = tf.placeholder(tf.float32, shape=[batch,400,600,3], name='input_norm') #input norm light image
    xlow_pre, xnorm_pre = patch_gen(imlow,imnorm)  #generate patch with data augmentation
    xlow, xnorm = channel_extension(xlow_pre,xnorm_pre) #channel extention 3->4 

'''
with tf.name_scope('input_image'):
    imlow = tf.placeholder(tf.float32, shape=[batch,patch,patch,3],name = 'input_low')   #input low light image
    imnorm = tf.placeholder(tf.float32, shape=[batch,patch,patch,3], name='input_norm') #input norm light image
    xlow, xnorm = channel_extension(imlow,imnorm) #channel extention 3->4
'''        
with tf.variable_scope("decom", reuse=tf.AUTO_REUSE):  #define weight and bias
    #shared weight between low light Decom-net and normal light Decom-net
    weight1 = tf.get_variable(name='w1', shape=[3,3,4,dk1],  initializer=tf.contrib.layers.xavier_initializer_conv2d(), dtype=tf.float32)
    weight2 = tf.get_variable(name='w2', shape=[3,3,dk1,dk2], initializer=tf.contrib.layers.xavier_initializer_conv2d(), dtype=tf.float32)
    weight3 = tf.get_variable(name='w3', shape=[3,3,dk2,dk3],initializer=tf.contrib.layers.xavier_initializer_conv2d(), dtype=tf.float32)
    weight4 = tf.get_variable(name='w4', shape=[3,3,dk3,dk4],initializer=tf.contrib.layers.xavier_initializer_conv2d(), dtype=tf.float32)
    weight5 = tf.get_variable(name='w5', shape=[3,3,dk4,dk5],initializer=tf.contrib.layers.xavier_initializer_conv2d(), dtype=tf.float32)
    weight6 = tf.get_variable(name='w6', shape=[3,3,dk5,dk6],initializer=tf.contrib.layers.xavier_initializer_conv2d(), dtype=tf.float32)
    weight7 = tf.get_variable(name='w7', shape=[3,3,dk6,dk7], initializer=tf.contrib.layers.xavier_initializer_conv2d(), dtype=tf.float32)
    #bias for low light conv
    lowb1 = tf.get_variable("lowb1", shape=[dk1], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    lowb2 = tf.get_variable("lowb2", shape=[dk2], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    lowb3 = tf.get_variable("lowb3", shape=[dk3], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    lowb4 = tf.get_variable("lowb4", shape=[dk4], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    lowb5 = tf.get_variable("lowb5", shape=[dk5], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    lowb6 = tf.get_variable("lowb6", shape=[dk6], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    lowb7 = tf.get_variable("lowb7", shape=[dk7], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    #bias for normal light conv
    normb1 = tf.get_variable("normb1", shape=[dk1], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    normb2 = tf.get_variable("normb2", shape=[dk2], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    normb3 = tf.get_variable("normb3", shape=[dk3], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    normb4 = tf.get_variable("normb4", shape=[dk4], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    normb5 = tf.get_variable("normb5", shape=[dk5], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    normb6 = tf.get_variable("normb6", shape=[dk6], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
    normb7 = tf.get_variable("normb7", shape=[dk7], initializer=tf.constant_initializer(0.01), dtype=tf.float32)
        
with tf.name_scope('decom_net_low'):
    lowcc1 = tf.nn.conv2d(xlow, weight1, strides=[1,1,1,1], padding='SAME', name="cc1")
    lowconv1 = tf.add(lowcc1,lowb1, name="conv1")   #no activation funciton
    print("low conv1 shape:",lowconv1.shape)
        
    lowcc2 = tf.nn.conv2d(lowcc1, weight2, strides=[1,1,1,1], padding='SAME', name="cc2")
    lowconv2 = tf.nn.relu(lowcc2 + lowb2, name="conv2")        
    #print("low conv2 shape:",lowconv2.shape)
            
    lowcc3 = tf.nn.conv2d(lowcc2, weight3, strides=[1,1,1,1], padding='SAME', name="cc3")
    lowconv3 = tf.nn.relu(lowcc3 + lowb3, name="conv3")  
    #print("low conv3 shape:",lowconv3.shape)
        
    lowcc4 = tf.nn.conv2d(lowcc3, weight4, strides=[1,1,1,1], padding='SAME', name="cc4")
    lowconv4 = tf.nn.relu(lowcc4 + lowb4, name="conv4")  
    #print("low conv4 shape:",lowconv4.shape) 

    lowcc5 = tf.nn.conv2d(lowcc4, weight5, strides=[1,1,1,1], padding='SAME', name="cc5")
    lowconv5 = tf.nn.relu(lowcc5 + lowb5, name="conv5")  
    #print("low conv5 shape:",lowconv5.shape)
        
    lowcc6 = tf.nn.conv2d(lowcc5, weight6, strides=[1,1,1,1], padding='SAME', name="cc6")
    lowconv6 = tf.nn.relu(lowcc6 + lowb6, name="conv6")  
    #print("low conv6 shape:",lowconv6.shape)
        
    lowcc7 = tf.nn.conv2d(lowcc6, weight7, strides=[1,1,1,1], padding='SAME', name="cc7")
    lowconv7 = tf.nn.sigmoid(lowcc7 + lowb7, name="conv7") 
    print("low conv7 shape:",lowconv7.shape)

with tf.name_scope('decom_net_normal'):
    normcc1 = tf.nn.conv2d(xnorm, weight1, strides=[1,1,1,1], padding='SAME', name="cc1")
    normconv1 = tf.nn.relu(normcc1 + normb1, name="conv1")
        
    normcc2 = tf.nn.conv2d(normcc1, weight2, strides=[1,1,1,1], padding='SAME', name="cc2")
    normconv2 = tf.nn.relu(normcc2 + normb2, name="conv2")        
        
    normcc3 = tf.nn.conv2d(normcc2, weight3, strides=[1,1,1,1], padding='SAME', name="cc3")
    normconv3 = tf.nn.relu(normcc3 + normb3, name="conv3")  
        
    normcc4 = tf.nn.conv2d(normcc3, weight4, strides=[1,1,1,1], padding='SAME', name="cc4")
    normconv4 = tf.nn.relu(normcc4 + normb4, name="conv4")  
        
    normcc5 = tf.nn.conv2d(normcc4, weight5, strides=[1,1,1,1], padding='SAME', name="cc5")
    normconv5 = tf.nn.relu(normcc5 + normb5, name="conv5")  
        
    normcc6 = tf.nn.conv2d(normcc5, weight6, strides=[1,1,1,1], padding='SAME', name="cc6")
    normconv6 = tf.nn.relu(normcc6 + normb6, name="conv6")  
        
    normcc7 = tf.nn.conv2d(normcc6, weight7, strides=[1,1,1,1], padding='SAME', name="cc7")
    normconv7 = tf.nn.sigmoid(normcc7 + normb7, name="conv7")  
    print("normal conv7 shape:",normconv7.shape)
        
with tf.name_scope('decom_output_low'):
    Rlow = tf.slice(lowconv7,[0,0,0,0],[batch,-1,-1,3],name="Rlow_output")   #output R low image
    Ilow = tf.slice(lowconv7,[0,0,0,3],[batch,-1,-1,1],name='Ilow_output') #output I low image        
    
with tf.name_scope('decom_output_norm'):
    Rnorm = tf.slice(normconv7,[0,0,0,0],[batch,-1,-1,3],name="Rnorm_output")   #output R norm image
    Inorm = tf.slice(normconv7,[0,0,0,3],[batch,-1,-1,1],name='Inorm_output') #output I norm image 
    print("Ilow shape:",Ilow.shape)
    print("Rlow shape:",Rlow.shape)
    print("Inorm shape:",Inorm.shape)
    print("Rnorm shape:",Rnorm.shape) 

with tf.name_scope('enhance_net'):
    with tf.name_scope("encoder_decoder"):
        pre_enh = tf.concat([Rlow,Ilow],axis=3) #concat on channel
        pre_enh = tf.layers.conv2d(inputs=pre_enh,filters=ek1,kernel_size=[3,3],padding='same', activation=None)   
        encconv1  =  tf.layers.conv2d(inputs=pre_enh,filters=ek1,kernel_size=[3,3],strides=[2,2],padding='same',activation=tf.nn.relu,name='downsample1')
        encconv2  =  tf.layers.conv2d(inputs=encconv1,filters=ek2,kernel_size=[3,3],strides=[2,2],padding='same',activation=tf.nn.relu,name='downsample2')
        encconv3  =  tf.layers.conv2d(inputs=encconv2,filters=ek3,kernel_size=[3,3],strides=[2,2],padding='same',activation=tf.nn.relu,name='downsample3')
        
        #import types
        #print(type(normconv7))
        conv3_shape = encconv3.shape
        upsample1 = nearestNeighborScaling4D(encconv3, int(conv3_shape[1]*2),int(conv3_shape[2]*2))  #upsample,W*2,H*2 
        upconv1  =  tf.layers.conv2d(inputs=upsample1,filters=ek1,kernel_size=[3,3],strides=[1,1],padding='same',activation=tf.nn.relu,name='upconv1')
        upresidual1 =  tf.add(upconv1,encconv2,name="upres1")
        
        upresidual1_shape = upresidual1.shape
        upsample2 = nearestNeighborScaling4D(upresidual1, int(upresidual1_shape[1]*2),int(upresidual1_shape[2]*2))  #upsample,W*2,H*2 
        upconv2  =  tf.layers.conv2d(inputs=upsample2,filters=ek1,kernel_size=[3,3],strides=[1,1],padding='same',activation=tf.nn.relu,name='upconv2')
        upresidual2 =  tf.add(upconv2,encconv1,name="upres2")
        
        upresidual2_shape = upresidual2.shape
        upsample3 = nearestNeighborScaling4D(upresidual2, int(upresidual2_shape[1]*2),int(upresidual2_shape[2]*2))  #upsample,W*2,H*2 
        upconv3  =  tf.layers.conv2d(inputs=upsample3,filters=ek1,kernel_size=[3,3],strides=[1,1],padding='same',activation=tf.nn.relu,name='upconv3')
        upresidual3 =  tf.add(upconv3,Ilow,name="upres3")  
        print("upresidual3 shape:",upresidual3.shape)
        
    with tf.name_scope("upsample_concat"):
        upresidual3_shape = upresidual3.shape
        preconcat1 = nearestNeighborScaling4D(upresidual1, int(upresidual3_shape[1]),int(upresidual3_shape[2])) 
        preconcat2 = nearestNeighborScaling4D(upresidual2, int(upresidual3_shape[1]),int(upresidual3_shape[2])) 
        preconcat3 = upresidual3
        concat = tf.concat([preconcat1,preconcat2,preconcat3],axis=3)  #concat on channel direction
        print("concat shape:",concat.shape)
        
    with tf.name_scope('reconstruct_illumination'):
        resizeconv = tf.layers.conv2d(inputs=concat,filters=rk,kernel_size=[1,1],strides=[1,1],padding='same',activation=None,name='resize_conv')
        reconI = tf.layers.conv2d(inputs=resizeconv,filters=1,kernel_size=[3,3],strides=[1,1],padding='same',activation=None,name='reconstruct_illumination')
        print("reconI shape:",reconI.shape)
        
with tf.name_scope("reconstruct_Image"):
    denoiseR = Rlow
    reconI_ex = tf.concat([reconI,reconI,reconI],axis=3)
    reconImage = tf.multiply(reconI_ex,denoiseR,name="reconstruct_Image")

low conv1 shape: (5, 48, 48, 64)
low conv7 shape: (5, 48, 48, 4)
normal conv7 shape: (5, 48, 48, 4)
Ilow shape: (5, 48, 48, 1)
Rlow shape: (5, 48, 48, 3)
Inorm shape: (5, 48, 48, 1)
Rnorm shape: (5, 48, 48, 3)
upresidual3 shape: (5, 48, 48, 64)
concat shape: (5, 48, 48, 192)
reconI shape: (5, 48, 48, 1)


## 5.define loss function

In [5]:
##decom-net loss function
Ilow_ex = tf.concat([Ilow,Ilow,Ilow],axis=3)
Inorm_ex = tf.concat([Inorm,Inorm,Inorm],axis=3)
xlow_image = tf.slice(xlow,[0,0,0,1],[-1,-1,-1,3])  #real image patch
xnorm_image = tf.slice(xnorm,[0,0,0,1],[-1,-1,-1,3]) #real image patch

Lrecon00 = tf.abs(tf.multiply(Ilow_ex,Rlow)-xlow_image)
Lrecon01 = tf.abs(tf.multiply(Ilow_ex,Rnorm)-xlow_image)
Lrecon10 = tf.abs(tf.multiply(Inorm_ex,Rlow)-xnorm_image)
Lrecon11 = tf.abs(tf.multiply(Inorm_ex,Rnorm)-xnorm_image)
Lrecon_decom = tf.reduce_mean(lamda00*Lrecon00+lamda01*Lrecon01+lamda10*Lrecon10+lamda11*Lrecon11)
Lir_decom = tf.reduce_mean(tf.abs(Rlow-Rnorm))

Lis_decom = tf.reduce_mean(tf.abs(tf.image.image_gradients(Ilow_ex)[0])*tf.exp(lamda_g*tf.image.image_gradients(Rlow)[0])) + \
            tf.reduce_mean(tf.abs(tf.image.image_gradients(Ilow_ex)[1])*tf.exp(lamda_g*tf.image.image_gradients(Rlow)[1])) + \
            tf.reduce_mean(tf.abs(tf.image.image_gradients(Inorm_ex)[0])*tf.exp(lamda_g*tf.image.image_gradients(Rnorm)[0])) + \
            tf.reduce_mean(tf.abs(tf.image.image_gradients(Inorm_ex)[1])*tf.exp(lamda_g*tf.image.image_gradients(Rnorm)[1]))
Ldecom = Lrecon_decom + lamda_ir*Lir_decom + lamda_is*Lis_decom


##enhance-net loss function
Lrecon_enh = tf.reduce_mean(reconI_ex*denoiseR-xnorm_image)
Lis_enh = tf.reduce_mean(tf.abs(tf.image.image_gradients(reconI_ex)[0])*tf.exp(lamda_g*tf.image.image_gradients(denoiseR)[0])) + \
          tf.reduce_mean(tf.abs(tf.image.image_gradients(reconI_ex)[1])*tf.exp(lamda_g*tf.image.image_gradients(denoiseR)[1]))
Lenh = Lrecon_enh + lamda_is_enh*Lis_enh

## 6.generate dataset and iterator

In [6]:
#get file name list
low_dir = "E:\\MyDownloads\\Download\\1006\\RetinexNet-master\\data\\our485\\low\\*.png"
norm_dir = "E:\\MyDownloads\\Download\\1006\\RetinexNet-master\\data\\our485\\high\\*.png"
syn_low_dir = "E:\\MyDownloads\\Download\\1006\\RetinexNet-master\\data\\syn\\low\\*.png"
syn_norm_dir = "E:\\MyDownloads\\Download\\1006\\RetinexNet-master\\data\\syn\\high\\*.png"
low_names = tf.train.match_filenames_once(low_dir) #return all matched names, a variable
norm_names = tf.train.match_filenames_once(norm_dir) 
syn_low_names = tf.train.match_filenames_once(syn_low_dir)
syn_norm_names = tf.train.match_filenames_once(syn_norm_dir)
low_list = tf.Variable("",dtype=tf.string) 
norm_list = tf.Variable("",dtype=tf.string) 

with tf.Session() as sess:    
    sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))
    filesfind0,filesfind1,filesfind2,filesfind3 = sess.run((low_names,norm_names,syn_low_names,syn_norm_names))
    low_string  = [""]*(filesfind0.shape[0]+filesfind2.shape[0])
    norm_string = [""]*(filesfind1.shape[0]+filesfind2.shape[0])   
    assert filesfind0.shape[0]==filesfind1.shape[0]
    assert filesfind2.shape[0]==filesfind3.shape[0]
    for i in range(0,filesfind0.shape[0]):
        low_string[i]  = filesfind0[i].decode()  #from byte to string
        norm_string[i] = filesfind1[i].decode()
    for w in range(0,filesfind2.shape[0]):
        low_string[filesfind0.shape[0]+w]  = filesfind2[w].decode()  #from byte to string
        norm_string[filesfind0.shape[0]+w] = filesfind3[w].decode()    
    low_list = low_string
    norm_list = norm_string

print("low light file number:",len(low_list))
print("norm light file number:",len(norm_list))

#build data set
total_train_images = len(low_list)
tr_data = tf.data.Dataset.from_tensor_slices((low_list,norm_list))   #buid dataset from file name
tr_data = tr_data.map(input_parser_pair) #map file name to file
tr_data = tr_data.shuffle(buffer_size=total_train_images).batch(batch).repeat(epoch) #form batch, pack 1 batch into 1 element

#get iterator
iterator = tr_data.make_initializable_iterator() #make iterator
next_element = iterator.get_next()  #next batch actually
training_init_op = iterator.make_initializer(tr_data)

low light file number: 1485
norm light file number: 1485


## 7.begin training

### 7.1 train decom-net

In [7]:
print("******** begin to train Decom Net ***********")
##run training
optimizer = tf.train.GradientDescentOptimizer(lr).minimize(Ldecom)
i=0
j=0
with tf.Session() as sess:
    # initialize the iterator on the training data
    sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))
    sess.run(training_init_op)

    # get each element of the training dataset until the end is reached
    while True:
        try:
            i +=1
            elem = sess.run(next_element)  #get next batch input image
            feed = {
                imlow:elem[0],
                imnorm:elem[1]
            }            
            sess.run(optimizer,feed_dict=feed)
            if i%(1485/5)==0:
                j +=1
                print("epoch:",j)
            if i%200 == 0:
                print('loss for batch ',i," is ", sess.run(Ldecom, feed_dict=feed))
            if j%20 == 0:
                ## fixme: add save 
                pass
        except tf.errors.OutOfRangeError:
            print("End of training for DecomNet.")
            break    

******** begin to train Decom Net ***********
loss for batch  200  is  0.42809084
epoch: 1
End of training for DecomNet.


### 7.2 train enhance net
fix decom-net weight, only train enhance net
    

In [8]:
print("******** begin to train Enhance Net ***********")
variable_not_in_decom = [var for var in tf.trainable_variables() if 'decom' not in var.name]
#print(variable_not_in_decom)
optimizer = tf.train.GradientDescentOptimizer(lr).minimize(Lenh, var_list = variable_not_in_decom)
i=0
j=0
with tf.Session() as sess:
    # initialize the iterator on the training data
    sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))   
    sess.run(training_init_op)

    # get each element of the training dataset until the end is reached
    while True:
        try:
            i +=1
            elem = sess.run(next_element)  #get next batch input image
            feed = {
                imlow:elem[0],
                imnorm:elem[1]
            }            
            sess.run(optimizer,feed_dict=feed)
            if i%(1485/5)==0:
                j +=1
                print("epoch:",j)
            if i%200 == 0:
                print('loss for batch ',i," is ", sess.run(Lenh, feed_dict=feed))
            if j%20 == 0:
                ## fixme: add save 
                pass
        except tf.errors.OutOfRangeError:
            print("End of training for DecomNet.")
            break      

******** begin to train Enhance Net ***********
loss for batch  200  is  nan
epoch: 1
End of training for DecomNet.


## 8. image test