In [12]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LeakyReLU, Reshape, Flatten
from tensorflow.keras.optimizers import Adam
import os

# Load and preprocess the MNIST dataset
(X_train, _), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5  # Normalize to [-1, 1]
X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)  # Reshape to (N, 28, 28, 1)

def build_generator():
    model = Sequential([
        Dense(256, input_dim=100, activation=LeakyReLU(0.2)),
        Dense(512, activation=LeakyReLU(0.2)),
        Dense(1024, activation=LeakyReLU(0.2)),
        Dense(784, activation='tanh'),  # Output an image of size 28x28 (flattened)
        Reshape((28, 28))  # Reshape to (28, 28) for the discriminator
    ])
    return model

def build_discriminator():
    model = Sequential([
        Flatten(input_shape=(28, 28)),  # Input shape as (28, 28)
        Dense(512, activation=LeakyReLU(0.2)),
        Dense(256, activation=LeakyReLU(0.2)),
        Dense(1, activation='sigmoid')  # Output a probability
    ])
    return model

generator = build_generator()
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])

gan = Sequential([generator, discriminator])
discriminator.trainable = False  # Freeze the discriminator when training the generator
gan.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))

def train_gan(epochs, batch_size):
    for epoch in range(epochs):
        # Train the discriminator
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        real_images = X_train[idx].reshape(-1, 28, 28)  # Reshape back to (28, 28)
        noise = np.random.normal(0, 1, (batch_size, 100))
        generated_images = generator.predict(noise)

        d_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
        d_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1)))
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        # Train the generator
        noise = np.random.normal(0, 1, (batch_size, 100))
        g_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))

        # Print the progress
        print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100*d_loss[1]:.2f}%] [G loss: {g_loss}]")

        # Save images at specific intervals
        if epoch % 2 == 0:  # Change save interval as needed
            save_images(epoch, generator)

def save_images(epoch, generator, examples=25, dim=(5, 5), figsize=(5, 5)):
    noise = np.random.normal(0, 1, (examples, 100))
    generated_images = generator.predict(noise)
    generated_images = 0.5 * generated_images + 0.5  # Rescale to [0, 1]

    plt.figure(figsize=figsize)
    for i in range(examples):
        plt.subplot(dim[0], dim[1], i + 1)
        plt.imshow(generated_images[i, :, :], cmap='gray')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig(f"gan_images/mnist_{epoch}.png")
    plt.close()

# Create output directory to save generated images
os.makedirs('gan_images', exist_ok=True)

# Train the GAN for 10 epochs
train_gan(epochs=10, batch_size=64)


[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step 




0 [D loss: 0.9596980810165405, acc.: 20.70%] [G loss: [array(0.8753231, dtype=float32), array(0.8753231, dtype=float32), array(0.2890625, dtype=float32)]]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 161ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step
1 [D loss: 0.9526920318603516, acc.: 19.14%] [G loss: [array(0.9576504, dtype=float32), array(0.9576504, dtype=float32), array(0.1640625, dtype=float32)]]
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
2 [D loss: 1.0227059125900269, acc.: 12.60%] [G loss: [array(1.0522028, dtype=float32), array(1.0522028, dtype=float32), array(0.11458334, dtype=float32)]]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 89ms/step
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step
3 [D loss: 1.0951828956604004, acc.: 10.25%] [G loss: [array(1.1361688, dtype=float32), array(1.1361688, dtype=float32), array(0.09570312, dtype=float32)]]
[1m2/2[0m [32m