In [1]:
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

from keras.layers import Input, Reshape, Dense, Dropout, Flatten, Convolution2D, UpSampling2D
from keras.models import Sequential, Model
from keras.layers.advanced_activations import LeakyReLU
from keras.datasets import mnist
from keras.optimizers import Adam
from keras.initializers import normal

Using TensorFlow backend.


In [2]:
np.random.seed(1000)

randomDim = 100

In [3]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_train = X_train [:, :, :, np.newaxis]

In [4]:
X_train.shape

(60000, 28, 28, 1)

In [5]:
adam = Adam(lr=0.0002, beta_1=0.5)

In [6]:
#generator

generator = Sequential()
generator.add(Dense(128*7*7, input_dim=randomDim, kernel_initializer='normal'))
generator.add(LeakyReLU(0.2))
generator.add(Reshape((7,7,128)))
generator.add(UpSampling2D(size=(2, 2)))
generator.add(Convolution2D(64, (5, 5), padding='same'))
generator.add(LeakyReLU(0.2))
generator.add(UpSampling2D(size=(2,2)))
generator.add(Convolution2D(1, (5, 5), padding='same', activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=adam)

In [7]:
# Discriminator 
discriminator = Sequential()
discriminator.add(Convolution2D(64, (5,5), padding='same', input_shape=(28,28,1), strides= (2,2), kernel_initializer='normal'))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Convolution2D(128, (5,5), padding='same', strides=(2,2)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=adam)

In [8]:
# Combined Network
discriminator.trainable = False
ganInput = Input(shape=(randomDim,))
x = generator(ganInput)
ganOutput = discriminator(x)
gan = Model(inputs = ganInput, outputs = ganOutput)
gan.compile(loss='binary_crossentropy', optimizer=adam)

In [9]:
dLosses = []
gLosses = []

In [10]:
def plotLoss(epoch):
    plt.figure(figsize=(10,8))
    plt.plot(dLosses, label="Discriminative loss")
    plt.plot(gLosses, label="Generative loss")
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('images/dcgan_loss_epoch_%d.png' % epoch)

In [11]:
def plotGeneratedImages(epoch, example=100, dim=(10,10), figsize=(10,10)):
    noise = np.random.normal(0, 1, size=[example, randomDim])
    generatedImages = generator.predict(noise)
    
    plt.figure(figsize=figsize)
    for i in range(generatedImages.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generatedImages[i,0], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('images/dcgan_generated_images_epoch_%d.png' %epoch)

In [12]:
def train(epochs=1, batchSize=128):
    batchCount = X_train.shape[0]//batchSize
    print ("Epochs : ", epochs)
    print ("BatchSize : ", batchSize)
    print ("Batches per epoch : ", batchCount)
    
    for e in range(1, epochs+1):
        print ("-"*20, "Epoch : %d" % e , "-"*20)
        for _ in tqdm(range(batchCount)):
            
            noise = np.random.normal(0, 1 , size=[batchSize, randomDim])
            imageBatch = X_train[np.random.randint(0,X_train.shape[0] , size=batchSize)]
            
            generatedImages = generator.predict(noise)
            X = np.concatenate([imageBatch, generatedImages])
            
            #labels
            yDis = np.zeros(2*batchSize)
            yDis[:batchSize] = 0.9
            
            #train discriminator
            discriminator.trainable = True
            dloss = discriminator.train_on_batch(X, yDis)
            
            #train generator 
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            yGen = np.ones(batchSize)
            discriminator.trainable = False
            gloss = gan.train_on_batch(noise, yGen)
        
        dLosses.append(dloss)
        gLosses.append(gloss)
        
        if e == 1 or e % 5 == 0:
            plotGeneratedImages(e)
    
    plotLoss(e)

In [None]:
train(50,128)

  0%|          | 0/468 [00:00<?, ?it/s]

('Epochs : ', 50)
('BatchSize : ', 128)
('Batches per epoch : ', 468)
('--------------------', 'Epoch : 1', '--------------------')


100%|██████████| 468/468 [04:36<00:00,  1.98it/s]
  0%|          | 0/468 [00:00<?, ?it/s]

('--------------------', 'Epoch : 2', '--------------------')


100%|██████████| 468/468 [04:52<00:00,  2.01it/s]
  0%|          | 0/468 [00:00<?, ?it/s]

('--------------------', 'Epoch : 3', '--------------------')


100%|██████████| 468/468 [04:08<00:00,  2.03it/s]
  0%|          | 0/468 [00:00<?, ?it/s]

('--------------------', 'Epoch : 4', '--------------------')


100%|██████████| 468/468 [04:05<00:00,  1.68it/s]
  0%|          | 0/468 [00:00<?, ?it/s]

('--------------------', 'Epoch : 5', '--------------------')


100%|██████████| 468/468 [04:12<00:00,  1.87it/s]
  0%|          | 0/468 [00:00<?, ?it/s]

('--------------------', 'Epoch : 6', '--------------------')


100%|██████████| 468/468 [04:12<00:00,  1.79it/s]
  0%|          | 0/468 [00:00<?, ?it/s]

('--------------------', 'Epoch : 7', '--------------------')


100%|██████████| 468/468 [04:25<00:00,  1.70it/s]
  0%|          | 0/468 [00:00<?, ?it/s]

('--------------------', 'Epoch : 8', '--------------------')


100%|██████████| 468/468 [04:15<00:00,  1.94it/s]
  0%|          | 0/468 [00:00<?, ?it/s]

('--------------------', 'Epoch : 9', '--------------------')


100%|██████████| 468/468 [04:20<00:00,  1.88it/s]
  0%|          | 0/468 [00:00<?, ?it/s]

('--------------------', 'Epoch : 10', '--------------------')


100%|██████████| 468/468 [04:14<00:00,  1.88it/s]
  0%|          | 0/468 [00:00<?, ?it/s]

('--------------------', 'Epoch : 11', '--------------------')


100%|██████████| 468/468 [04:11<00:00,  1.94it/s]
  0%|          | 0/468 [00:00<?, ?it/s]

('--------------------', 'Epoch : 12', '--------------------')


 83%|████████▎ | 390/468 [04:03<00:44,  1.74it/s]