In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LeakyReLU, Reshape, Flatten
from tensorflow.keras.optimizers import Adam

In [None]:
(X_train, _), (_, _) = mnist.load_data()
X_train = X_train / 127.5 - 1.0     # Normalize to [-1, 1]
X_train = X_train.reshape(-1, 784)  # Flatten to 1D (28*28)

In [None]:
generator = Sequential([
    Dense(128, input_dim=100),
    LeakyReLU(alpha=0.2),
    Dense(256),
    LeakyReLU(alpha=0.2),
    Dense(784, activation='tanh'),
    Reshape((28, 28))
])
generator.summary()

In [None]:
discriminator = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(256),
    LeakyReLU(alpha=0.2),
    Dense(128),
    LeakyReLU(alpha=0.2),
    Dense(1, activation='sigmoid')
])
discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
discriminator.summary()

In [None]:
discriminator.trainable = False
gan = Sequential([generator, discriminator])
gan.compile(loss='binary_crossentropy', optimizer='adam')

In [None]:
for epoch in range(100):
    # ---------------------
    # Train Discriminator
    # ---------------------
    batch_size=128
    idx = np.random.randint(0, X_train.shape[0], batch_size)
    real_imgs = X_train[idx].reshape(batch_size, 28, 28)

    noise = np.random.normal(0, 1, (batch_size, 100))
    fake_imgs = generator.predict(noise)

    real_labels = np.ones((batch_size, 1))
    fake_labels = np.zeros((batch_size, 1))

    d_loss_real = discriminator.train_on_batch(real_imgs, real_labels)
    d_loss_fake = discriminator.train_on_batch(fake_imgs, fake_labels)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # ---------------------
    # Train Generator
    # ---------------------
    noise = np.random.normal(0, 1, (batch_size, 100))
    valid_y = np.ones((batch_size, 1))  # generator wants to fool discriminator
    g_loss = gan.train_on_batch(noise, valid_y)

    # Print progress
    if epoch % 100 == 0:
        print(f"{epoch} [D loss: {d_loss[0]:.4f}, acc.: {100*d_loss[1]:.2f}%] [G loss: {g_loss:.4f}]")

    # Save sample images
    if epoch % 600 == 0:
        noise = np.random.normal(0, 1, (25,100))
        gen_imgs = generator.predict(noise)
        gen_imgs = 0.5 * gen_imgs + 0.5  # scale to [0,1]

        fig, axs = plt.subplots(5, 5, figsize=(5, 5))
        count = 0
        for i in range(5):
            for j in range(5):
                axs[i, j].imshow(gen_imgs[count], cmap='gray')
                axs[i, j].axis('off')
                count += 1
        plt.show()