# RM-MMDnet

In [1]:
import tensorflow as tf
import numpy as np
from ops import *
from utils import *
import os
import time
from glob import glob
from scipy.misc import imsave as ims
from random import randint
from data_providers import *
slim = tf.contrib.slim
import scipy as sp

cifar = True
%pylab inline

# config = tf.ConfigProto(
#     gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.2),
#     device_count = {'GPU': 1}
# )

Populating the interactive namespace from numpy and matplotlib


`%matplotlib` prevents importing * from pylab and numpy
  "\n`%matplotlib` prevents importing * from pylab and numpy"


In [2]:
def transform(image, npx=64, is_crop=True):
    if is_crop:
        cropped_image = center_crop(image, npx)
    else:
        cropped_image = image
    return np.array(cropped_image)/127.5 - 1.

In [3]:
params = {
    'batch_size':64,
    'image_dim':32*32*3,
    'c':3,
    'h':32,
    'w':32
}

def makeScaleMatrix(num_gen, num_orig):

    # first 'N' entries have '1/N', next 'M' entries have '-1/M'
    s1 =  tf.constant(1.0 / num_gen, shape = [num_gen, 1])
    s2 = -tf.constant(1.0 / num_orig, shape = [num_orig, 1])

    return tf.concat([s1, s2],0)

def computeRatio(x, gen_x, sigma = [1,2,4,8,16]):
    
    x = slim.flatten(x)
    gen_x = slim.flatten(gen_x)

    # concatenation of the generated images and images from the dataset
    # first 'N' rows are the generated ones, next 'M' are from the data
    X = tf.concat([gen_x, x],0)

    # dot product between all combinations of rows in 'X'
    XX = tf.matmul(X, tf.transpose(X))

    # dot product of rows with themselves
    X2 = tf.reduce_sum(X * X, 1, keep_dims = True)

    # exponent entries of the RBF kernel (without the sigma) for each
    # combination of the rows in 'X'
    # -0.5 * (x^Tx - 2*x^Ty + y^Ty)
    exponent = XX - 0.5 * X2 - 0.5 * tf.transpose(X2)

    # scaling constants for each of the rows in 'X'
    s = makeScaleMatrix(params['batch_size'], params['batch_size'])

    # scaling factors of each of the kernel values, corresponding to the
    # exponent values
    S = tf.matmul(s, tf.transpose(s))

    loss1 = 0
    loss2 = 0
    mmd = 0

    # for each bandwidth parameter, compute the MMD value and add them all
    n = params['batch_size']
    n_sq = float(n*n)
    for i in range(len(sigma)):

        # kernel values for each combination of the rows in 'X' 
        kernel_val = tf.exp(1.0 / sigma[i] * exponent)
        
        k_de_de = kernel_val[:n,:n]
        k_de_nu = kernel_val[:n,n:]
        k_nu_de = kernel_val[n:,:n]
        k_nu_nu = kernel_val[n:,n:]
        
        loss1 += tf.matmul(tf.matmul(tf.matrix_inverse(k_de_de),k_de_nu),tf.ones(shape=(params['batch_size'],1)))
        mmd += tf.reduce_sum(S * kernel_val)


    return loss1/float(len(sigma)), tf.sqrt(mmd)


In [4]:
# def discriminator(image, reuse=False, i=0):
    
#     with tf.variable_scope("disc") as scope:
#         if reuse:
#                 scope.reuse_variables()
#         h=32
#         w=32
#         h0 = image
#         h0 = lrelu(conv2d(h0, 3, df_dim, name='d_h0_conv')) #16x16x32
#         h1 = lrelu(tf.contrib.layers.batch_norm(conv2d(h0, df_dim, df_dim, name='d_h1_conv'+str(i)))) #8x8x64
#         h2 = lrelu(tf.contrib.layers.batch_norm(conv2d(h1, df_dim, df_dim*2, name='d_h2_conv'+str(i)))) #4x4x128
#         h3 = tf.reshape(h2, [batchsize, -1])
#         h4 = dense(h3, 4*4*df_dim*2, 128, scope='d_h4_lin'+str(i)) #2048
#         print h0
#         print h1
#         print h2
#         print h3
#         print h4
#         return h4
    
def discriminator(image, reuse=False, i=0):
    
    with tf.variable_scope("disc") as scope:
        if reuse:
                scope.reuse_variables()
        h=32
        w=32
        h0 = image
        h0 = lrelu(conv2d(h0, 3, df_dim, name='d_h0_conv')) #16x16x32
        h1 = lrelu(tf.contrib.layers.batch_norm(conv2d(h0, df_dim, df_dim*2, name='d_h1_conv'+str(i)))) #8x8x64
        h2 = lrelu(tf.contrib.layers.batch_norm(conv2d(h1, df_dim*2, df_dim*2*2, name='d_h2_conv'+str(i)))) #4x4x128
        h3 = tf.reshape(h2, [batchsize, -1])
        h4 = dense(h3, 4*4*df_dim*4, 128, scope='d_h4_lin'+str(i)) #2048
        print h0
        print h1
        print h2
        print h3
        print h4
        return h4
    

        
    

        
def generator(z):
    with tf.variable_scope("gen") as scope:
        gf_dim=32
        h=32
        w=32
        z2 = dense(z, z_dim, 4*4*gf_dim*4, scope='g_h0_lin')
        h0 = tf.nn.relu(tf.contrib.layers.batch_norm(tf.reshape(z2, [-1, 4, 4, gf_dim*4]))) # 4x4x128
        h1 = tf.nn.relu(tf.contrib.layers.batch_norm(conv_transpose(h0, [batchsize, 8, 8, gf_dim*2], "g_h1"))) #8x8x64
        h2 = tf.nn.relu(tf.contrib.layers.batch_norm(conv_transpose(h1, [batchsize, 16, 16, gf_dim*1], "g_h2"))) #16x16x32
        h3 = tf.nn.tanh(conv_transpose(h2, [batchsize, 32, 32, 3], "g_h4"))
        print h0
        print h1
        print h2
        print h3
        return h3  
        

