<a href="https://colab.research.google.com/github/sayakpaul/TF-2.0-Hacks/blob/master/GANs_with_TF_2_0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook follows [this amazing tutorial on GANs](https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f) and tries to port the code to TensorFlow 2.0. 


## Install `Tensorflow 2.0`

In [0]:
!pip install tensorflow-gpu==2.0.0-beta1

## Imports

In [0]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow import keras

%matplotlib inline

## Helper function to generate a distribution (normal) for the real data

In [0]:
def get_distribution_sampler(mu, sigma):
  return lambda n: tf.convert_to_tensor(np.random.normal(mu, sigma, (1, n)))

## Helper function to generate a uniform distribution for the generator network

In [0]:
def get_generator_input_sampler():
  return lambda m, n: tf.convert_to_tensor(np.random.rand(m, n))

## The Generator network class

In [0]:
class Generator(keras.Model):
  def __init__(self, input_size, hidden_size, output_size):
    super(Generator, self).__init__()
    self.map1 = keras.layers.Dense(hidden_size, input_shape=input_size, activation='tanh')
    self.map2 = keras.layers.Dense(hidden_size, activation='tanh')
    self.map3 = keras.layers.Dense(output_size, activation='linear')
  
  def call(self, inputs):
    x = self.map1(inputs)
    x = self.map2(x)
    x = self.map3(x)
    return x

## The Discriminator network class

In [0]:
class Discriminator(keras.Model):
  def __init__(self, input_size, hidden_size, output_size):
    super(Discriminator, self).__init__()
    self.map1 = keras.layers.Dense(hidden_size, input_shape=input_size, activation='sigmoid')
    self.map2 = keras.layers.Dense(hidden_size, activation='sigmoid')
    self.map3 = keras.layers.Dense(output_size, activation='sigmoid')
  
  def call(self, inputs):
    x = self.map1(inputs)
    return self.map3(self.map2(x))

In [0]:
def get_moments(d):
  # https://stats.stackexchange.com/questions/126346/why-kurtosis-of-a-normal-distribution-is-3-instead-of-0
  # Return the first 4 moments of the data provided
  d = tf.transpose(d, (1, 0))
  mean = tf.reduce_mean(d)
  diffs = (d - mean)
  var = tf.reduce_mean(tf.pow(diffs, 2.0))
  std = tf.sqrt(var)
  zscores = diffs / std
  skews = tf.reduce_mean(tf.pow(zscores, 3.0))
  kurtoses = tf.reduce_mean(tf.pow(zscores, 4.0)) - 3.0 # excess kurtosis, should be 0 for Gaussian
  return tf.stack([mean, std, skews, kurtoses], axis=0)

In [0]:
def stats(d):
    return [np.mean(d), np.std(d)]

## Model hyperparameters and other constants

In [0]:
# Model parameters
g_input_size = 1      # Random noise dimension coming into generator, per output vector
g_hidden_size = 5     # Generator complexity
g_output_size = 1     # Size of generated output vector
d_input_size = 500    # Minibatch size - cardinality of distributions
d_hidden_size = 10    # Discriminator complexity
d_output_size = 1     # Single dimension for 'real' vs. 'fake' classification
minibatch_size = d_input_size

d_learning_rate = 1e-3
g_learning_rate = 1e-3
sgd_momentum = 0.9

num_epochs = 5000
print_interval = 100
d_steps = 20
g_steps = 20

dfe, dre, ge = 0, 0, 0
d_real_data, d_fake_data, g_fake_data = None, None, None

## Data generation parameters

In [0]:
data_mean = 4
data_stddev = 1.25

d_sampler = get_distribution_sampler(data_mean, data_stddev)
gi_sampler = get_generator_input_sampler()

## Initialize the networks

In [0]:
G = Generator(input_size=(500, 1),
                  hidden_size=g_hidden_size,
                  output_size=g_output_size)

D = Discriminator(input_size=(1,4),
                  hidden_size=d_hidden_size,
                  output_size=d_output_size)

## Declare the loss and optimizers

In [0]:
criterion = tf.keras.losses.BinaryCrossentropy(from_logits=True)  
d_optimizer = tf.keras.optimizers.SGD(learning_rate=d_learning_rate, momentum=sgd_momentum)
g_optimizer = tf.keras.optimizers.SGD(learning_rate=g_learning_rate, momentum=sgd_momentum)

## One forward and backward pass with the Discriminator network with real data

*We do not update the parameters with these gradients.*

In [0]:
# d_real_data = d_sampler(d_input_size)

# with tf.GradientTape() as tape:
#   d_real_decision = D(get_moments(d_real_data).reshape((1,4)))
#   d_real_error = criterion(d_real_decision, np.ones((1,1)))  # ones = true
# d_real_grads = tape.gradient(d_real_error, D.trainable_weights) # compute/store gradients, but don't change params
# d_real_grads[0].numpy()

