In [1]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, LSTM
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import pickle
import sys
import numpy as np

In [2]:
class VanillaDiscriminator:
    def __init__(self, max_sequence_len):
        self.max_sequence_len = max_sequence_len
    
    def build_model(self):
        txt_shape = (self.max_sequence_len, 1, 1)
        model = Sequential(name='VanillaDiscriminator')
        model.add(Flatten(input_shape=txt_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()
        txt = Input(shape=txt_shape)
        validity = model(txt)
        self.discriminator = Model(txt, validity)
        return self.discriminator

In [3]:
class VanillaGenerator:
    def __init__(self, max_sequence_len):
        self.max_sequence_len = max_sequence_len
    
    def build_model(self):
        noise_shape = (self.max_sequence_len, 1)
        model = Sequential(name='VanillaGenerator')
        model.add(LSTM(256, input_shape=noise_shape))
        model.add(Dropout(0.2))
        model.add(Dense(noise_shape[0], activation='tanh'))
        model.summary()
        noise = Input(shape=noise_shape)
        txt = model(noise)
        self.generator = Model(noise, txt)
        return self.generator

In [4]:
class TextGAN:
    def __init__(self, discriminator, generator, max_sequence_len):
        self.discriminator = discriminator
        self.generator = generator
        self.max_sequence_len = max_sequence_len
        
    def build_model(self):
        self.discriminator.trainable = False # it's important to freeze the discriminator when training the generator
        gan_input = Input(shape=(self.max_sequence_len, 1)) # The GAN takes noise as input 
        generator_out = self.generator(gan_input) # generates text output
        gan_output = self.discriminator(generator_out) # and validates generated text 
        self.gan =  Model(gan_input, gan_output, name='GAN')
        self.gan.summary()
        return self.gan

In [5]:
class GANTrainer:
    def __init__(self):
        pass
    
    def train(self, X_train, discriminator, generator, gan, batch_size, epochs):
        half_batch = int(batch_size/2)
        for epoch in range(epochs):
            ##########################
            # train the discriminator on half-real and half-fake data
            ##########################
            # get random half-batch real data
            idx = np.random.randint(0, X_train.shape[0], half_batch)
            txts = X_train[idx]

            # get half-batch fake data
            noise = np.random.normal(0, 1, (half_batch, columns, 1))
            gen_txts = generator.predict(noise)
            gen_txts = np.expand_dims(gen_txts, axis=2)
            gen_txts = np.expand_dims(gen_txts, axis=3)

            # compute discriminator losses on real and fake data and average them
            d_loss_real = discriminator.train_on_batch(txts, np.ones((half_batch, 1)))
            d_loss_fake = discriminator.train_on_batch(gen_txts, np.zeros((half_batch, 1)))
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            
            ##########################
            # train the GAN, thereby traning the generator (on full-batch data)
            # the discriminator is not trained in the GAN because it's trainable flag is set to False 
            ##########################
            noise = np.random.normal(0, 1, (batch_size, columns, 1))
            # the generator wants discriminator to mistake texts as real. Therefore send np.ones as labels
            g_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

In [8]:
# data loading and preprocessing
data = pickle.load(open("X.p","rb"))
X_train = data
rows, columns, channels = X_train.shape
print(rows, columns, channels)
# expand the last dimension
X_train = np.expand_dims(X_train, axis=3)
print(X_train.shape)

7613 768 1
(7613, 768, 1, 1)


In [9]:
# parameter initializations
max_sequence_len = columns
batch_size = 1024
epochs = 100
optimizer = Adam(0.0002, 0.5)

In [10]:
discriminator = VanillaDiscriminator(max_sequence_len).build_model()
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])

Model: "VanillaDiscriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            (None, 768)               0         
_________________________________________________________________
dense (Dense)                (None, 512)               393728    
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 512)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 257       
Total params: 525,313
Trainable params: 525,313
Non-trainable params: 0
________________________________________

In [11]:
generator = VanillaGenerator(max_sequence_len).build_model()
generator.compile(loss='binary_crossentropy', optimizer=optimizer)

Model: "VanillaGenerator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm (LSTM)                  (None, 256)               264192    
_________________________________________________________________
dropout (Dropout)            (None, 256)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 768)               197376    
Total params: 461,568
Trainable params: 461,568
Non-trainable params: 0
_________________________________________________________________


In [12]:
gan = TextGAN(discriminator, generator, max_sequence_len).build_model()
gan.compile(loss='binary_crossentropy', optimizer=optimizer)

Model: "GAN"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         [(None, 768, 1)]          0         
_________________________________________________________________
model_1 (Model)              (None, 768)               461568    
_________________________________________________________________
model (Model)                (None, 1)                 525313    
Total params: 986,881
Trainable params: 461,568
Non-trainable params: 525,313
_________________________________________________________________


In [13]:
# training
GANTrainer().train(X_train, discriminator, generator, gan, batch_size, epochs)

0 [D loss: 0.704702, acc.: 35.55%] [G loss: 0.691149]
1 [D loss: 0.491764, acc.: 50.00%] [G loss: 0.688128]
2 [D loss: 0.416639, acc.: 50.00%] [G loss: 0.683806]
3 [D loss: 0.390614, acc.: 50.00%] [G loss: 0.679571]
4 [D loss: 0.380743, acc.: 50.00%] [G loss: 0.675496]
5 [D loss: 0.377441, acc.: 50.00%] [G loss: 0.670860]
6 [D loss: 0.377320, acc.: 50.00%] [G loss: 0.666329]
7 [D loss: 0.378216, acc.: 50.00%] [G loss: 0.661325]
8 [D loss: 0.380268, acc.: 50.00%] [G loss: 0.655204]
9 [D loss: 0.384612, acc.: 50.00%] [G loss: 0.647454]
10 [D loss: 0.391916, acc.: 50.00%] [G loss: 0.634348]
11 [D loss: 0.417960, acc.: 50.00%] [G loss: 0.610074]
12 [D loss: 0.672615, acc.: 50.00%] [G loss: 0.524856]
13 [D loss: 0.706141, acc.: 50.00%] [G loss: 1.293922]
14 [D loss: 0.210587, acc.: 100.00%] [G loss: 2.240002]
15 [D loss: 0.180414, acc.: 100.00%] [G loss: 1.991547]
16 [D loss: 0.279523, acc.: 99.51%] [G loss: 1.067211]
17 [D loss: 0.330659, acc.: 93.26%] [G loss: 0.802134]
18 [D loss: 0.3360