In [12]:
tf.reset_default_graph()
sdata=[]
batchsize = params['batch_size']
iscrop = False
imageshape = [32, 32, 3]
z_dim = 128
gf_dim = 32
df_dim = 32
c_dim = 3
learningrate_gen = 1e-4
learningrate = 1e-4
beta1 = 0.5
NUM_PROJ=1
NUM_MMD = 10

images = tf.placeholder(tf.float32, [batchsize] + imageshape, name="real_images")

zin = tf.placeholder(tf.float32, [None, z_dim], name="z")
G = generator(zin)

# GAN
dloss=0.0
gloss=0.0
closs=0.0

DL = discriminator(images,i=1)
GL = discriminator(G,reuse=True,i=1)



r_nu,mmd = computeRatio(GL, DL) #correct
r_de,_ = computeRatio(DL, GL)

gloss = tf.reduce_mean(mmd)

# dloss = tf.reduce_mean(tf.square(r_de)) - 2*tf.reduce_mean(r_nu)
print 'r_Nu',r_nu
dloss =  -tf.reduce_mean(tf.log(r_nu+1e-6),0) #-2*tf.reduce_mean(tf.square(r_de-1)) - tf.reduce_mean(r_nu)
print 'dloss', dloss                       
t_vars = tf.trainable_variables()

g_vars = [var for var in t_vars if 'gen' in var.name]
d_vars = [var for var in t_vars if 'disc' in var.name]

print np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])
print np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables() if 'disc' in v.name])
print np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables() if 'gen' in v.name])

g_optim = tf.train.AdamOptimizer(learningrate_gen, beta1=beta1).minimize(gloss, var_list=g_vars)
d_optim = tf.train.AdamOptimizer(learningrate,     beta1=beta1).minimize(dloss, var_list=d_vars)


start_time = time.time()
display_z = np.random.uniform(-1, 1, [batchsize, z_dim]).astype(np.float32)

batch = CIFAR10DataProvider(batch_size=params['batch_size'])
batch_idxs = batch.num_batches
batch_idx = batch_idxs


# sess = tf.InteractiveSession(config=config)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
tf.initialize_all_variables().run()



Tensor("gen/Relu:0", shape=(?, 4, 4, 128), dtype=float32)
Tensor("gen/Relu_1:0", shape=(64, 8, 8, 64), dtype=float32)
Tensor("gen/Relu_2:0", shape=(64, 16, 16, 32), dtype=float32)
Tensor("gen/Tanh:0", shape=(64, 32, 32, 3), dtype=float32)
Tensor("disc/lrelu/add:0", shape=(64, 16, 16, 32), dtype=float32)
Tensor("disc/lrelu_1/add:0", shape=(64, 8, 8, 64), dtype=float32)
Tensor("disc/lrelu_2/add:0", shape=(64, 4, 4, 128), dtype=float32)
Tensor("disc/Reshape:0", shape=(64, 2048), dtype=float32)
Tensor("disc/d_h4_lin1/add:0", shape=(64, 128), dtype=float32)
Tensor("disc_1/lrelu/add:0", shape=(64, 16, 16, 32), dtype=float32)
Tensor("disc_1/lrelu_1/add:0", shape=(64, 8, 8, 64), dtype=float32)
Tensor("disc_1/lrelu_2/add:0", shape=(64, 4, 4, 128), dtype=float32)
Tensor("disc_1/Reshape:0", shape=(64, 2048), dtype=float32)
Tensor("disc_1/d_h4_lin1/add:0", shape=(64, 128), dtype=float32)
r_Nu Tensor("div:0", shape=(64, 1), dtype=float32)
dloss Tensor("Neg_2:0", shape=(1,), dtype=float32)
857955
42

In [9]:
# '------Training GAN--------------------'
for epoch in xrange(200):
    batch.new_epoch()
     
    
    for idx in range(batch_idxs):
        batch_images,_=batch.next()
        batch_z = np.random.uniform(-1, 1, [batchsize, z_dim]).astype(np.float32) 
        
#         for i in range(1):
#             sess.run([d_optim],
#                  feed_dict={ images: batch_images, zin: batch_z })
        
        sess.run([d_optim,g_optim],
                 feed_dict={ images: batch_images, zin: batch_z })
        
#     '---------Printing intermediate results-------------'      
    if epoch % 10 == 0:
        print gloss.eval({zin: batch_z,images: batch_images})       
        
        print("Epoch: [%2d] [%4d/%4d] time: %4.4f, " % (epoch, idx, batch_idx, time.time() - start_time,))
        
        _,sdata = sess.run([images,G],feed_dict={ images: batch_images, zin: batch_z })
        sdata = sdata[:64]
        sdata = np.expand_dims(sdata,0)
        ims("./results/"+str(epoch)+"fake.jpg",merge(sdata[0],[8,8]))
        
    


gen_images = np.vstack([sess.run(G,feed_dict={ images: batch_images, \
            zin: np.random.uniform(-1, 1, [batchsize, z_dim]).astype(np.float32) }) for _ in range(250)])

np.save(open("./rmn","w"),gen_images)

0.86079043
Epoch: [ 0] [ 780/ 781] time: 34.7600, 
0.56675994
Epoch: [10] [ 780/ 781] time: 211.6749, 
0.5974993
Epoch: [20] [ 780/ 781] time: 388.6723, 
0.6454807
Epoch: [30] [ 780/ 781] time: 565.3031, 


KeyboardInterrupt: 