In [None]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Dense, Reshape, Flatten, Dropout, LeakyReLU, Conv2DTranspose, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt
import os


In [None]:
# Load MNIST data
(X_train, _), (_, _) = mnist.load_data()
X_train = X_train / 255.0  # Normalize to [0, 1]
X_train = np.expand_dims(X_train, axis=-1)
X_train, X_test = train_test_split(X_train, test_size=0.2, random_state=42)

In [None]:
print(X_train.shape)

In [None]:
# Generator model
def build_generator(z_dim):
    model = Sequential()
    model.add(Dense(7*7*256, input_dim=z_dim))
    model.add(Reshape((7, 7, 256)))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))

    model.add(Conv2DTranspose(128, kernel_size=4, strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))

    model.add(Conv2DTranspose(64, kernel_size=4, strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))

    model.add(Conv2DTranspose(1, kernel_size=4, strides=1, padding='same', activation='tanh'))
    return model

# Discriminator model
def build_discriminator(img_shape):
    model = Sequential()
    model.add(Conv2D(64, kernel_size=4, strides=2, padding='same', input_shape=img_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))

    model.add(Conv2D(128, kernel_size=4, strides=2, padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))

    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    return model


In [None]:
# Define GAN components
batch_size = 256
z_dim = 100
wasserstein_loss = lambda y_true, y_pred: K.mean(y_true * y_pred)
lambda_gp = 10
epsilon = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=0, maxval=1)

# Initialize models
generator = build_generator(z_dim)
discriminator = build_discriminator((28, 28, 1))


In [None]:
generator.summary()

In [None]:
def gradient_penalty(discriminator, real_imgs, fake_imgs, epsilon, batch_size):
    interpolated_imgs = epsilon * tf.cast(real_imgs, tf.float32) + ((1 - epsilon) * tf.cast(fake_imgs, tf.float32))

    with tf.GradientTape() as tape:
        tape.watch(interpolated_imgs)
        pred = discriminator(interpolated_imgs, training = True)

    gradients = tape.gradient(pred, [interpolated_imgs])[0]
    grad_norms = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]))
    gradient_penalty = tf.reduce_mean(tf.square(grad_norms - 1))

    return gradient_penalty


In [None]:
# Function to plot generated images
def plot_generated_images(generator, epoch, z_dim, examples=16, dim=(4, 4), figsize=(10, 10)):
    noise = np.random.normal(0, 1, size=[examples, z_dim])
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(examples, 28, 28)

    plt.figure(figsize=figsize)
    for i in range(examples):
        plt.subplot(dim[0], dim[1], i + 1)
        plt.imshow(generated_images[i], interpolation='nearest', cmap='gray')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig(f'gan_output_sample_{epoch}.png')

In [None]:
gen_optimizer = Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.9)
critic_optimizer = Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.9)

In [None]:
def show_losses(critic_losses, gen_losses):
    # Plot losses
    plt.figure(figsize=(7, 3))
    plt.plot(critic_losses, label='Critic Loss')
    plt.plot(gen_losses, label='Generator Loss')
    plt.xlabel('Batch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

In [None]:
@tf.function
def train_critic(discriminator, critic_optimizer, real_imgs, fake_imgs):
    # Train discriminator
    with tf.GradientTape() as tape:
        critic_loss_real = discriminator(real_imgs)
        critic_loss_fake = discriminator(fake_imgs)
        gp = gradient_penalty(discriminator, real_imgs, fake_imgs, epsilon, batch_size)
        critic_loss = tf.reduce_mean(critic_loss_fake) - tf.reduce_mean(critic_loss_real) + (lambda_gp * gp)
    gradients = tape.gradient(critic_loss, discriminator.trainable_variables)
    critic_optimizer.apply_gradients(zip(gradients, discriminator.trainable_variables))
    return critic_loss

@tf.function
def train_generator(generator, discriminator, gen_optimizer):
    z = np.random.randn(batch_size, z_dim)
    valid_y = np.ones((batch_size, 1))

    with tf.GradientTape() as tape:
        critic_pred = discriminator(generator(z), training = True)
        g_loss = -tf.reduce_mean(critic_pred)
    gradients = tape.gradient(g_loss, generator.trainable_variables)
    gen_optimizer.apply_gradients(zip(gradients, generator.trainable_variables))
    return g_loss

# Training function
def train_wgan(generator, discriminator, epochs=100, batch_size=64, n_critic_train = 3):
    critic_losses, gen_losses = [], []

    # Calculate number of batches per epoch
    num_batches = X_train.shape[0] // batch_size

    for epoch in range(epochs):
        np.random.shuffle(X_train)
        print(f"Epoch {epoch}")
        n_critic_count = 1
        for batch in range(num_batches):
            real_imgs = X_train[batch * batch_size: (batch * batch_size) + batch_size]
            z = np.random.randn(batch_size, z_dim)
            fake_imgs = generator.predict(z)

            # Train critic
            critic_loss = train_critic(discriminator, critic_optimizer, real_imgs, fake_imgs)
            n_critic_count += 1

            if n_critic_count > n_critic_train:
                # Train generator
                g_loss = train_generator(generator, discriminator, gen_optimizer)
                n_critic_count = 1
                # Save gen losses
                gen_losses.append(g_loss)

            # Save critic losses
            critic_losses.append(critic_loss)

        # Print progress
        # os.system('cls' if os.name == 'nt' else 'clear')
        print(f"D Loss: {np.mean(critic_losses)}, G Loss: {np.mean(gen_losses)}")
        plot_generated_images(generator, epoch, z_dim)
        show_losses(critic_losses, gen_losses)

    return critic_losses, gen_losses


In [None]:
# Train DCGAN
D_losses, g_losses = train_wgan(generator, discriminator, epochs = 10, batch_size = batch_size, n_critic_train=3)

# Save the generator model
generator.save('simple_gan_generator.h5')

In [None]:

def generate_images(generator, z_points):
    """Generate images from z-space points using the generator."""
    images = generator.predict(z_points)
    return images

# Generate random points in z-space
z_dim = 100  # Dimension of the z-space (latent space)
num_samples = 16  # Number of images to generate
z_points = np.random.randn(num_samples, z_dim)

# Generate images from these points
generated_images = generate_images(generator, z_points)

# Plot the generated images
fig, axes = plt.subplots(4, 4, figsize=(10, 10))
for i, ax in enumerate(axes.flatten()):
    ax.imshow(generated_images[i].reshape(28, 28), cmap='gray')
    ax.axis('off')

plt.tight_layout()
plt.show()