# **Data Augmentation with (Conditional) Gan** 

**The objective is to verify if a dataset of synthetic images generated from a GAN can be used to effectively train a classifier, and how it compares to training the classifier with the original training set.**
**Verify if synthetic images can serve as additional data in addition to the original training set.**

Dataset - **CIFAR10**


In [7]:
import tensorflow as tf
import numpy as np

# Loading the CIFAR-10 dataset 
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# --- Preprocessing ---
# Normalize pixel values from the [0, 255] range to the [-1, 1] range.
# This is a common practice for GANs as it helps the generator's
# output (using a tanh activation) match the real image distribution.
x_train = (x_train.astype('float32') - 127.5) / 127.5

# Print the shape of the training data to confirm
print("Shape of training images:", x_train.shape)
print("Shape of training labels:", y_train.shape)

import tensorflow as tf

gpus = tf.config.list_physical_devices('GPU')

if gpus:
  # If GPUs are found, TensorFlow will automatically use them
  print(f" GPU(s) found: {len(gpus)}")
  for gpu in gpus:
    print(f"  - {gpu}")
else:
  # If no GPUs are found, TensorFlow will use the CPU
  print(" No GPU found. TensorFlow will use the CPU.")

Shape of training images: (50000, 32, 32, 3)
Shape of training labels: (50000, 1)
 No GPU found. TensorFlow will use the CPU.


#### Building the Discriminator 

In [8]:
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Concatenate, Embedding
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, LeakyReLU
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
import numpy as np

# --- Define Constants ---
IMG_SHAPE = (32, 32, 3)
NUM_CLASSES = 10
LATENT_DIM = 128 # Increased latent dim for potentially better quality

