<a href="https://colab.research.google.com/github/singhamritanshu/GeneratinMNISTImages/blob/main/Generating_MNIST_images_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
tf.logging.set_verbosity(tf.logging.ERROR)

import matplotlib.pyplot as plt
%matplotlib inline

tf.reset_default_graph()

In [None]:
data = input_data.read_data_sets("data/mnist",one_hot=True)

In [None]:
plt.imshow(data.train.images[13].reshape(28,28),cmap="gray")

**Defining Generator**

In [None]:
def generator(z,reuse=None):
    
    with tf.variable_scope('generator',reuse=reuse):
        
        hidden1 = tf.layers.dense(inputs=z,units=128,activation=tf.nn.leaky_relu)
        hidden2 = tf.layers.dense(inputs=hidden1,units=128,activation=tf.nn.leaky_relu)
        output = tf.layers.dense(inputs=hidden2,units=784,activation=tf.nn.tanh)
        
        return output

**Defining Discriminator**

In [None]:
def discriminator(X,reuse=None):
    
    with tf.variable_scope('discriminator',reuse=reuse):
        
        hidden1 = tf.layers.dense(inputs=X,units=128,activation=tf.nn.leaky_relu)
        hidden2 = tf.layers.dense(inputs=hidden1,units=128,activation=tf.nn.leaky_relu)
        logits = tf.layers.dense(inputs=hidden2,units=1)
        output = tf.sigmoid(logits)
        
        return logits

In [None]:
x = tf.placeholder(tf.float32,shape=[None,784])
z = tf.placeholder(tf.float32,shape=[None,100])

In [None]:
fake_x = generator(z)
D_logits_real = discriminator(x)
D_logits_fake = discriminator(fake_x,reuse=True)

In [None]:
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logits_real, labels=tf.ones_like(D_logits_real)))

In [None]:
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logits_fake, labels=tf.zeros_like(D_logits_fake)))

In [None]:
D_loss = D_loss_real + D_loss_fake

In [None]:
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logits_fake,labels=tf.ones_like(D_logits_fake)))

In [None]:
training_vars = tf.trainable_variables()

theta_D = [var for var in training_vars if 'dis' in var.name]
theta_G = [var for var in training_vars if 'gen' in var.name]

In [None]:
D_optimizer = tf.train.AdamOptimizer(0.001).minimize(D_loss,var_list = theta_D)
G_optimizer = tf.train.AdamOptimizer(0.001).minimize(G_loss, var_list = theta_G)

In [None]:
batch_size = 100
num_epochs = 1000

init = tf.global_variables_initializer()

In [None]:
with tf.Session() as session:
    
    
    #initialize all variables
    session.run(init)
    
    #for each epoch
    for epoch in range(num_epochs):
        
        #select number of batches
        num_batches = data.train.num_examples // batch_size
        
        #for each batch
        for i in range(num_batches):
            
            #get the batch of data according to the batch size
            batch = data.train.next_batch(batch_size)
            
            #reshape the data  
            batch_images = batch[0].reshape((batch_size,784))
            batch_images = batch_images * 2 - 1
            
            #sample batch noise
            batch_noise = np.random.uniform(-1,1,size=(batch_size,100))
            
            #define the feed dictionaries with input x as batch_images and noise z as batch noise
            feed_dict = {x: batch_images, z : batch_noise}

            
            #train discriminator and generator
            _ = session.run(D_optimizer,feed_dict = feed_dict)
            _ = session.run(G_optimizer,feed_dict = feed_dict)

            
            #compute loss of discriminator and generator
            discriminator_loss = D_loss.eval(feed_dict)
            generator_loss = G_loss.eval(feed_dict)
                      
            
        #feed the noise to a generator on every 100th epoch and generate an image
        if epoch%100==0:
            print("Epoch: {}, iteration: {}, Discriminator Loss:{}, Generator Loss: {}".format(epoch,i,discriminator_loss,generator_loss))
            
            #generate a fake image
            _fake_x = fake_x.eval(feed_dict)

            #plot the fake image generated by the generator
            plt.imshow(_fake_x[0].reshape(28,28))
            plt.show()