<a href="https://colab.research.google.com/github/pnalaba/mnistGAN/blob/master/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import numpy as np                                                                                                                
import tensorflow as tf
from tensorflow.python.keras.layers import Input,Dense,Flatten,LeakyReLU,BatchNormalization,Reshape
from tensorflow.python.keras.optimizers import Adam
from tensorflow.python.keras.models import Sequential, Model
import matplotlib as mpl
#mpl.use('TkAgg')
import matplotlib.pyplot as plt
from keras.datasets import mnist

In [0]:
import tensorflow as tf
import keras
#see if any gpus are available
from keras import backend as K
K.tensorflow_backend._get_available_gpus()

#configure tensorflow to use gpu
#config = tf.ConfigProto( device_count = {'GPU': 1 , 'CPU': 1} ) 
#sess = tf.Session(config=config) 
#keras.backend.set_session(sess)

In [25]:
#Using folders from google drive - 
#  -- works well when running notebook on colab 

# Note : Please create a folder corresponding to
#  root_folder+"/gan/images"

from google.colab import drive
drive.mount('/content/drive')
root_folder = '/content/drive/My Drive/Colab Notebooks'

#If you are running this on your computer, you can instead use your local file system
# in which case, set root_folder appropriately and create the folders


['/job:localhost/replica:0/task:0/device:GPU:0']

In [0]:
class GAN():
  def __init__(self):
    self.img_rows=28
    self.img_cols=28
    self.channels=1
    self.img_shape = (self.img_rows, self.img_cols, self.channels)

    optimizer = Adam(0.0002, 0.5)

    #Build and compile the discriminator
    self.discriminator = self.build_discriminator()
    self.discriminator.compile(loss='binary_crossentropy',
        optimizer=optimizer,
        metrics=['accuracy'])

    #Build and compile the generator
    self.generator = self.build_generator()
    self.generator.compile(loss='binary_crossentropy',
        optimizer=optimizer)

    #The generator takes noise as input and generates imgs
    z = Input(shape=(100,)) 
    img = self.generator(z)

    #For the combined model we will only train the generator
    self.discriminator.trainable = False

    #The discriminator takes generated images as input and determines validity
    valid = self.discriminator(img)

    #The combined model (stacked generator and discriminator) takes
    # noise as input => generates images => determines validity
    self.combined = Model(z, valid)
    self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
    self.discriminator.trainable = True #to prevent keras Warning about updating 
                                        #parameters of a non-trainable model


In [0]:
class GAN(GAN):
  def build_generator(self):
    noise_shape = (100,)
    model = Sequential()

    model.add(Dense(256, input_shape=noise_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(np.prod(self.img_shape), activation='tanh'))
    model.add(Reshape(self.img_shape))

    model.summary()

    noise=Input(shape=noise_shape)
    img=model(noise)

    return Model(noise, img)

In [0]:
class GAN(GAN):
  def build_discriminator(self):
    img_shape = (self.img_rows, self.img_cols, self.channels)

    model = Sequential()

    model.add(Flatten(input_shape=img_shape))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation='sigmoid'))
    model.summary()

    img = Input(shape=img_shape)
    validity = model(img)

    return Model(img,validity)


In [0]:
class GAN(GAN):
  def train(self, epochs, batch_size=128, save_interval=50):

    #Load the dataset

    (X_train, _), (_,_) = mnist.load_data()

    #Rescale to [-1, 1]
    X_train = (X_train.astype(np.float32) - 127.5)/ 127.5
    X_train = np.expand_dims(X_train, axis=3)

    half_batch = int(batch_size/2)

    for epoch in range(epochs):
      # ----------------
      # Train Discriminator
      # ----------------

      # Select a random half batch of images
      idx = np.random.randint(0, X_train.shape[0], half_batch)
      imgs = X_train[idx]

      noise = np.random.normal(0,1, (half_batch, 100))

      # Generate a half batch of new images
      gen_imgs = self.generator.predict(noise)

      #Train the discriminator
      self.discriminator.trainable=True
      d_loss_real = self.discriminator.train_on_batch(imgs,np.ones((half_batch,1)))
      d_loss_fake = self.discriminator.train_on_batch(gen_imgs,np.zeros((half_batch,1)))
      d_loss = 0.5*np.add(d_loss_real, d_loss_fake)

      # ------------------
      # Train Generator
      # -----------------
      noise = np.random.normal(0,1, (batch_size, 100))

      # The generator wants the discriminator to label the generated samples
      # as valid (ones)
      valid_y = np.array([1] * batch_size)

      #Train the generator
      self.discriminator.trainable=False
      g_loss = self.combined.train_on_batch(noise, valid_y)

      # Plot the progress
      print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]"%
          (epoch, d_loss[0],100*d_loss[1],g_loss))

      # If at save interval => save generated image samples
      if epoch % save_interval == 0:
        self.save_imgs(epoch)


In [0]:

class GAN(GAN):
  def save_imgs(self, epoch):
    r,c = 5, 5
    noise = np.random.normal(0,1,(r*c,100))
    gen_imgs = self.generator.predict(noise)

    # Rescale images to [0,1]
    gen_imgs = 0.5*gen_imgs+0.5

    fig, axs = plt.subplots(r,c)
    cnt = 0
    for i in range(r):
      for j in range(c):
        axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
        axs[i,j].axis('off')
        cnt +=1
    fig.savefig(root_folder+"/gan/images/mnist_%d.png"%epoch)
    plt.close()


In [23]:
gan = GAN()


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_4 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_28 (Dense)             (None, 512)               401920    
_________________________________________________________________
leaky_re_lu_20 (LeakyReLU)   (None, 512)               0         
_________________________________________________________________
dense_29 (Dense)             (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_21 (LeakyReLU)   (None, 256)               0         
_________________________________________________________________
dense_30 (Dense)             (None, 1)                 257       
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
_________________________________________________________________
____

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive


In [27]:
gan.train(epochs=10000, batch_size=1024, save_interval=200)

0 [D loss: 0.573088, acc.: 74.61%] [G loss: 0.929625]
1 [D loss: 0.567481, acc.: 75.98%] [G loss: 0.935788]
2 [D loss: 0.579618, acc.: 74.02%] [G loss: 0.941367]
3 [D loss: 0.565158, acc.: 75.20%] [G loss: 0.942932]
4 [D loss: 0.561580, acc.: 74.51%] [G loss: 0.941705]
5 [D loss: 0.568682, acc.: 74.71%] [G loss: 0.926735]
6 [D loss: 0.573942, acc.: 73.73%] [G loss: 0.948941]
7 [D loss: 0.568117, acc.: 75.20%] [G loss: 0.945887]
8 [D loss: 0.568031, acc.: 74.51%] [G loss: 0.945382]
9 [D loss: 0.563688, acc.: 74.22%] [G loss: 0.946981]
10 [D loss: 0.575241, acc.: 74.41%] [G loss: 0.928592]
11 [D loss: 0.587830, acc.: 72.07%] [G loss: 0.939708]
12 [D loss: 0.572834, acc.: 74.61%] [G loss: 0.925547]
13 [D loss: 0.562286, acc.: 75.10%] [G loss: 0.934330]
14 [D loss: 0.565512, acc.: 76.27%] [G loss: 0.936603]
15 [D loss: 0.568875, acc.: 75.88%] [G loss: 0.945003]
16 [D loss: 0.570360, acc.: 74.71%] [G loss: 0.945205]
17 [D loss: 0.577090, acc.: 73.24%] [G loss: 0.955208]
18 [D loss: 0.559200