In [20]:
import numpy as np 
import matplotlib.pyplot as plt 
from keras.datasets import mnist 
from keras.layers import (Activation,BatchNormalization,Dense,Concatenate,Embedding,Flatten,Input,Multiply,Reshape)
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D,Conv2DTranspose
from keras.models import Model,Sequential 
from keras.optimizers import Adam

In [21]:
#Model parameters 
imgRows = 28
imgCols = 28
channels = 1
zDim = 100
nClasses = 10
imgShape = (imgRows,imgCols,channels)

In [22]:
def coreGen(zDim):
    #Core generator 
    model = Sequential()
    model.add(Dense(7*7*256,input_dim = zDim))
    model.add(Reshape((7,7,256)))
    model.add(Conv2DTranspose(128,kernel_size = 3, strides = 2,padding = 'same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha = 0.01))
    model.add(Conv2DTranspose(64,kernel_size = 3,strides = 1,padding = 'same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha = 0.01))
    model.add(Conv2DTranspose(1,kernel_size = 3,strides = 2,padding = 'same'))
    model.add(Activation('tanh'))
    return model 

In [23]:
def cGEN(zDim):
    #CGAN Addition to the core generator 
    z = Input(shape = (zDim,))
    label = Input(shape = (1,),dtype = 'int32')
    label_embedding = Embedding(nClasses,zDim,input_length = 1)(label)
    label_embedding = Flatten()(label_embedding)
    joined_representation = Multiply()([z,label_embedding])
    G = coreGen(zDim)
    conditioned_img = G(joined_representation)
    return Model([z,label],conditioned_img)

In [24]:
def coreDis(imgShape):
    model = Sequential()
    model.add(Conv2D(64,kernel_size = 3,strides = 2,input_shape = (imgShape[0],imgShape[1],imgShape[2]+1),padding = 'same'))
    model.add(LeakyReLU(alpha = 0.01))
    model.add(Conv2D(64,kernel_size = 3,strides = 2,input_shape = imgShape,padding = 'same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha = 0.01))
    model.add(Conv2D(128,kernel_size = 3,strides = 2,input_shape = imgShape,padding = 'same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha = 0.01))
    model.add(Flatten())
    model.add(Dense(1,activation = 'sigmoid'))
    return model 

In [25]:
def cDis(imgShape):
    img = Input(shape = imgShape)
    label = Input(shape = (1,),dtype = 'int32')
    label_embedding = Embedding(nClasses,np.prod(imgShape),input_length = 1)(label)
    label_embedding = Reshape(imgShape)(label_embedding)
    concatenated = Concatenate(axis = -1)([img,label_embedding])
    discriminator = coreDis(imgShape)
    classification = discriminator(concatenated)
    return Model ([img,label],classification)

In [26]:
def buildCGAN(G,D):
    #Build the CGAN 
    z = Input(shape = (zDim,))
    label = Input(shape = (1,))
    img = G([z,label])
    classification = D([img,label])
    model = Model([z,label],classification)
    return model 
discriminator = cDis(imgShape)
discriminator.compile(loss = 'binary_crossentropy',optimizer = Adam(),metrics = ['accuracy'])
discriminator.trainable = False
generator = cGEN(zDim)
cgan = buildCGAN(generator,discriminator)
cgan.compile(loss = 'binary_crossentropy',optimizer = Adam())

    

In [27]:
def sample_images(image_grid_rows = 2,image_grid_columns = 5):
    z = np.random.normal(0,1,(image_grid_rows,image_grid_columns,zDim))
    labels = np.arange(0,10).reshape(-1,1)
    genImgs = generator.predict([z,labels])
    genImgs = 0.5*genImgs + 0.5
    fig,axs = plt.subplots(image_grid_rows,image_grid_columns,figsize = (10,4),sharex = True,sharey = True)
    cnt = 0
    for i in range(image_grid_rows):
        for j in range(image_grid_columns):
            axs[i,j].imshow(genImgs[cnt,:,:,0],cmap = 'gray')
            axs[i,j].axis('off')
            axs[i,j].set_title("Digit: %d" %labels[cnt])
            cnt += 1

In [29]:
#Training loop 
accuracies = []; losses = []
def train(iterations,batch_size,sampleInterval):
    (x_train,y_train),(_,_) = mnist.load_data()
    x_train = (x_train/127.5) - 1.0
    x_train  = np.expand_dims(x_train,axis = 3)
    real = np.ones((batch_size,1))
    fake = np.ones((batch_size,1))
    for iteration in range(iterations):
        idx = np.random.randint(0,x_train.shape[0],batch_size)
        imgs,labels = x_train[idx],y_train[idx]
        z = np.random.normal(0,1,(batch_size,zDim))
        genImgs = generator.predict([z,labels])
        dLossReal = discriminator.train_on_batch([imgs,labels],real)
        dLossFake = discriminator.train_on_batch([genImgs,labels],fake)
        dLoss = 0.5*np.add(dLossReal,dLossFake)
        
        z = np.random.normal(0,1,(batch_size,zDim))
        labels = np.random.randint(0,nClasses,batch_size).reshape(-1,1)
        gLoss = cgan.train_on_batch([z,labels],real)
        
        if(iteration+1)%sampleInterval == 0:
            losses.append((dLoss[0],gLoss))
            accuracies.append(dLoss[1]*100)
            sample_images()