In [None]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Dense, Flatten, Reshape, LeakyReLU
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt

# Hyperparameters
latent_dim = 100  # Dimension of random noise input
image_size = 28  # Size of MNIST images (28x28 pixels)
num_epochs = 5
batch_size = 128

def create_generator():
    model = tf.keras.Sequential()
    model.add(Dense(7 * 7 * 256, use_bias=False, input_shape=(latent_dim,)))
    model.add(Reshape((7, 7, 256)))
    model.add(Conv2DTranspose(128, (3, 3), strides=2, padding='same', activation='relu'))
    model.add(Conv2DTranspose(64, (3, 3), strides=2, padding='same', activation='relu'))
    model.add(Conv2DTranspose(1, (3, 3), activation='tanh', padding='same'))
    return model

def create_discriminator():
    model = tf.keras.Sequential()
    model.add(Conv2D(64, (3, 3), strides=2, padding='same', input_shape=(image_size, image_size, 1)))
    model.add(LeakyReLU(alpha=0.2))  
    model.add(Conv2D(128, (3, 3), strides=2, padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2D(256, (3, 3), strides=2, padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    return model

def create_gan(generator, discriminator):
    discriminator.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
    discriminator.trainable = False  

    gan_input = tf.keras.Input(shape=(latent_dim,))
    generated_image = generator(gan_input)
    gan_output = discriminator(generated_image)

    gan_model = Model(inputs=gan_input, outputs=gan_output)
    gan_model.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
    return gan_model

def train_gan(generator, discriminator, gan_model, dataset, epochs, batch_size, latent_dim):
    for epoch in range(epochs):
        for batch in range(dataset.shape[0] // batch_size):
            # Train discriminator
            noise = tf.random.normal(shape=(batch_size, latent_dim))
            fake_images = generator.predict(noise)
            real_images = dataset[np.random.randint(0, dataset.shape[0], batch_size)]

            combined_images = np.concatenate([real_images, fake_images])
            labels = np.concatenate([np.ones((batch_size, 1)), np.zeros((batch_size, 1))])
            labels += 0.05 * np.random.random(labels.shape)

            discriminator_loss = discriminator.train_on_batch(combined_images, labels)

            # Train generator
            noise = tf.random.normal(shape=(batch_size, latent_dim))
            misleading_targets = np.ones((batch_size, 1))

            generator_loss = gan_model.train_on_batch(noise, misleading_targets)

        print(f'Epoch {epoch + 1}, Discriminator Loss: {discriminator_loss}, Generator Loss: {generator_loss}')

def plot_generated_images(generator, latent_dim, examples=10, figsize=(10, 10)):
    noise = np.random.normal(0, 1, size=[examples, latent_dim])
    generated_images = generator.predict(noise)
    plt.figure(figsize=figsize)
    for i in range(examples):
        plt.subplot(1, examples, i+1)
        plt.imshow(generated_images[i, :, :, 0], cmap='gray')
        plt.axis('off')
    plt.tight_layout()
    plt.show()

# Load MNIST dataset
(train_images, _), (_, _) = mnist.load_data()
train_images = train_images.reshape(-1, image_size, image_size, 1)
train_images = train_images.astype('float32') / 255.0

# Create models
generator = create_generator()
discriminator = create_discriminator()
gan_model = create_gan(generator, discriminator)

# Train GAN
train_gan(generator, discriminator, gan_model, train_images, num_epochs, batch_size, latent_dim)

# Display generated images
plot_generated_images(generator, latent_dim)






Epoch 1, Discriminator Loss: 0.5266859531402588, Generator Loss: 1.0042624473571777






Epoch 2, Discriminator Loss: 0.6771271228790283, Generator Loss: 0.399431973695755




