# Generative Adversarial Networks

The objective today is to use a neural network library written from scratch to create a generative adversarial network from which we will create data that resembles the ever so popular MNIST data set.  

First let's get a few questions out of the way

## Why?

1.  Scala - because it is a great language
2.  ND4J - it is the main linear algebra library for DL4J (Deep learning 4 Java).  The author's objective was to shorten the gap between JVM languages and Numpy or Matlab
3.  From Scratch - it is challenging a.f. and a lot of fun.  

In [1]:
%%classpath add mvn 
org.nd4j nd4j-native-platform 0.7.2
org.nd4j nd4s_2.11 0.7.2

In [2]:
%%classpath add mvn  
org.scalanlp breeze_2.12 0.13.2
org.scalanlp breeze-natives_2.12 0.13.2
org.scalanlp breeze-viz_2.12 0.13.2

In [4]:
%classpath add jar ../scala-miniflow/target/scala-2.11/scala-miniflow_2.11-0.1.0-SNAPSHOT.jar

##  Feed Forward Neural networks

Here we will begin to get into how we will model our neural network framework. Per Deep Learning by Goodfellow, the feedforward neural network is called such because

* (feedforward) information flows through the function being evaluated from the input $x$, through intermediate computations used to define $f$, and finally to the output $y$. We are not considering any feedback connections
* (network) They are typically represented by composing together many different functions. The model is associated with a directy acyclic graph describing how functions are componsed together.
* (neural) The models are loosly inspired by neuroscience.  

While this is from scratch, we will spend little time on the framework with the exception of an example of how to set up simple neural network.  

Before that, we'll introduce our main building block - the Node class.  

Our Node will

* take as arguments, incoming nodes
* have a method which captures outbound nodes
* has a forward method (feed forward)
* has a backward method (for back propagation)


## GAN

Generative Adversarial Networks, or GANs for short, are amazing.  You put two networks in a game, where one network's objective is to trick the other network. 

In what follows, the tricktser (aka the generator), will generate images, those images will be mixed in with real images and the other network (aka the discriminator) will attempt to label the images correctly as fake vs real.  

Based on how we set up the objective functions for both networks

* The generators objective will help it get better a generating fake images
* The discriminators objective will help it get better at picking out the fake images.  

By the end the generator will be able to take a vector of noise and turn it into something that resembles an image from the mnist dataset.  I say resememles because these are very hard to train and design (and it has been built from scratch) so we not should expect state of the art results.  

## Why does this work?

Let's refer to the Generator as $G$ and the Discriminator as $D$.  

$G$ takes as an input an $1 \times n$ vector of random noise, and outputs a $1 \times p$ vectors that now represents an images.  

$D$ takes as an input an $1 \times p$ vector representing an image a produces a scaler between 0 and 1 which will represent the probability that the input is real.  

Finally, let $x$ represent an image and $z$ a noise vector.  

### Discriminator objective 

The cost function for $D$ should be obvious - binary cross entropy. 

$$cost_D = - \sum 1_{y = real}\ln(p) + 1_{y = fake}\ln(1 - p) $$

The smaller $cost_D$ is the better the discriminator is at discriminating

### Generator objective

For $G$, we want $G$ to generate real looking pictures, so suppose that we told the discriminator that the images $G$ generated are real?  That would be equivalent to the following cost function 

$$cost_G = -\sum \ln( (D \circ G)(z) )$$

So, the cost for $G$ is natural log of the discriminator evaluated at the images created by the generator.  So that is to say, suppose that $D \circ G$ is small (for sake of argument suppose it is only slightly bigger than 0), then $\ln ( D \circ G)$ will be very negative.  Then multiplying this by a negative gives a very large number.  So when cost of $G$ is very large, this means the discriminator knows it is seeing a fake.  On the other hand, if $D \circ G$ is closer to 1, then $\ln(D \circ G)$ will be very small (in the absolute sense), and we'll see $cost_G$ be very small.  This implies that the discriminator is believing the data that we provided is real!.  So now if we work to optimize this function, we want to minimize this, that means the optimization work towards making the generator generate real looking images!!  Pretty damn cool if you ask me 

## Building the graph



When we train the gan, we'll do a forward and backward on the generator's graph.  

![generator graph](img/generator_graph.png)

and then we do a forward and a backward on the discriminator's graph

![discriminator graph](img/discriminator_graph.png)

It is important to note that we will not use any derivatives calculated for the discriminator's parameter during the back prop of the generator's graph!

In [18]:
import com.github.timsetsfire.nn.node._
import com.github.timsetsfire.nn.activation._
import com.github.timsetsfire.nn.costfunctions._
import com.github.timsetsfire.nn.batchnormalization._
import com.github.timsetsfire.nn.regularization.Dropout
import com.github.timsetsfire.nn.optimize._
import com.github.timsetsfire.nn.graph._

