In [1]:
from ops import batch_norm,linear,conv2d,deconv2d,lrelu
from image_helpers import *
import tensorflow as tf
from glob import glob
import numpy as np
import os,time

### Model Defintion

![alt text](images/model.png)

In [2]:
#Parameter Defintion
is_crop=True
batch_size=64
image_size=108
sample_size=64
image_shape=[100,16,3]
z_dim=100
gf_dim=64
df_dim=64
learning_rate=0.0002
beta1=0.5

In [3]:
#Batch Normalisation objects
d_bn1 = batch_norm(name='d_bn1')
d_bn2 = batch_norm(name='d_bn2')
d_bn3 = batch_norm(name='d_bn3')
g_bn0 = batch_norm(name='g_bn0')
g_bn1 = batch_norm(name='g_bn1')
g_bn2 = batch_norm(name='g_bn2')
g_bn3 = batch_norm(name='g_bn3')

In [4]:
def discriminator(image,reuse=False):
    if reuse:
        tf.get_variable_scope().reuse_variables()
        
    h0=lrelu(conv2d(image,df_dim,name='d_h0_conv'))
    h1=lrelu(d_bn1(conv2d(h0,df_dim*2,name='d_h1_conv')))
    h2=lrelu(d_bn2(conv2d(h1,df_dim*4,name='d_h2_conv')))
    h3=lrelu(d_bn3(conv2d(h2,df_dim*8,name='d_h3_conv')))
    h4=linear(tf.reshape(h3,[batch_size,-1]),1,'d_h3_lin')
             
    return tf.nn.sigmoid(h4),h4

In [5]:
def generator(z):
    z_=linear(z,gf_dim*8*4*4,'g_h0_lin')
    h0=tf.nn.relu(g_bn0(tf.reshape(z_,[-1,4,4,gf_dim*8])))
    h1=tf.nn.relu(g_bn1(deconv2d(h0,[batch_size,8,8,gf_dim*4],name='g_h1')))
    h2=tf.nn.relu(g_bn2(deconv2d(h1,[batch_size,16,16,gf_dim*2],name='g_h2')))
    h3=tf.nn.relu(g_bn3(deconv2d(h2,[batch_size,32,32,gf_dim*1],name='g_h3')))
    h4 = deconv2d(h3, [batch_size, 64, 64, 3], name='g_h4')
    
    return tf.nn.tanh(h4)

In [6]:
#Building model
images=tf.placeholder(tf.float32,[batch_size]+image_shape,name='real_images')
sample_images=tf.placeholder(tf.float32,[sample_size]+image_shape,name="sample_images")
z=tf.placeholder(tf.float32,[None,z_dim])

G=generator(z)
D,D_logits=discriminator(images)
D_,D_logits_=discriminator(G,reuse=True)

#cost fn
d_loss_real=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logits,tf.ones_like(D)))
d_loss_fake=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logits_,tf.zeros_like(D_)))
d_loss=d_loss_real+d_loss_fake

g_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logits_,tf.ones_like(D_)))
#For image completion
mask=tf.placeholder(tf.float32,[None]+image_shape,name="mask")

contextual_loss=tf.reduce_sum(tf.contrib.layers.flatten(tf.abs(tf.mul(mask,G)-tf.mul(mask,images))))
perceptual_loss=g_loss

complete_loss=contextual_loss+perceptual_loss

ValueError: Trying to share variable d_h3_lin/Matrix, but specified shape (8192, 1) and found shape (3584, 1).

In [100]:
#Optimizers
t_vars=tf.trainable_variables()

d_vars=[var for var in t_vars if 'd_' in var.name]
g_vars=[var for var in t_vars if 'g_' in var.name]

d_optim=tf.train.AdamOptimizer(learning_rate,beta1=beta1).minimize(d_loss,var_list=d_vars)
g_optim=tf.train.AdamOptimizer(learning_rate,beta1=beta1).minimize(g_loss,var_list=g_vars)
complete_optim=tf.train.AdamOptimizer(learning_rate,beta1=beta1).minimize(complete_loss,var_list=g_vars)

## Data Points

In [101]:
sess=tf.Session()
sess.run(tf.global_variables_initializer())

saver=tf.train.Saver()