## One forward and backward pass with the Discriminator network with the fake data

In [0]:
# d_gen_input = gi_sampler(minibatch_size, g_input_size)
# with tf.GradientTape() as tape:
#   with tape.stop_recording():
#     d_fake_data = G(d_gen_input)
#   d_fake_decision = D(get_moments(d_fake_data.numpy().T).reshape((1,4)))
#   d_fake_error = criterion(d_fake_decision, np.zeros((1,1)))
# d_fake_grads = tape.gradient(d_fake_error, D.trainable_weights) 
# print(d_fake_grads[0].numpy())
# d_optimizer.apply_gradients(zip(d_fake_grads, D.trainable_weights)) # Only optimizes D's parameters

## One forward and backward pass with the Generator network

In [0]:
# gen_input = gi_sampler(minibatch_size, g_input_size)
# with tf.GradientTape() as tape:
#   g_fake_data = G(gen_input)
#   dg_fake_decision = D(tf.reshape(get_moments_tf(g_fake_data), (1, 4)))
#   g_error = criterion(dg_fake_decision, np.ones((1,1)))
# g_grads = tape.gradient(g_error, G.trainable_weights)
# g_optimizer.apply_gradients(zip(g_grads, G.trainable_weights))

In [0]:
for epoch in range(num_epochs):
    for d_index in range(d_steps):
        #  1A: Train D on real
        d_real_data = tf.convert_to_tensor(d_sampler(d_input_size))
        with tf.GradientTape() as tape:
          d_real_decision = D(tf.reshape(get_moments(d_real_data), (1,4)))
          d_real_error = criterion(d_real_decision, tf.convert_to_tensor(np.ones((1,1))))  # ones = true
        d_real_grads = tape.gradient(d_real_error, D.trainable_weights) # compute/store gradients, but don't change params
        
        #  1B: Train D on fake
        d_gen_input = tf.convert_to_tensor(gi_sampler(minibatch_size, g_input_size))
        with tf.GradientTape() as tape:
          with tape.stop_recording():
            d_fake_data = G(d_gen_input)
          d_fake_decision = D(tf.reshape(get_moments(d_fake_data), (1, 4)))
          d_fake_error = criterion(d_fake_decision, tf.convert_to_tensor(np.zeros((1,1))))
        d_fake_grads = tape.gradient(d_fake_error, D.trainable_weights) 
        d_optimizer.apply_gradients(zip(d_fake_grads, D.trainable_weights)) # Only optimizes D's parameters

        dre, dfe = d_real_error.numpy(), d_fake_error.numpy()

    for g_index in range(g_steps):
        # 2. Train G on D's response (but DO NOT train D on these labels)
        gen_input = tf.convert_to_tensor(gi_sampler(minibatch_size, g_input_size))
        with tf.GradientTape() as tape:
          g_fake_data = G(gen_input)
          dg_fake_decision = D(tf.reshape(get_moments(g_fake_data), (1, 4)))
          g_error = criterion(dg_fake_decision, tf.convert_to_tensor(np.ones((1,1))))
        g_grads = tape.gradient(g_error, G.trainable_weights)
        g_optimizer.apply_gradients(zip(g_grads, G.trainable_weights))

        ge = g_error.numpy()

    if epoch % print_interval == 0:
        print("Epoch %s: D (%s real_err, %s fake_err) G (%s err); Real Dist (%s),  Fake Dist (%s) " %
              (epoch, dre, dfe, ge, stats(d_real_data.numpy()), stats(d_fake_data.numpy())))

W0825 16:46:38.575693 140618269017984 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/nn_impl.py:182: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Epoch 0: D (1.0886667966842651 real_err, 0.6931471824645996 fake_err) G (1.05702543258667 err); Real Dist ([3.969248032721655, 1.2093934306513305]),  Fake Dist ([-0.16539627014683225, 0.07419386361400125]) 
Epoch 100: D (1.087760329246521 real_err, 0.6931471824645996 fake_err) G (1.0470203161239624 err); Real Dist ([4.066130571911628, 1.2346445557152335]),  Fake Dist ([-0.30376876966542266, 0.11773658791402217]) 
Epoch 200: D (1.0887233018875122 real_err, 0.6931471824645996 fake_err) G (1.0439269542694092 err); Real Dist ([3.932937238259309, 1.2277380207905655]),  Fake Dist ([-0.45581279916255424, 0.05699663380595091]) 
Epoch 300: D (1.089498519897461 real_err, 0.6931471824645996 fake_err) G (1.0407992601394653 err); Real Dist ([3.9936716043970866, 1.26235733971329]),  Fake Dist ([-0.7447241589696177, 0.008942774214919355]) 
Epoch 400: D (1.0871673822402954 real_err, 0.6931471824645996 fake_err) G (1.0361649990081787 err); Real Dist ([3.984654450017385, 1.2337837894636985]),  Fake Dist

The network is still not properly configured. 