In [1]:
import os
import numpy as np
from matplotlib.image import imread, imsave
from matplotlib import pyplot as plt 

import tensorflow as tf
import tensorflow.keras as keras

In [2]:
class data_generator:
    dir_name=None
    fileNames=None
    count=None
    n_files=None
    def __init__(self,dir_name):
        self.dir_name=dir_name
        self.fileNames=[x for x in os.listdir(self.dir_name) if x.endswith('.jpg')]
        self.count=0
        self.n_files=len(self.fileNames)

    def getRandomOne(self):
        res=[]
        img=imread(self.dir_name+self.fileNames[np.random.randint(self.n_files)])
        res.append(img/255.0)
        return np.array(res)
    
    def getEpochData(self,size):
        print("Reading Image Bank size:",size)
        res=[]
        for i in range(size):
            img=imread(self.dir_name+self.fileNames[self.count])
            self.count=(self.count+1)%self.n_files
            res.append(img/255.)
        res=np.array(res)
        print("Read Image Bank:",res.shape)
        return res
            

In [3]:
class tsop:
    def conv2d(x,n_filters,fsize,strides=1,padding='REFLECT',activation=tf.nn.relu,
               dilation_rate=1,name=None):
        p=dilation_rate*(fsize//2)
        if padding=='REFLECT':
            x=tf.pad(x,[[0,0],[p,p],[p,p],[0,0]],mode=padding,name=name+'_padding')
        x=keras.layers.Conv2D(n_filters,kernel_size=fsize,strides=strides,activation=activation,
                              dilation_rate=dilation_rate,name=name+'_conv')(x)
        return x
    def getMask_256(X):
        m,h,w,c=X.shape
        # (x,y) is the center of mask. x, y is in (64,192)
        x=int(np.random.random()*128+64)
        y=int(np.random.random()*128+64)
        # k is the size of mask, in (96,128)
        k=int(np.random.random()*32+96)
        #image mask denotes inpainting part where it is 1
        image_mask=np.zeros((m,h,w,1))
        image_mask[:,x-k//2:x+k//2,y-k//2:y+k//2,:]=1
        #local mask denotes the part as the input of local discriminator
        local_mask=np.zeros((m,h,w,1))
        local_mask[:,x-64:x+64,y-64:y+64,:]=1
        
        return image_mask,local_mask
    

In [4]:
def generator(x: tf.Tensor,mask: tf.Tensor):
    g_vars=[t for t in tf.global_variables() if t.name.startswith('generator')]
    with tf.variable_scope('generator',reuse=len(g_vars)>0):
        m,h,w,c=x.shape.as_list()
        inpainting_mask=tf.broadcast_to(mask,[m,h,w,c])
        denoting_mask=tf.broadcast_to(mask,[m,h,w,1])
        input_images=x
        x=x*(1-inpainting_mask)
        x=tf.concat((x,denoting_mask),axis=-1)
        
        x=tsop.conv2d(x,64,fsize=5,name='step0_0')

        x=tsop.conv2d(x,128,fsize=3,strides=2,name='step1_0')
        x=tsop.conv2d(x,128,fsize=3,name='step1_1')

        x=tsop.conv2d(x,256,fsize=3,strides=2,name='step2_0')
        x=tsop.conv2d(x,256,fsize=3,name='step2_1')
        x=tsop.conv2d(x,256,fsize=3,name='step2_2')
        
        x=tsop.conv2d(x,256,fsize=3,dilation_rate=2,name='step3_0')
        x=tsop.conv2d(x,256,fsize=3,dilation_rate=4,name='step3_1')
        x=tsop.conv2d(x,256,fsize=3,dilation_rate=8,name='step3_2')
        x=tsop.conv2d(x,256,fsize=3,dilation_rate=16,name='step3_3')
        
        x=tsop.conv2d(x,256,fsize=3,name='step4_0')
        x=tsop.conv2d(x,256,fsize=3,name='step4_1')
        
        x=keras.layers.Conv2DTranspose(128,kernel_size=4,strides=2,padding='same',name='step5_0_deconv1')(x)
        x=tsop.conv2d(x,128,fsize=3,name='step5_1')
        
        x=keras.layers.Conv2DTranspose(64,kernel_size=4,strides=2,padding='same',name='step6_0_deconv2')(x)
        x=tsop.conv2d(x,32,fsize=3,name='step6_1')
        x=tsop.conv2d(x,3,fsize=3,name='step6_2',activation=tf.nn.sigmoid)
        res=input_images*(1-inpainting_mask)+x*inpainting_mask
    return res

def global_discriminator(x):
    gd_vars=[t for t in tf.global_variables() if t.name.startswith('global_discriminator')]
    with tf.variable_scope('global_discriminator',reuse=len(gd_vars)>0):
        x=tsop.conv2d(x,64,fsize=5,strides=2,name='step1')
        x=tsop.conv2d(x,128,fsize=5,strides=2,name='step2')
        x=tsop.conv2d(x,256,fsize=5,strides=2,name='step3')
        x=tsop.conv2d(x,512,fsize=5,strides=2,name='step4')
        x=tsop.conv2d(x,512,fsize=5,strides=2,name='step5')
        x=tsop.conv2d(x,512,fsize=5,strides=2,name='step6')
        x=keras.layers.Flatten()(x)
        x=keras.layers.Dense(1)(x)
    return x

def local_discriminator(x,mask):
    
    ld_vars=[t for t in tf.global_variables() if t.name.startswith('local_discriminator')]
    with tf.variable_scope('local_discriminator',reuse=len(ld_vars)>0):
        m,w,h,c=x.shape.as_list()
        mask=tf.broadcast_to(mask,[m,w,h,c])
        win=tf.reshape(tf.boolean_mask(x,mask),shape=(m,128,128,c))
        x=tsop.conv2d(win,64,fsize=5,strides=2,name='step1')
        x=tsop.conv2d(x,128,fsize=5,strides=2,name='step2')
        x=tsop.conv2d(x,256,fsize=5,strides=2,name='step3')
        x=tsop.conv2d(x,512,fsize=5,strides=2,name='step4')
        x=tsop.conv2d(x,512,fsize=5,strides=2,name='step5')
        x=keras.layers.Flatten()(x)
        x=keras.layers.Dense(1)(x)
    return x

def getInterpSample(original_images,generated_images):
    with tf.variable_scope('getInterpSample'):
        m,h,w,c=original_images.shape.as_list() 

        shape=original_images.shape.as_list()
        original_images_vector=tf.reshape(original_images,(-1,h*w*c))
        generated_images_vector=tf.reshape(generated_images,(-1,h*w*c))

        eps=tf.random_uniform([m,1],minval=0.,maxval=1.)
        interp=tf.reshape(eps*original_images_vector+(1-eps)*generated_images_vector,shape)
    return interp

def getGradientPenalty(out_local,out_global,interp,inpainting_mask):
    with tf.variable_scope('getGradientPenalty'):
        mask=tf.broadcast_to(inpainting_mask,interp.shape.as_list())
        gradient_local=tf.multiply(tf.gradients(out_local,interp)[0],mask)
        gradient_global=tf.multiply(tf.gradients(out_global,interp)[0],mask)
        slope_local=tf.sqrt(tf.reduce_sum(tf.square(gradient_local),axis=[1,2,3]))
        slope_global=tf.sqrt(tf.reduce_sum(tf.square(gradient_global),axis=[1,2,3]))
        penalty_local=tf.reduce_mean(tf.nn.relu(slope_local-1))
        penalty_global=tf.reduce_mean(tf.nn.relu(slope_global-1))
    return penalty_local,penalty_global

In [11]:
class GAN:
    
    def __init__(self):
        self.data=data_generator("training_data/")
        self.total_epoches=1000
        self.batch_size=32
        self.N_each_epoch=320
    
    def build_net_architecture(self):
        #shape=[self.batch_size,256,256,3]
        tf.reset_default_graph()
        original_images=tf.get_variable(name="original_images",dtype=tf.float32,shape=[self.batch_size,256,256,3])
        images_mask=tf.get_variable(name="images_mask",dtype=tf.float32,shape=[self.batch_size,256,256,1])
        localD_mask=tf.get_variable(name="localD_mask",dtype=tf.float32,shape=[self.batch_size,256,256,1])
        
        generated_images=generator(original_images,images_mask)
        
        global_D_pos=global_discriminator(original_images)
        global_D_neg=global_discriminator(generated_images)
        
        local_D_pos=local_discriminator(original_images,localD_mask)
        local_D_neg=local_discriminator(generated_images,localD_mask)
        
        interp=getInterpSample(original_images,generated_images)
        global_D_interp=global_discriminator(interp)
        local_D_interp=local_discriminator(interp,localD_mask)
        penalty_local,penalty_global=getGradientPenalty(local_D_interp,global_D_interp,interp,images_mask)
        
        local_discriminator_loss=tf.reduce_mean(local_D_neg)-tf.reduce_mean(local_D_pos)+10*penalty_local
        global_discriminator_loss=tf.reduce_mean(global_D_neg)-tf.reduce_mean(global_D_pos)+10*penalty_global
        mse_generator_loss=tf.reduce_mean((generated_images-original_images)**2)
        
        D_loss=local_discriminator_loss+global_discriminator_loss
        G_loss=-tf.reduce_mean(local_D_neg)-tf.reduce_mean(global_D_neg)+mse_generator_loss
        
        return D_loss,G_loss,generated_images
    
    def network_visualization(self):
        tensorboard_dir='tensorboard/'
        if not os.path.exists(tensorboard_dir):
            os.mkdir(tensorboard_dir)
        writer=tf.summary.FileWriter(tensorboard_dir)
        sess=tf.Session()
        sess.run(tf.global_variables_initializer())
        writer.add_graph(sess.graph)
        sess.close()
        
    def train(self):
        
        D_loss,G_loss,generated_images=self.build_net_architecture()
        
        graph=tf.get_default_graph()
        original_images=graph.get_tensor_by_name("original_images:0")
        images_mask=graph.get_tensor_by_name("images_mask:0")
        localD_mask=graph.get_tensor_by_name("localD_mask:0")
        
        

        #define optimizers
        t_vars=tf.trainable_variables()
        d_vars=[var for var in t_vars if 'discriminator' in var.name]
        g_vars=[var for var in t_vars if 'generator' in var.name]
        
        d_opt=tf.train.RMSPropOptimizer(learning_rate=1e-4).minimize(D_loss,var_list=d_vars)
        g_opt=tf.train.RMSPropOptimizer(learning_rate=1e-4).minimize(G_loss,var_list=g_vars)
        
        #training
        saver = tf.train.Saver()
        if not os.path.exists('checkpoints/'):
            os.makedirs('checkpoints/')
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            if not os.path.exists('out/'):
                os.makedirs('out/')
            
            for epoch in range(self.total_epoches):
                imgs_epoch=self.data.getEpochData(self.N_each_epoch)
                total_batch=int(self.N_each_epoch/self.batch_size)
                
                for e in range(total_batch):
                    print("Epoch: %d, batch: %d###" % (epoch,e),end=", ")
                    imgs=imgs_epoch[e*self.batch_size:min((e+1)*self.batch_size,self.N_each_epoch)]
                    mask_gen,mask_local=tsop.getMask_256(imgs)
                    
                    print("Training discriminators",end=" ")
                    _,D_loss_=sess.run([d_opt,D_loss],
                                           feed_dict={original_images:imgs,images_mask:mask_gen,localD_mask:mask_local})
                    for i in range(10):
                        print("%d Training generator"%i,end=', ')
                        _,G_loss_=sess.run([g_opt,G_loss],
                                       feed_dict={original_images:imgs,images_mask:mask_gen,localD_mask:mask_local})
                    print('---Epoch %s, D_loss: %s, G_loss: %s' % (epoch,D_loss_,G_loss_))    
                    
                #visulization
                if epoch % 10 ==0:
                    img_out_t=sess.run(generated_images,feed_dict={original_images:imgs,
                                                                   images_mask:mask_gen,
                                                                   localD_mask:mask_local})
                    img_=img_out_t.reshape((self.batch_size,256,256,3))[0]
                    imsave('out/%s.png' %(epoch),img_)
                    saver.save(sess, "checkpoints/model.ckpt")
            print("Done!")

In [12]:
gan=GAN()

In [None]:
gan.train()

Reading Image Bank size: 320
Read Image Bank: (320, 256, 256, 3)
Epoch: 0, batch: 0###, Training discriminators 0 Training generator, 

In [22]:
gan.network_visualization()

In [13]:
a=tf.constant(1.0,shape=(3,4))
b=tf.constant(2.0,shape=(4,2))
p=tf.matmul(a,b)
g=tf.gradients(p,[a,b])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    res=sess.run(g)
    print(res)

[array([[4., 4., 4., 4.],
       [4., 4., 4., 4.],
       [4., 4., 4., 4.]], dtype=float32), array([[3., 3.],
       [3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)]