In [102]:
#DATASET 
data=glob(os.path.join('img_align_celeba/','*.jpg'))

sample_z=np.random.uniform(-1,1,size=(sample_size,z_dim))
sample_files=data[0:sample_size]
sample=[get_image(sample_file,image_size,is_crop) for sample_file in sample_files]
sample_images=np.reshape(np.array(sample).astype(np.float32),[sample_size]+image_shape)


## Training the model

In [103]:
#Training
counter=1
start_time=time.time()
for epoch in range(1):
    np.random.shuffle(data)
    batchidxs=int(len(data)/batch_size)
    
    for idx in range(batchidxs):
        #try:
            batch_files=data[idx*batch_size:(idx+1)*batch_size]
            batch=[get_image(batch_file,image_size,is_crop=is_crop) for batch_file in batch_files]
            batch_images=np.reshape(np.array(batch).astype(np.float32),[batch_size]+image_shape)

            batch_z=np.random.uniform(-1,1,[batch_size,z_dim]).astype(np.float32)

            #mask
            scale=0.25
            mask_=np.ones([batch_size]+image_shape).astype(np.float32)
            l=int(64*scale)
            u=int(64*(1.0-scale))
            mask_[:,l:u,l:u,:]=0.0
            #inverse mask
            scale=0.25
            imask_=np.zeros([batch_size]+image_shape).astype(np.float32)
            l=int(64*scale)
            u=int(64*(1.0-scale))
            imask_[:,l:u,l:u,:]=1.0
            
            
            fd={z:batch_z,images:batch_images}
            sess.run([g_optim,d_optim],feed_dict=fd)
            sess.run([complete_optim],feed_dict={z:batch_z,images:batch_images,mask:mask_})
            c_loss,dloss,gloss=sess.run([complete_loss,d_loss,g_loss],feed_dict={z:batch_z,images:batch_images,mask:mask_})
            print(counter,c_loss,dloss,gloss)
            
            if np.mod(counter,5)==0:
                sample_generated,dl,gl=sess.run([G,d_loss,g_loss],feed_dict={z:sample_z,images:sample_images})
                original_part=np.multiply(sample_images,mask_)
                generated_part=np.multiply(sample_generated,imask_)
                total=np.add(original_part,generated_part)
                save_images(total,'samples\\')
                print('[Sample] d_loss: %.8f, g_loss: %.8f' % (dl, gl))
            counter+=1  
        #except:
         #   continue

1 282487.0 7.2893 0.00100419
2 276528.0 0.0993834 2.89487
3 274097.0 1.12919 0.519172
4 246674.0 1.34889 0.378413
5 256940.0 0.111176 3.92779
[Sample] d_loss: 0.11600902, g_loss: 5.41791153
6 249439.0 1.29371 0.538651
7 246668.0 0.28509 4.2912
8 228919.0 1.40452 0.440909
9 244363.0 0.73556 15.9906
10 236065.0 0.0146455 12.1339
[Sample] d_loss: 0.17829044, g_loss: 13.38587952
11 248971.0 0.0783687 3.24196
12 241467.0 4.29185 0.0225754
13 239579.0 1.54853 26.4669
14 242600.0 0.151014 25.2905
15 220674.0 0.0228339 16.4789
[Sample] d_loss: 0.09336887, g_loss: 18.01052094
16 236997.0 0.0429531 5.2789
17 227873.0 4.27723 0.0832719
18 231394.0 1.18376 24.6325
19 227557.0 0.474589 25.0789
20 227269.0 0.0977529 16.9231
[Sample] d_loss: 0.07977862, g_loss: 18.91727066
21 229288.0 0.0904242 7.49572
22 227796.0 2.97267 0.512974
23 243101.0 0.312964 14.9833
24 234674.0 0.685183 15.013
25 244656.0 0.212154 7.82399
[Sample] d_loss: 0.14891370, g_loss: 10.20151997
26 228032.0 2.20547 0.859715
27 21855

KeyboardInterrupt: 

## Saving the model

In [106]:
saver.save(sess,"checkpoint\\image_completion.chk")

'checkpoint\\image_completion.chk'

In [None]:
saver.restore(sess, "checkpoint\\all_variables.chk")
print(sess.run(tf.all_variables()))