This code was used for the simulations in Figure 3 and Appendix E.6 of our paper. (This code is for our algorithm, GDA, and Unrolled GANs, on the four Gaussian mixture dataset.)

For our algorithm, set the hyperparamters unrolling_steps=0, disc_steps = 6, rate = 4. For GDA with 1 discriminator step, set the hyperparamters unrolling_steps=0, disc_steps = 1, rate = 1.  For GDA with 6 discriminator steps, set  the hyperparamters unrolling_steps=0, disc_steps = 6, rate = 1.  For Unrolled GANs, set the hyperparamters unrolling_steps=6, disc_steps = 1, rate = 1. For OMD, use the optimAdam optimizer and disc_steps = 1, rate = 1.

(Note that the part of our algorithm which saves the generator and discriminator weights is not implemented efficiently in tensorflow the particular code we give here.  It instead copies the weights into numpy and then back into tensorflow, which is a very innefficient process.  For this reason, each iteration of our algorithm runs very slowly in this code.  For an efficient implementation of our algorithm, with the weights "saved" directly in tensorflow, see instead our code for MNIST or CIFAR-10.)

In [1]:
%pylab inline
from collections import OrderedDict
import tensorflow as tf
ds = tf.contrib.distributions
slim = tf.contrib.slim
        
from keras.optimizers import Adam

try:
    from moviepy.video.io.bindings import mplfig_to_npimage
    import moviepy.editor as mpy
    generate_movie = True
except:
    print("Warning: moviepy not found.")
    generate_movie = False

Populating the interactive namespace from numpy and matplotlib

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



Using TensorFlow backend.


In [2]:
_graph_replace = tf.contrib.graph_editor.graph_replace

def remove_original_op_attributes(graph):
    """Remove _original_op attribute from all operations in a graph."""
    for op in graph.get_operations():
        op._original_op = None
        
def graph_replace(*args, **kwargs):
    """Monkey patch graph_replace so that it works with TF 1.0"""
    remove_original_op_attributes(tf.get_default_graph())
    return _graph_replace(*args, **kwargs)

### Utility functions

In [3]:
def extract_update_dict(update_ops):
    """Extract variables and their new values from Assign and AssignAdd ops.
    
    Args:
        update_ops: list of Assign and AssignAdd ops, typically computed using Keras' opt.get_updates()

    Returns:
        dict mapping from variable values to their updated value
    """
    name_to_var = {v.name: v for v in tf.global_variables()}
    updates = OrderedDict()
    for update in update_ops:
        var_name = update.op.inputs[0].name
        var = name_to_var[var_name]
        value = update.op.inputs[1]
        if update.op.type == 'Assign':
            updates[var.value()] = value
        elif update.op.type == 'AssignAdd':
            updates[var.value()] = var + value
        #else:
        #    raise ValueError("Update op type (%s) must be of type Assign or AssignAdd"%update_op.op.type)
    return updates

### Data creation

In [4]:
def sample_mog(batch_size, n_mixture=5, std=0.01, radius=1.0):
    thetas = np.linspace(0, 2 * np.pi, n_mixture)
    xs, ys = radius * np.sin(thetas), radius * np.cos(thetas)
    cat = ds.Categorical(tf.zeros(n_mixture))
    comps = [ds.MultivariateNormalDiag([xi, yi], [std, std]) for xi, yi in zip(xs.ravel(), ys.ravel())]
    data = ds.Mixture(cat, comps)
    return data.sample(batch_size)

### Generator and discriminator architectures

In [5]:
def generator(z, output_dim=2, n_hidden=128, n_layer=2):
    with tf.variable_scope("generator"):
        h = slim.stack(z, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.relu)#tf.nn.tanh)
        x = slim.fully_connected(h, output_dim, activation_fn=None)
    return x

def discriminator(x, n_hidden=128, n_layer=2, reuse=False):
    with tf.variable_scope("discriminator", reuse=reuse):
        h = slim.stack(x, slim.fully_connected, [n_hidden] * n_layer, activation_fn=tf.nn.relu)#tf.nn.tanh)
        log_d = slim.fully_connected(h, 1, activation_fn=None)
    return log_d

### Hyperparameters

In [16]:
params = dict(
    batch_size=512,
    disc_learning_rate=1e-4,
    gen_learning_rate=1e-3,
    beta1=0.5,
    epsilon=1e-8,
    max_iter=1500,
    viz_every=100,
    z_dim=256,
    x_dim=2,
    unrolling_steps=0,
    disc_steps = 6,
    rate = 4,
    type="cross"
)

## Train GDA!