def build_critic():
    # Input for the image
    img_input = Input(shape=IMG_SHAPE)

    # Input for the class label
    label_input = Input(shape=(1,))
    
    # Convert label into a dense vector and reshape
    label_embedding = Embedding(NUM_CLASSES, 50)(label_input)
    label_embedding = Dense(IMG_SHAPE[0] * IMG_SHAPE[1])(label_embedding)
    label_embedding = Reshape((IMG_SHAPE[0], IMG_SHAPE[1], 1))(label_embedding)

    # Combine the label embedding and the image
    concatenated_input = Concatenate()([img_input, label_embedding])

    x = Conv2D(64, kernel_size=4, strides=2, padding='same')(concatenated_input)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(128, kernel_size=4, strides=2, padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(256, kernel_size=4, strides=2, padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Flatten()(x)
    
    # --- KEY CHANGE ---
    # No sigmoid activation. The critic outputs a raw score (a scalar).
    x = Dense(1)(x)

    # Create the critic model
    critic = Model([img_input, label_input], x, name="critic")
    return critic

#### Building the Generator

In [9]:
def build_generator():
    # Input for the random noise
    noise_input = Input(shape=(LATENT_DIM,))

    # Input for the class label
    label_input = Input(shape=(1,))
    
    # Process the label
    label_embedding = Embedding(NUM_CLASSES, 50)(label_input)
    label_embedding = Dense(4 * 4)(label_embedding)
    label_embedding = Reshape((4, 4, 1))(label_embedding)

    # Process the noise
    noise = Dense(256 * 4 * 4, activation='relu')(noise_input)
    noise = Reshape((4, 4, 256))(noise)

    # Combine the processed noise and label
    concatenated_input = Concatenate()([noise, label_embedding])

    # Upsample to a full-sized image
    x = Conv2DTranspose(128, kernel_size=4, strides=2, padding='same', activation='relu')(concatenated_input)
    x = Conv2DTranspose(128, kernel_size=4, strides=2, padding='same', activation='relu')(x)
    x = Conv2DTranspose(128, kernel_size=4, strides=2, padding='same', activation='relu')(x)
    
    # Output layer
    x = Conv2D(3, kernel_size=5, padding='same', activation='tanh')(x)

    # Create the generator model
    generator = Model([noise_input, label_input], x, name="generator")
    return generator

#### Building and Compiling the cGAN

In [11]:
class WGAN(tf.keras.Model):
    def __init__(self, critic, generator, latent_dim, critic_extra_steps=5, gp_weight=10.0):
        super().__init__()
        self.critic = critic
        self.generator = generator
        self.latent_dim = latent_dim
        self.c_extra_steps = critic_extra_steps
        self.gp_weight = gp_weight

    def compile(self, c_optimizer, g_optimizer):
        super().compile()
        self.c_optimizer = c_optimizer
        self.g_optimizer = g_optimizer
        # We don't use the standard Keras loss, so we define our own metrics
        self.c_loss_metric = tf.keras.metrics.Mean(name="c_loss")
        self.g_loss_metric = tf.keras.metrics.Mean(name="g_loss")

    @property
    def metrics(self):
        return [self.c_loss_metric, self.g_loss_metric]

    def gradient_penalty(self, batch_size, real_images, fake_images, labels):
        """ Calculates the gradient penalty. """
        # Get the interpolated image
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 1. Get the critic output for this interpolated image.
            pred = self.critic([interpolated, labels], training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 3. Calculate the norm of the gradients.
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

    def train_step(self, data):
        real_images, labels = data
        batch_size = tf.shape(real_images)[0]

        # --- Train the Critic ---
        # WGAN-GP trains the critic multiple times for each generator train step
        for i in range(self.c_extra_steps):
            # Get random noise
            random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
            with tf.GradientTape() as tape:
                # Generate fake images from the latent vector
                fake_images = self.generator([random_latent_vectors, labels], training=True)
                # Get the critic scores for real images
                real_output = self.critic([real_images, labels], training=True)
                # Get the critic scores for fake images
                fake_output = self.critic([fake_images, labels], training=True)

                # Calculate the Wasserstein loss
                c_cost = tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)
                # Calculate the gradient penalty
                gp = self.gradient_penalty(batch_size, real_images, fake_images, labels)
                # Add the gradient penalty to the original discriminator loss
                c_loss = c_cost + gp * self.gp_weight

            # Calculate and apply gradients
            c_gradient = tape.gradient(c_loss, self.critic.trainable_variables)
            self.c_optimizer.apply_gradients(zip(c_gradient, self.critic.trainable_variables))

        # --- Train the Generator ---
        # Get random noise
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        with tf.GradientTape() as tape:
            # Generate fake images
            generated_images = self.generator([random_latent_vectors, labels], training=True)
            # Get the critic scores for the fake images
            gen_img_output = self.critic([generated_images, labels], training=True)
            # Calculate the generator loss (we want to maximize the critic's score for fake images)
            g_loss = -tf.reduce_mean(gen_img_output)

        # Calculate and apply gradients
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        self.g_optimizer.apply_gradients(zip(gen_gradient, self.generator.trainable_variables))
        
        # Update metrics
        self.c_loss_metric.update_state(c_loss)
        self.g_loss_metric.update_state(g_loss)
        
        return {"c_loss": self.c_loss_metric.result(), "g_loss": self.g_loss_metric.result()}

#### Training Loop

In [None]:
import matplotlib.pyplot as plt

# Create a callback to periodically save generated images
class ImageSampler(tf.keras.callbacks.Callback):
    def __init__(self, num_img=10, latent_dim=128):
        self.num_img = num_img
        self.latent_dim = latent_dim

    def on_epoch_end(self, epoch, logs=None):
        if epoch % 5 == 0: # Sample every 5 epochs
            noise = tf.random.normal(shape=(self.num_img, self.latent_dim))
            labels = tf.constant(np.arange(0, self.num_img).reshape(-1, 1))
            
            generated_images = self.model.generator.predict([noise, labels])
            
            # Rescale images from [-1, 1] to [0, 1] for plotting
            generated_images = (generated_images + 1) / 2.0
            
            fig, axs = plt.subplots(1, self.num_img, figsize=(15, 3))
            for i in range(self.num_img):
                axs[i].imshow(generated_images[i])
                axs[i].set_title(f"Class: {labels[i].numpy()[0]}")
                axs[i].axis('off')
            plt.show()
            plt.close()

# --- Hyperparameters ---
EPOCHS = 1000 # WGAN-GP often converges faster.
BATCH_SIZE = 64

# --- Build and Compile the WGAN-GP ---
critic = build_critic()
generator = build_generator()

wgan = WGAN(critic=critic, generator=generator, latent_dim=LATENT_DIM)

wgan.compile(
    c_optimizer=Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.9),
    g_optimizer=Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.9),
)

# --- Start Training ---
# Create an instance of our image sampler callback
sampler_callback = ImageSampler(latent_dim=LATENT_DIM)

# Use tf.data to prepare the dataset for efficient training
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(BATCH_SIZE)

# The .fit() method will now handle the entire training loop for you!
wgan.fit(train_dataset, epochs=EPOCHS, callbacks=[sampler_callback])