In [8]:
#To run in colab
# https://colab.research.google.com/drive/1J9XeMLnjOuTYhGx_9jiEBGKYvnIlLV2n?usp=sharing

import tensorflow as tf
from tensorflow.keras import layers, models, losses
import numpy as np
import matplotlib.pyplot as plt

In [9]:

# Load MNIST dataset
(train_images, _), (_, _) = tf.keras.datasets.mnist.load_data()

# Normalize images
train_images = train_images.astype('float32') / 255.0

# Set random seed for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

In [10]:

# Define generator model
def build_generator(latent_dim):
    model = models.Sequential([
        layers.Input(shape=(latent_dim,)),
        layers.Dense(7 * 7 * 256, use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Reshape((7, 7, 256)),
        layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='sigmoid')
    ])
    return model


In [11]:

# Define discriminator model
def build_discriminator():
    model = models.Sequential([
        layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]),
        layers.LeakyReLU(),
        layers.Dropout(0.3),
        layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(),
        layers.Dropout(0.3),
        layers.Flatten(),
        layers.Dense(1)
    ])
    return model


In [12]:

# Define GAN model
def build_gan(generator, discriminator):
    discriminator.trainable = False
    model = models.Sequential([generator, discriminator])
    return model

# Compile discriminator
discriminator = build_discriminator()
discriminator.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
                      loss=losses.BinaryCrossentropy(from_logits=True))

# Compile GAN
latent_dim = 100
generator = build_generator(latent_dim)
gan = build_gan(generator, discriminator)
gan.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5),
            loss=losses.BinaryCrossentropy(from_logits=True))

In [13]:
# Define helper function for generating images
def generate_images(generator, latent_dim, num_images=25):
    noise = tf.random.normal([num_images, latent_dim])
    generated_images = generator.predict(noise)
    plt.figure(figsize=(10, 10))
    for i in range(num_images):
        plt.subplot(5, 5, i + 1)
        plt.imshow(generated_images[i, :, :, 0], cmap='gray')
        plt.axis('off')
    plt.show()

In [18]:


# Training loop
def train_gan(gan, discriminator, generator, train_images, latent_dim, epochs=10, batch_size=128):
    generator_losses = []
    discriminator_losses = []
    for epoch in range(epochs):
        for i in range(train_images.shape[0] // batch_size):
            # Train discriminator
            noise = tf.random.normal([batch_size, latent_dim])
            generated_images = generator.predict(noise)
            real_images = train_images[np.random.randint(0, train_images.shape[0], batch_size)]
            combined_images = np.concatenate([real_images, generated_images])
            labels = np.concatenate([np.ones((batch_size, 1)), np.zeros((batch_size, 1))])
            labels += 0.05 * np.random.random(labels.shape)
            discriminator_loss = discriminator.train_on_batch(combined_images, labels)

            # Train generator
            noise = tf.random.normal([batch_size, latent_dim])
            misleading_labels = np.ones((batch_size, 1))
            generator_loss = gan.train_on_batch(noise, misleading_labels)
            generator_losses.append(generator_loss)
            discriminator_losses.append(discriminator_loss)
        # Print progress
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch + 1}/{epochs}, Discriminator Loss: {discriminator_loss}, Generator Loss: {generator_loss}")

        # Generate and save sample images
        if (epoch + 1) % 10 == 0:
            generate_images(generator, latent_dim)
            
    return generator_losses, discriminator_losses

# Train GAN
latent_dim = 100
epochs = 10
generator_losses, discriminator_losses = train_gan(gan, discriminator, generator, train_images.reshape(-1, 28, 28, 1), latent_dim, epochs)






KeyboardInterrupt: 

In [None]:

# Plot loss curves
plt.plot(generator_losses, label='Generator Loss')
plt.plot(discriminator_losses, label='Discriminator Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
