In [None]:
import tensorflow as tf
from tensorflow.keras.optimizers import * 
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import * 
import numpy as np
import matplotlib.pyplot as plt
import random


class GAN():
    def __init__(self):
        self.img_size = 28
        self.img_shape = (self.img_size, self.img_size, 1)
        self.seed_dim = 100

        self.discriminator = self.build_discriminator()
        self.discriminator.compile(
            loss='binary_crossentropy',
            optimizer = Adam(0.0002, 0.5),
            metrics=['accuracy'])
        
        self.generator = self.build_generator()
        self.generator.compile(
            loss='binary_crossentropy',
            optimizer = Adam(0.0002, 0.5))  
        
        self.combined = self.build_combined()
        self.combined.compile(
            loss='binary_crossentropy', 
            optimizer=Adam(0.0002, 0.5))
    
    
    def build_combined(self):
        noise = Input(shape=(self.seed_dim,))
        img = self.generator(noise)
        self.discriminator.trainable = False
        fake = self.discriminator(img)
        return tf.keras.Model(noise,fake)
    
    def build_generator(self):
        model = tf.keras.Sequential()

        model.add(Dense(256, input_shape=(self.seed_dim,)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))
        
        model.summary()

        noise = Input(shape=(self.seed_dim,))
        img = model(noise)    
        return tf.keras.Model(noise,img)
   
    def build_discriminator(self):
        model = tf.keras.Sequential()
        
        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))       
        
        model.summary()

        img = Input(shape=self.img_shape)
        fake_or_real = model(img)
        return tf.keras.Model(img, fake_or_real)
    
    def train(self,epochs, batch_size, interval):
        (x_train, _),(_,_) = tf.keras.datasets.mnist.load_data()
        x_train = (x_train.astype('float32') - 127.5) / 255.
        x_train = np.expand_dims(x_train, axis=3)
        assert x_train.shape == (60000,28,28,1)

        half = int(batch_size/2)
        for epoch in range(epochs):
            indexes = np.random.randint(0,x_train.shape[0], half)
            imgs_real = x_train[indexes]
            
            # train discriminator
            noise = np.random.normal(0,1,(half, self.seed_dim))
            imgs_fake = self.generator.predict(noise)
            
            l1 = self.discriminator.train_on_batch(imgs_real,np.ones((half,1)))
            l2 = self.discriminator.train_on_batch(imgs_fake,np.zeros((half,1)))
            disc_loss = 0.5 * np.add(l1, l2)
            
            # train generator
            noise = np.random.normal(0,1,(batch_size, self.seed_dim))
            gen_loss = self.combined.train_on_batch(noise, np.array([1] * batch_size))
            
            print( "%d D %f %f G %f" % (epoch, disc_loss[0],disc_loss[1], gen_loss))
            if (epoch+1) % interval == 0:
                self.plot()
        
    def plot(self):
        noise = np.random.normal(0,1,(100,self.seed_dim))
        imgs = self.generator.predict(noise)
        
        imgs = 0.5 * (imgs + 1)
        assert imgs.shape == (100,self.img_size, self.img_size, 1)
        
        fig, axs = plt.subplots(10,10)
        for i in range(10):
            for j in range(10):
                axs[i,j].imshow(imgs[i+j*10,:,:,0], cmap='gray')
                axs[i,j].axis('off')
        plt.show()
        
gan = GAN()
gan.train(epochs = 20000, batch_size = 64, interval = 2000)