import scala.util.Try
import org.nd4j.linalg.factory.Nd4j
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.ops.transforms.Transforms.{sigmoid,exp,log,pow,sqrt}
import org.nd4s.Implicits._

import breeze.linalg.DenseMatrix
import breeze.plot._

OutputCell.HIDDEN

## Get Data 



We'll be using the test data set from MNIST.  

Run `curl -s https://pjreddie.com/media/files/mnist_test.csv > ./data/mnist_test.csv` in terminal (or git bash on windows)

## Import Data

In [19]:
val x_ = Nd4j.readNumpy("data/mnist_test.csv", ",").getColumns( (1 until 785):_*).div(255d)
OutputCell.HIDDEN

## Inputs

For inputs, we'll have the following 
* Images - this will be used for real and fake images alike.  It will serve as the input to our _discriminator_.  
* Labels - this will be used for real and fake lables.  It will serve as an input to our _discriminator_ in particular the objective function 
* Noise - this is the noise that will be used to generate images.  It will be an input to the _generator_.
* FakeLabels - this is fake labels.  When we tune the generator, we "connect" the generator output to the discriminator input.  We forward propagate the fake data through, but we label them as real.  We we backprop we backprop from the output of the discriminator all the way back to the input to the generator.  This has the affect of updating parameters as to make the images generated look better.

In [23]:
val images = new Input()
images.setName("images")
val labels = new Input()
labels.setName("labels")
val noise = new Input()  // noise is used to generate fake images
noise.setName("noise")
val fakeLabels = new Input()
fakeLabels.setName("fake_labels")
OutputCell.HIDDEN

## Generator Network

This is the network that will generate fake data. Notice that the network does not "end" with an objective.  It ends with a fakeImage output.   

In [24]:
val h1Generator= LeakyReLU(noise, (100,128), 0.2)
h1Generator.setName("generator_hidden1")

val h2Generator= LeakyReLU(h1Generator, (128, 256), 0.2)
h2Generator.setName("generator_hidden2")

val h3Generator= LeakyReLU(h2Generator, (256, 512), 0.2)
h3Generator.setName("generator_hidden3")

val fakeImages = Sigmoid(h3Generator, (512,784))  // the output, 784, is the dimensions of the image (28x28)
fakeImages.setName("fake_images")

OutputCell.HIDDEN

In [25]:
val generator = topologicalSort{ buildGraph(fakeImages) }
val generatorTrainables = generator.filter{ _.getClass.getSimpleName == "Variable" }

[[Variable@58043370, Variable@1d1ecbd0, Variable@50d68622, Variable@21dd9b1a, Variable@5011a007, Variable@548aa46, Variable@33af23eb, Variable@6ae7e583]]

## Discriminator Network

In [26]:
val h1Discrim = LeakyReLU(images, (784,256), 0.1)  // the input is the dimensions of the image
h1Discrim.setName("discriminator_hidden_layer1")

val d1 = new Dropout(h1Discrim, 0.20)
d1.setName("dropout_h1_layer")

val h2Discrim = LeakyReLU(d1, (256,64), 0.1)
h2Discrim.setName("discriminator_hidden_layer2")

val d2 = new Dropout(h2Discrim, 0.20)
d2.setName("dropout_h2_layer")

val h3Discrim = LeakyReLU(d2, (64,16), 0.1)
h3Discrim.setName("discriminator_hidden_layer3")

val logits = Linear(h3Discrim, (16, 1))
logits.setName("discriminator_logits")

val cost = new BceWithLogits(labels, logits)
cost.setName("cost_function")

OutputCell.HIDDEN

In [27]:
val discriminator = topologicalSort{ buildGraph(cost) }
val discriminatorTrainables = discriminator.filter{ _.getClass.getSimpleName == "Variable" }

[[Variable@34805f2f, Variable@11edaa5f, Variable@3f8ef57, Variable@3537a120, Variable@11c068f5, Variable@1d07b279, Variable@69c0b628, Variable@15034985]]

## Initialize

Next, we'll initialize the trainable parameters of the networks.  

In [28]:
// initialize generator and discriminator parameters
discriminatorTrainables.foreach{ node =>
    val size = node.size
    val (m,n) = (size._1.asInstanceOf[Int], size._2.asInstanceOf[Int])
    node.value = Nd4j.randn(m, n) * math.sqrt(3/(m.toDouble + n.toDouble))
  }

// initialize generator and discriminator
generatorTrainables.foreach{ node =>
    val size = node.size
    val (m,n) = (size._1.asInstanceOf[Int], size._2.asInstanceOf[Int])
    node.value = Nd4j.randn(m, n) * math.sqrt(3/(m.toDouble + n.toDouble))
  }



## Adam

We'll use Adam to train this graph and we're not going to cover the details of the algorithm

