# ****GANs for Data Augmentation in Image Classification****

In the field of image classification, obtaining a large and diverse dataset is crucial for training robust machine learning models. However, real-world datasets are often limited, imbalanced, and lack the variety needed to train models effectively. This project aims to address the challenge of dataset limitations by leveraging Generative Adversarial Networks (GANs) to augment existing image datasets. The goal is to generate synthetic images that are indistinguishable from real images, thereby enhancing the dataset’s size and diversity. This augmentation is expected to improve the accuracy and generalization ability of image classification models, especially in scenarios with limited or imbalanced data. The project will explore the effectiveness of GAN-generated images in augmenting datasets and their impact on the performance of classification algorithms. The ultimate objective is to develop a methodology that can be applied to various image classification tasks, ensuring models are trained on datasets that better represent the complexity of real-world visual data


In [1]:
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten
from keras.layers import BatchNormalization
from keras.layers import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np

In [2]:
(X_train, _), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [3]:
def build_generator():
    model = Sequential()
    model.add(Dense(256, input_dim=100))
    model.add(LeakyReLU(0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(784, activation='tanh'))
    model.add(Reshape((28, 28, 1)))
    noise = Input(shape=(100,))
    img = model(noise)
    return Model(noise, img)

In [4]:
# Define the discriminator
def build_discriminator():
    model = Sequential()
    model.add(Flatten(input_shape=(28, 28, 1)))
    model.add(Dense(1024))
    model.add(LeakyReLU(0.2))
    model.add(Dense(512))
    model.add(LeakyReLU(0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(0.2))
    model.add(Dense(1, activation='sigmoid'))
    img = Input(shape=(28, 28, 1))
    validity = model(img)
    return Model(img, validity)

In [5]:
# Build and compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])


In [6]:
# Build the generator
generator = build_generator()
z = Input(shape=(100,))
img = generator(z)

In [8]:
# For the combined model we will only train the generator
discriminator.trainable = False
validity = discriminator(img)

In [9]:
# The combined model  (stacked generator and discriminator)
# Trains the generator to fool the discriminator
combined = Model(z, validity)
combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))

In [11]:
for epoch in range(30):
    # Select a random half of images
    idx = np.random.randint(0, X_train.shape[0], 128)
    imgs = X_train[idx]

    # Sample noise and generate a batch of new images
    noise = np.random.normal(0, 1, (128, 100))
    gen_imgs = generator.predict(noise)

    # Train the discriminator (real classified as ones and generated as zeros)
    d_loss_real, d_acc_real = discriminator.train_on_batch(imgs, np.ones((128, 1)))
    d_loss_fake, d_acc_fake = discriminator.train_on_batch(gen_imgs, np.zeros((128, 1)))
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
    d_acc = 0.5 * np.add(d_acc_real, d_acc_fake)

    # Train the generator (wants discriminator to mistake images as real)
    g_loss = combined.train_on_batch(noise, np.ones((128, 1)))

    # If at save interval => save generated image samples and plot progress
    if epoch % 1000 == 0:
        print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss, 100*d_acc, g_loss))


0 [D loss: 0.003639, acc.: 100.00%] [G loss: 7.131785]
