In [None]:
from keras.models import Model
from keras.layers import Input, Dense, Reshape, Activation, BatchNormalization, Flatten
from keras.layers import UpSampling2D, Conv2D, MaxPool2D
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam

LRELU = 0.2

def custom_conv(x, filters, kernel=3, bn=True, activation='relu', ud_sample='None'):
    if ud_sample == 'up':
        x = UpSampling2D()(x)
    
    initialization = 'he_uniform' if activation == 'relu' else 'glorot_uniform'
    out = Conv2D(filters, kernel, padding='same', kernel_initializer=initialization)(x)
    if bn:
        out = BatchNormalization()(out)
    out = Activation(activation)(out)
    
    if ud_sample == 'down':
        out = MaxPool2D()(out)
    
    return out

def custom_dense(x, units, bn=True, activation='lrelu'):
    initialization = 'he_uniform' if activation.find('relu') == -1 else 'glorot_uniform'
    activation_fn = LeakyReLU(LRELU) if activation == 'lrelu' else Activation(activation)
    out = Dense(units, kernel_initializer=initialization)(x)
    if bn:
        out = BatchNormalization()(out)
    out = activation_fn(out)
    
    return out

def generator_model(input_shape=(100,)):
    x = Input(input_shape)
    y = custom_dense(x, 1024)
    y = custom_dense(y, 128*7*7)
    y = Reshape((7,7,128))(y)
    y = custom_conv(y, 64, 5, ud_sample='up')
    y = custom_conv(y, 1, 5, ud_sample='up', activation='tanh', bn=False)
    model = Model(x, y)
    
    return model

def discriminator_model(input_shape=(28,28,1)):
    x = Input(input_shape)
    y = custom_conv(x, 64, 5, ud_sample='down')
    y = custom_conv(y, 128, 5, ud_sample='down')
    y = Flatten()(y)
    y = custom_dense(y, 1024)
    y = custom_dense(y, 1, bn=False, activation='sigmoid')
    model = Model(x, y)
    
    return model

def generator_containing_discriminator(generator, discriminator):
    model = Sequential()
    model.add(generator)
    discriminator.trainable = False
    model.add(discriminator)
    
    return model

In [None]:
def train(BATCH_SIZE=128):
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    X_train = (X_train.astype(np.float32) - 127.5)/127.5
    X_train = X_train.reshape((X_train.shape[0], 1) + X_train.shape[1:])
    discriminator = discriminator_model()
    generator = generator_model()
    discriminator_on_generator = \
        generator_containing_discriminator(generator, discriminator)
    d_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
    g_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
    generator.compile(loss='binary_crossentropy', optimizer="SGD")
    discriminator_on_generator.compile(
        loss='binary_crossentropy', optimizer=g_optim)
    discriminator.trainable = True
    discriminator.compile(loss='binary_crossentropy', optimizer=d_optim)
    noise = np.zeros((BATCH_SIZE, 100))
    for epoch in range(100):
        print("Epoch is", epoch)
        print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))
        for index in range(int(X_train.shape[0]/BATCH_SIZE)):
            for i in range(BATCH_SIZE):
                noise[i, :] = np.random.uniform(-1, 1, 100)
            image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
            generated_images = generator.predict(noise, verbose=0)
            if index % 20 == 0:
                image = combine_images(generated_images)
                image = image*127.5+127.5
                Image.fromarray(image.astype(np.uint8)).save(
                    str(epoch)+"_"+str(index)+".png")
            X = np.concatenate((image_batch, generated_images))
            y = [1] * BATCH_SIZE + [0] * BATCH_SIZE
            d_loss = discriminator.train_on_batch(X, y)
            print("batch %d d_loss : %f" % (index, d_loss))
            for i in range(BATCH_SIZE):
                noise[i, :] = np.random.uniform(-1, 1, 100)
            discriminator.trainable = False
            g_loss = discriminator_on_generator.train_on_batch(
                noise, [1] * BATCH_SIZE)
            discriminator.trainable = True
            print("batch %d g_loss : %f" % (index, g_loss))
            if index % 10 == 9:
                generator.save_weights('generator', True)
                discriminator.save_weights('discriminator', True)

In [30]:
g = generator_model()
d = discriminator_model()

In [37]:
g.trainable

True

In [38]:
d.trainable

False

In [34]:
xx = generator_containing_discriminator(g, d)

In [36]:
xx.trainable

True