In [None]:
import numpy as np
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

Functions to normalize the data in the range [0,1]

In [None]:
min_vector = []
max_vector = []
median_vector = []
def normalize(dataset):
    for i in range(dataset.shape[0]):
        for k in range(dataset.shape[1]):

            max_ = np.max(dataset[i,k,:])
            max_vector.append(max_)
          
            min_ = np.min(dataset[i,k,:])
            min_vector.append(min_)
          
            median = np.median(dataset[i,k,:])
            median_vector.append(median)
            dataset[i,k,:] = 2*((dataset[i,k,:] - min_)/(max_-min_))-1

def denormalize(M_vec, m_vec, dataset):
  M = sum(M_vec)/len(M_vec)
  m = sum(m_vec)/len(m_vec)
  
  for i in range(dataset.shape[0]):
    for k in range(dataset.shape[1]):
      dataset[i,k,:] = m + (M-m)*(dataset[i,k,:]+1)/2

Data loading

In [None]:
#Preparing the data (loading and normalization)
data = np.load('TOD.npy')
test = np.load('TOD_test.npy')

dataset = data[:,1,:,:].copy()
test = test[:,1,:,:]

normalize(dataset)
dataset.shape

In [None]:
#-----------------#
# HIPERPARAMETERS #
#-----------------#

data_shape = dataset.shape[1:]
noise_dim = 500
epochs = 100
BATCH_SIZE = 64

***DISCRIMINATOR***

In [None]:
def get_discriminator_model():
    
    model = Sequential()

    model.add(layers.Input(shape = data_shape))
    
    model.add(layers.Conv1D(16, 12, strides = 1, data_format='channels_first',activation = 'swish'))
    model.add(layers.LayerNormalization())

    model.add(layers.Conv1D(32, 12, strides = 2, data_format='channels_first',activation = 'swish'))
    model.add(layers.LayerNormalization())

    model.add(layers.Conv1D(64, 12, strides = 3, data_format='channels_first',activation = 'swish'))
    model.add(layers.LayerNormalization())

    model.add(layers.Flatten())  
    model.add(layers.Dropout(0.5))  
    model.add(layers.Dense(1))

    return model

d_model = get_discriminator_model()
d_model.summary()

***GENERATOR***

In [None]:
def get_generator_model():
    
    model = Sequential()
    
    model.add(layers.Input(shape = noise_dim))
    model.add(layers.Reshape((1, noise_dim)))
    
    model.add(layers.Conv1DTranspose(4 , 12, strides = 1, padding = 'same', data_format='channels_first', activation = 'swish'))
    model.add(layers.BatchNormalization())
    
    model.add(layers.Conv1DTranspose(8 , 12, strides = 2, padding = 'same', data_format='channels_first', activation = 'swish'))
    model.add(layers.BatchNormalization())
    
    model.add(layers.Conv1DTranspose(12, 12, strides = 2, padding = 'same', data_format='channels_first', activation = 'swish'))
  
    model.add(layers.Dense(data_shape[1], activation='tanh'))

    return model

g_model = get_generator_model()
g_model.summary()

Creation of the Wasserstein GAN with gradient penalty class

In [None]:
class WGAN(keras.Model):
    def __init__(
        self,
        discriminator,
        generator,
        latent_dim,
        discriminator_extra_steps=3,
        gp_weight=10.0,
    ):
        super(WGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_steps = discriminator_extra_steps
        self.gp_weight = gp_weight

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn


    def gradient_penalty(self, batch_size, real_data, fake_data):
        # get the interplated data
        alpha = tf.random.normal([batch_size, 1, 1], 0.0, 1.0)
        diff = fake_data - real_data
        interpolated = real_data + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = self.discriminator(interpolated, training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 3. Calcuate the norm of the gradients
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

    def train_step(self, real_data):
        if isinstance(real_data, tuple):
            real_data = real_data[0]

        # Get the batch size
        batch_size = tf.shape(real_data)[0]

        #-------------------------#
        # Train the DISCRIMINATOR #
        #-------------------------#
        
        for i in range(self.d_steps):
            
            # Get the latent vector
            random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
            
            with tf.GradientTape() as tape:
                
                # Generate fake sample from the latent vector
                fake_data = self.generator(random_latent_vectors, training=True)
                
                fake_logits = self.discriminator(fake_data, training=True)
                real_logits = self.discriminator(real_data, training=True)

                # Calculate discriminator loss
                d_cost = self.d_loss_fn(real_=real_logits, fake_=fake_logits)
                # Calculate the gradient penalty
                gp = self.gradient_penalty(batch_size, real_data, fake_data)
                
                # Add the gradient penalty to the discriminator loss
                d_loss = d_cost + gp * self.gp_weight

            # Get the gradients w.r.t the discriminator loss
            d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
            # Update the weights of the discriminator
            self.d_optimizer.apply_gradients(
                zip(d_gradient, self.discriminator.trainable_variables)
            )

        #---------------------#
        # Train the GENERATOR #
        #---------------------#
        
        # Get the latent vector
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        
        with tf.GradientTape() as tape:
            # Generate fake images using the generator
            generated_data = self.generator(random_latent_vectors, training=True)
            # Get the discriminator logits for fake images
            gen_img_logits = self.discriminator(generated_data, training=True)
            # Calculate the generator loss
            g_loss = self.g_loss_fn(gen_img_logits)

        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )
        return {"d_loss": d_loss, "g_loss": g_loss}

In [None]:
generator_optimizer = keras.optimizers.Adam(learning_rate=0.0005, beta_1=0.5)
discriminator_optimizer = keras.optimizers.Adam(learning_rate=0.0005, beta_1=0.5)

def discriminator_loss(real_, fake_):
    real_loss = tf.reduce_mean(real_)
    fake_loss = tf.reduce_mean(fake_)
    return fake_loss - real_loss

def generator_loss(fake_):
    return -tf.reduce_mean(fake_)


# Get the wgan model
wgan = WGAN(
    discriminator=d_model,
    generator=g_model,
    latent_dim=noise_dim,
    discriminator_extra_steps=4,
    gp_weight = 10.0
)

# Compile the wgan model
wgan.compile(
    d_optimizer=discriminator_optimizer,
    g_optimizer=generator_optimizer,
    g_loss_fn=generator_loss,
    d_loss_fn=discriminator_loss,
)

Start the training

In [None]:
history = wgan.fit(dataset, batch_size=BATCH_SIZE, epochs=epochs)

Visualize the generator and the discriminator loss function

In [None]:
g_loss = history.history['g_loss']
d_loss = history.history['d_loss']
steps = range(1,len(g_loss)+1)

plt.figure(figsize=(15,5))
plt.plot(steps, g_loss, label = 'g_loss')
plt.plot(steps, d_loss, label = 'd_loss')

plt.title("Training Loss")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.legend()