## Setting up first and second moment maps for Adam Optimizer

In [29]:
val Array(xrows, xcols) = x_.shape
val batchSize = 128
val stepsPerEpoch = xrows / batchSize

val firstMomentGenerator = generatorTrainables.map{ i => (i, Nd4j.zerosLike(i.value))}.toMap
val secondMomentGenerator = generatorTrainables.map{ i => (i, Nd4j.zerosLike(i.value))}.toMap
val firstMomentDiscriminator = discriminatorTrainables.map{ i => (i, Nd4j.zerosLike(i.value))}.toMap
val secondMomentDiscriminator = discriminatorTrainables.map{ i => (i, Nd4j.zerosLike(i.value))}.toMap
val t = new java.util.concurrent.atomic.AtomicInteger
OutputCell.HIDDEN

## Dropout

We're using dropout through the discriminator.  We we do not want to do though is have dropout turned on when we are training the generator, so below is a helper function to set the training field of the dropout nodes to true or false as needed.

In [30]:
def setDropoutTraining(n: Node, training: Boolean = false): Unit = {
  n.asInstanceOf[Dropout[Node]].train = training
}

setDropoutTraining: (n: com.github.timsetsfire.nn.node.Node, training: Boolean)Unit


In [31]:
var stepSize: Double = 0.002 // 0.001 tf default
val beta1: Double = 0.2  // 0.9 tf default
val beta2: Double = 0.999  // 0.999 tf default
val delta: Double = 1e-8
OutputCell.HIDDEN

1.0E-8

Below, we are going to track 16 noise vectors and how they look after $n$ number of epochs.

In [32]:
val noiseDataForPicture = Nd4j.rand(16,100).mul(2).sub(1)
OutputCell.HIDDEN

In [33]:
for(epoch <- 0 to 100) {

      var loss = 0d
      var genCost = 0d
      var n = 0d
    
      for(steps <- 0 to stepsPerEpoch) {

        t.addAndGet(1)

        val noiseData = Nd4j.rand(batchSize,100).mul(2).sub(1)
        val fakeLabelData = Nd4j.ones(batchSize, 1)

        val generatorFeedDict: Map[Node, INDArray] = Map(
          noise -> noiseData,
          fakeLabels -> fakeLabelData
        )

        // generator
          
        // set the discriminator dropout nodes training method to false
        discriminator.filter{ _.getClass.getSimpleName == "Dropout"}.foreach(d => setDropoutTraining(d, false))
          
        // initialize the input nodes for the generator
        generatorFeedDict.foreach{ case (n, v) => n.forward(v)}
        // feed forward the generator network
        generator.foreach(_.forward())
        
        // pass the fake images created by the generator to the discriminator
        images.forward(fakeImages.value)
        // pass the fake labels to the discriminator as well.  We'll telling the discriminator
        // that the labels are real, and forward prop
        labels.forward(fakeLabels.value)
        discriminator.foreach(_.forward())
          
        // now we backprop the discriminator
        discriminator.reverse.foreach(_.backward())
        
        // the generator and discriminator are connected by the images node at this point
        // so we grab the derivative out of the discriminator input images, and put it into
        // the generator output fake images
        fakeImages.backward(images.gradients(images).dup)
        // and we continue to backprop
        generator.reverse.tail.foreach(_.backward())
       
        // update the generator cost function
        genCost += (cost.value.sumT*batchSize)

        // update the parameters of the generator via Adam optimizer
        for( n <- generatorTrainables) {
          firstMomentGenerator(n).muli(beta1).addi(n.gradients(n).mul(1 - beta1))
          secondMomentGenerator(n).muli(beta2).addi( pow(n.gradients(n),2).mul(1 - beta2))
          val fhat = firstMomentGenerator(n).div(1 - math.pow(beta1, t.get))
          val shat = secondMomentGenerator(n).div(1 - math.pow(beta2, t.get))
          n.value.addi( fhat.mul(-stepSize).div(sqrt(shat).add(delta)))
        }

        // forward through the generator
        generator.foreach(_.forward())
        
        // get the fake data
        val fakeImageData = fakeImages.value
          
        // shuffle x
        Nd4j.shuffle(x_,1)
          
        // grab batchsize rows of x
        val realImageData = x_.getRows((0 until batchSize):_*)
        
        // label the real images with a 1
        val realLabelData = Nd4j.ones(batchSize, 1)
          
        // label the fake images with a 0
        val fakeLabelData0 = Nd4j.zeros(batchSize, 1)

        // concatenate the data and update the feed dicationary for the generator
        val labelData = Nd4j.concat(0, fakeLabelData0, realLabelData)
        val imageData = Nd4j.concat(0, fakeImageData, realImageData)
        val discriminatorFeedDict: Map[Node, INDArray] = Map(
          images -> imageData,
          labels -> labelData
        )
        
        // turn dropout back on
        discriminator.filter{ _.getClass.getSimpleName == "Dropout"}.foreach(d => setDropoutTraining(d, true))
          
        // feed forward
        discriminatorFeedDict.foreach{ case (n, v) => n.forward(v)}
        discriminator.foreach(_.forward())
          
        // back prop
        discriminator.reverse.foreach(_.backward())
          
        // update the discriminator weights.  
        for( n <- discriminatorTrainables) {
          firstMomentDiscriminator(n).muli(beta1).addi(n.gradients(n).mul(1d - beta1))
          secondMomentDiscriminator(n).muli(beta2).addi( pow(n.gradients(n),2).mul(1d - beta2))
          val fhat = firstMomentDiscriminator(n).div(1 - math.pow(beta1, t.get))
          val shat = secondMomentDiscriminator(n).div(1 - math.pow(beta2, t.get))
          n.value.addi( fhat.mul(-stepSize).div(sqrt(shat).add(delta)))
        }

        loss += ((cost.value(0,0)) * images.value.shape.apply(0))
        n += images.value.shape.apply(0)
      }
    
      if(epoch % 10 == 0) {
        print(f"epoch: ${epoch}")
        print(f"\tdiscriminator -> loss: ${loss / n.toDouble}%2.3f")
        println(f"\tgenerator -> loss: ${genCost / (n.toDouble/2d)}%2.3f")
//         noise.forward(noiseDataForPicture)
//         generator.foreach(_.forward())
//         val f4 = Figure()
          
//         for (i <- 0 until 16) {
//           val dig1 = fakeImages.value.getRow(i).dup.data.asDouble()
//           val da = (for{ i <- 0 to 28} yield dig1.drop(i*28).take(28)).init
//           val dig1b = DenseMatrix(da.reverse:_*) //.reshape(28,28)
//           f4.subplot(4,4,i) += image(dig1b)
//         }
//         f4.saveas(s"resources/fig${epoch}.png")
          
      }
}