In [17]:
def initialize_setup():

    tf.reset_default_graph()

    data = sample_mog(params['batch_size'])
    noise = ds.Normal(tf.zeros(params['z_dim']), 
                      tf.ones(params['z_dim'])).sample(params['batch_size'])
    # Construct generator and discriminator nets
    with slim.arg_scope([slim.fully_connected], weights_initializer=tf.orthogonal_initializer(gain=0.8)):
        samples = generator(noise, output_dim=params['x_dim'])
        real_score = discriminator(data)
        fake_score = discriminator(samples, reuse=True)

    # Saddle objective    
    loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=real_score, labels=tf.ones_like(real_score)) +
        tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_score, labels=tf.zeros_like(fake_score)))

    gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "generator")
    #disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "discriminator")

    # Vanilla discriminator update: for loop is to take many discriminator steps
    #for disc_steps  in range(0, 10):
    disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "discriminator")
    d_opt = Adam(lr=params['disc_learning_rate'], beta_1=params['beta1'], epsilon=params['epsilon'])
    updates = d_opt.get_updates(disc_vars, [], loss)
    d_train_op = tf.group(*updates, name="d_train_op")





    # Unroll optimization of the discrimiantor
    if params['unrolling_steps'] > 0:
        # Get dictionary mapping from variables to their update value after one optimization step
        update_dict = extract_update_dict(updates)
        cur_update_dict = update_dict
        for i in range(params['unrolling_steps'] - 1):
            # Compute variable updates given the previous iteration's updated variable
            cur_update_dict = graph_replace(update_dict, cur_update_dict)
        # Final unrolled loss uses the parameters at the last time step
        unrolled_loss = graph_replace(loss, cur_update_dict)
    else:
        unrolled_loss = loss

    # Optimize the generator on the unrolled loss
    g_train_opt = tf.train.AdamOptimizer(params['gen_learning_rate'], beta1=params['beta1'], epsilon=params['epsilon'])
    g_train_op = g_train_opt.minimize(-unrolled_loss, var_list=gen_vars)

    # Optimize the Vanilla generator on the non-unrolled loss
    g_train_op_vanilla = g_train_opt.minimize(-loss, var_list=gen_vars)




    #make a copy of the weights
    disc_vars_copy = [tf.identity(disc_vars[0]), tf.identity(disc_vars[1]), tf.identity(disc_vars[2]), tf.identity(disc_vars[3]), tf.identity(disc_vars[4]), tf.identity(disc_vars[5])]
    gen_vars_copy = [tf.identity(gen_vars[0]), tf.identity(gen_vars[1]), tf.identity(gen_vars[2]), tf.identity(gen_vars[3]), tf.identity(gen_vars[4]), tf.identity(gen_vars[5])]
    
    return loss, unrolled_loss, g_train_op, d_train_op, samples, data


In [None]:
### GDA algorithm
from tqdm.notebook import tqdm

xmax = 3
fs = []
frames = []
np_samples = []
n_batches_viz = 100
viz_every = params['viz_every']    

reps = 20
for _ in tqdm(range(reps)):
    loss, unrolled_loss, g_train_op, d_train_op, samples, data = initialize_setup()
    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())

    #initialize at a very high loss value (to ensure that the first step is accepted)
    f = [10000000]
    points = []
    for i in tqdm(range(params['max_iter'])):
        
        disc_steps = 1
        rate = params['rate']
        
        #take one generator gradient update and one discriminator gradient update
        f_new, _, _ = sess.run([[unrolled_loss, loss], g_train_op, d_train_op])

        #take the additional discriminator gradient updates
        for disc_steps  in range(0, disc_steps-1):
             f_new, _ = sess.run([[loss], d_train_op])

        f = f_new

        if i % viz_every == 0:
            np_samples.append(np.vstack([sess.run(samples) for _ in range(n_batches_viz)]))
            xx, yy = sess.run([samples, data])
            points.append((xx, yy))
            xx, yy = sess.run([samples, data])
            fig = figure(figsize=(5,5))
            plt.scatter(yy[:, 0], yy[:, 1], c='r', edgecolor='none')
            plt.scatter(xx[:, 0], xx[:, 1], edgecolor='none')
            plt.axis('off')
            plt.xlim(-1.5, 1.5)
            plt.ylim(-1.5, 1.5)
            plt.show()


## Train using Minimax algorithm

In [None]:
from tqdm.notebook import tqdm
xmax = 3
fs = []
frames = []
np_samples = []
n_batches_viz = 100
viz_every = params['viz_every']

