In [1]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt


In [2]:
def build_generator(latent_dim):
    model = tf.keras.Sequential([
        layers.Dense(128, activation='relu', input_dim=latent_dim),
        layers.BatchNormalization(momentum=0.8),
        layers.Dense(256, activation='relu'),
        layers.BatchNormalization(momentum=0.8),
        layers.Dense(512, activation='relu'),
        layers.BatchNormalization(momentum=0.8),
        layers.Dense(28 * 28, activation='tanh'),
        layers.Reshape((28, 28, 1))
    ])
    return model


In [3]:
def build_discriminator():
    model = tf.keras.Sequential([
        layers.Flatten(input_shape=(28, 28, 1)),
        layers.Dense(512, activation='relu'),
        layers.Dense(256, activation='relu'),
        layers.Dense(1, activation='sigmoid')
    ])
    return model


In [4]:
def compile_gan(generator, discriminator):
    discriminator.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002), 
                          loss='binary_crossentropy', metrics=['accuracy'])
    discriminator.trainable = False  # Freeze discriminator weights for GAN training
    gan = tf.keras.Sequential([generator, discriminator])
    gan.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002), loss='binary_crossentropy')
    return gan


In [5]:
def load_data():
    (x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
    x_train = x_train / 127.5 - 1.0  # Normalize to [-1, 1]
    x_train = np.expand_dims(x_train, axis=-1)
    return x_train


In [7]:
def train_gan(generator, discriminator, gan, data, epochs, batch_size, latent_dim):
    half_batch = batch_size // 2
    
    for epoch in range(epochs):
        # Train discriminator
        idx = np.random.randint(0, data.shape[0], half_batch)
        real_imgs = data[idx]
        noise = np.random.normal(0, 1, (half_batch, latent_dim))
        fake_imgs = generator.predict(noise)
        
        real_labels = np.ones((half_batch, 1))
        fake_labels = np.zeros((half_batch, 1))
        
        d_loss_real = discriminator.train_on_batch(real_imgs, real_labels)
        d_loss_fake = discriminator.train_on_batch(fake_imgs, fake_labels)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        # Train generator
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        valid_labels = np.ones((batch_size, 1))
        g_loss = gan.train_on_batch(noise, valid_labels)
        
        # Display progress
        print(f"{epoch + 1}/{epochs}, Discriminator Loss: {d_loss[0]}, Generator Loss: {g_loss}")
        
        # Save samples periodically
        if epoch % 100 == 0:
            sample_images(generator, epoch, latent_dim)

def sample_images(generator, epoch, latent_dim):
    noise = np.random.normal(0, 1, (25, latent_dim))
    generated_images = generator.predict(noise)
    generated_images = 0.5 * generated_images + 0.5  # Rescale to [0, 1]

    fig, axes = plt.subplots(5, 5, figsize=(5, 5))
    for i, ax in enumerate(axes.flat):
        ax.imshow(generated_images[i, :, :, 0], cmap='gray')
        ax.axis('off')
    plt.savefig(f"generated_images_epoch_{epoch}.png")
    plt.close()


In [8]:
latent_dim = 10
epochs = 100
batch_size = 64

data = load_data()
generator = build_generator(latent_dim)
discriminator = build_discriminator()
gan = compile_gan(generator, discriminator)

train_gan(generator, discriminator, gan, data, epochs, batch_size, latent_dim)


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(**kwargs)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 155ms/step




1/100, Discriminator Loss: 0.6920542120933533, Generator Loss: [array(0.68450713, dtype=float32), array(0.68450713, dtype=float32), array(0.734375, dtype=float32)]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 110ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
2/100, Discriminator Loss: 0.7005037069320679, Generator Loss: [array(0.6988026, dtype=float32), array(0.6988026, dtype=float32), array(0.59375, dtype=float32)]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
3/100, Discriminator Loss: 0.7185268998146057, Generator Loss: [array(0.71478313, dtype=float32), array(0.71478313, dtype=float32), array(0.5833333, dtype=float32)]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
4/100, Discriminator Loss: 0.723709225654602, Generator Loss: [array(0.72235096, dtype=float32), array(0.72235096, dtype=float32), array(0.51953125, dtype=float32)]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0