epoch: 0	discriminator -> loss: 0.546	generator -> loss: 4.900
epoch: 10	discriminator -> loss: 0.447	generator -> loss: 2.083
epoch: 20	discriminator -> loss: 0.590	generator -> loss: 1.165
epoch: 30	discriminator -> loss: 0.626	generator -> loss: 1.033
epoch: 40	discriminator -> loss: 0.645	generator -> loss: 0.932
epoch: 50	discriminator -> loss: 0.656	generator -> loss: 0.889
epoch: 60	discriminator -> loss: 0.656	generator -> loss: 0.869
epoch: 70	discriminator -> loss: 0.655	generator -> loss: 0.872
epoch: 80	discriminator -> loss: 0.656	generator -> loss: 0.864
epoch: 90	discriminator -> loss: 0.653	generator -> loss: 0.874
epoch: 100	discriminator -> loss: 0.648	generator -> loss: 0.899


In [34]:
noise.forward(noiseDataForPicture)
generator.foreach(_.forward())

val d1 = fakeImages.value.getRows(1,2,3,4).data.asDouble
val d2 = fakeImages.value.getRows(5,6,7,8).data.asDouble
val d3 = fakeImages.value.getRows(9,10,11,12).data.asDouble
val d4 = fakeImages.value.getRows(13,14,15,0).data.asDouble

val items = 4 * 28

val da = (for{ i <- 0 to items} yield d1.drop(i*28).take(28) ++ (d2.drop(i*28).take(28)) ++ (d3.drop(i*28).take(28)) ++ (d4.drop(i*28).take(28))).init 


val hm = new HeatMap
hm.data_=( da.reverse )
hm.color_=(GradientColor.GREEN_YELLOW_WHITE )

hm



I had a hard time generating the associated figures here, so I had to run the Main method in the jar, with 100 epochs.  

First Epoch

<img src="img/genfig0.png" alt="drawing" width="350"/>

Epoch 10
<img src="img/genfig10.png" alt="drawing" width="350"/>

Epoch 20
<img src="img/genfig20.png" alt="drawing" width="350"/>

Epoch 30
<img src="img/genfig30.png" alt="drawing" width="350"/>

Epoch 40
<img src="img/genfig40.png" alt="drawing" width="350"/>

Epoch 50
<img src="img/genfig50.png" alt="drawing" width="350"/>

Epoch 60
<img src="img/genfig60.png" alt="drawing" width="350"/>

Epoch 70
<img src="img/genfig70.png" alt="drawing" width="350"/>

Epoch 80
<img src="img/genfig80.png" alt="drawing" width="350"/>

Epoch 90
<img src="img/genfig90.png" alt="drawing" width="350"/>

Epoch 100
<img src="img/genfig100.png" alt="drawing" width="350"/>