reps = 20
ps = []
for _ in tqdm(range(reps)):

    loss, unrolled_loss, g_train_op, d_train_op, samples, data = initialize_setup()
    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())

    #initialize at a very high loss value (to ensure that the first step is accepted)
    f = [10000000]

    for i in (range(params['max_iter'])):
        disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "discriminator")
        gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "generator")

        fs.append(f)
        disc_steps = params["disc_steps"]
        rate = params['rate']

        #Accept/reject step
        if i%rate != 0:

            #copy the current generator and discriminator weights in case the algorithm decides to go back to them later
            gen_vars_old = sess.run(gen_vars)
            disc_vars_old = sess.run(disc_vars)

            #take one generator gradient update and one discriminator gradient update
            f_new, _, _ = sess.run([[unrolled_loss, loss], g_train_op, d_train_op])


            #take the additional discriminator updates
            for disc_steps  in range(0, disc_steps-1):
                f_new, _ = sess.run([[loss], d_train_op])


            #Keep the weights which lead to the smaller loss    
            if f_new > f:
                print("reject")
                #If the old weights are better, replace the new weights with the old weights
                sess.run(tf.assign(gen_vars[0], gen_vars_old[0]))
                sess.run(tf.assign(gen_vars[1], gen_vars_old[1]))   
                sess.run(tf.assign(gen_vars[2], gen_vars_old[2])) 
                sess.run(tf.assign(gen_vars[3], gen_vars_old[3]))
                sess.run(tf.assign(gen_vars[4], gen_vars_old[4]))    
                sess.run(tf.assign(gen_vars[5], gen_vars_old[5])) 
            else:
                print("accept")
            f = f_new

        else:
            #take one generator gradient update and one discriminator gradient update
            f_new, _, _ = sess.run([[unrolled_loss, loss], g_train_op, d_train_op])

            #take the additional discriminator gradient updates
            for disc_steps  in range(0, disc_steps-1):
                 f_new, _ = sess.run([[loss], d_train_op])

            f = f_new




        #plot the output
        if i % viz_every == 0:
            np_samples.append(np.vstack([sess.run(samples) for _ in range(n_batches_viz)]))
            xx, yy = sess.run([samples, data])
            fig = figure(figsize=(5,5))
            scatter(yy[:, 0], yy[:, 1], c='r', edgecolor='none')
            scatter(xx[:, 0], xx[:, 1], edgecolor='none')
            axis('off')
            plt.xlim(-1.5, 1.5)
            plt.ylim(-1.5, 1.5)
            plt.show()
        

## Train using OMD

In [27]:
from optimAdam import *

def initialize_omd_setup():
    tf.reset_default_graph()

    data = sample_mog(params['batch_size'])

    noise = ds.Normal(tf.zeros(params['z_dim']), 
                      tf.ones(params['z_dim'])).sample(params['batch_size'])
    # Construct generator and discriminator nets
    with slim.arg_scope([slim.fully_connected], weights_initializer=tf.orthogonal_initializer(gain=0.8)):
        samples = generator(noise, output_dim=params['x_dim'])
        real_score = discriminator(data)
        fake_score = discriminator(samples, reuse=True)

    # Wasserstein loss    
    loss = tf.reduce_mean(fake_score - real_score)

    gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "generator")

    # Vanilla discriminator update: for loop is to take many discriminator steps
    disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "discriminator")
    d_opt = optimAdam(lr=params['disc_learning_rate'], beta_1=params['beta1'], epsilon=params['epsilon'])
    updates = d_opt.get_updates(disc_vars, [], loss)
    d_train_op = tf.group(*updates, name="d_train_op")





    # Unroll optimization of the discrimiantor
    if params['unrolling_steps'] > 0:
        # Get dictionary mapping from variables to their update value after one optimization step
        update_dict = extract_update_dict(updates)
        cur_update_dict = update_dict
        for i in range(params['unrolling_steps'] - 1):
            # Compute variable updates given the previous iteration's updated variable
            cur_update_dict = graph_replace(update_dict, cur_update_dict)
        # Final unrolled loss uses the parameters at the last time step
        unrolled_loss = graph_replace(loss, cur_update_dict)
    else:
        unrolled_loss = loss

    # # Optimize the generator on the unrolled loss
    
    g_opt = optimAdam(lr=params['gen_learning_rate'], beta_1=params['beta1'], epsilon=params['epsilon'])
    updates = g_opt.get_updates(gen_vars, [], -loss)
    g_train_op = tf.group(*updates, name="d_train_op")


    return loss, g_train_op, d_train_op, samples, data    

In [None]:
from tqdm.notebook import tqdm

reps = 20 
ps = []
for _ in tqdm(range(reps)):

    loss, g_train_op, d_train_op, samples, data = initialize_omd_setup()
    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())


    xmax = 3
    fs = []
    frames = []
    np_samples = []
    n_batches_viz = 100
    viz_every = params['viz_every']

    #initialize at a very high loss value (to ensure that the first step is accepted)
    f = [10000000]
    disc_steps = 1
    for i in tqdm(range(params['max_iter'])):


        disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "discriminator")
        gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "generator")


        fs.append(f)


        #take one generator gradient update and one discriminator gradient update
        f_new, _, _ = sess.run([[loss, loss], g_train_op, d_train_op])

        #take the additional discriminator gradient updates
        for disc_steps  in range(0, disc_steps-1):
             f_new, _ = sess.run([[loss], d_train_op])
        f = f_new

        # Gradient clipping for WGAN
        for j in range(len(disc_vars)):
            sess.run(tf.assign(disc_vars[j], tf.clip_by_value(disc_vars[j], clip_value_min=-0.1, clip_value_max=0.1)))

        #plot the output
        if i % viz_every == 0:
            np_samples.append(np.vstack([sess.run(samples) for _ in range(n_batches_viz)]))
            xx, yy = sess.run([samples, data])
            fig = figure(figsize=(5,5))
            ps.append((xx,yy))
            scatter(yy[:, 0], yy[:, 1], c='r', edgecolor='none')
            scatter(xx[:, 0], xx[:, 1], edgecolor='none')
            axis('off')
            plt.xlim(-1.5, 1.5)
            plt.ylim(-1.5, 1.5)

            show()