In [2]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Dense, Reshape, Flatten, LeakyReLU, BatchNormalization
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam

# Create directory for saving images
os.makedirs('gan_images', exist_ok=True)

# Define the Generator
def build_generator():
    model = Sequential([
        Dense(256, input_dim=100),
        LeakyReLU(alpha=0.2),
        BatchNormalization(momentum=0.8),
        Dense(512),
        LeakyReLU(alpha=0.2),
        BatchNormalization(momentum=0.8),
        Dense(1024),
        LeakyReLU(alpha=0.2),
        BatchNormalization(momentum=0.8),
        Dense(28 * 28 * 1, activation='tanh'),
        Reshape((28, 28, 1))
    ])
    return model

# Define the Discriminator
def build_discriminator():
    model = Sequential([
        Flatten(input_shape=(28, 28, 1)),
        Dense(512),
        LeakyReLU(alpha=0.2),
        Dense(256),
        LeakyReLU(alpha=0.2),
        Dense(1, activation='sigmoid')
    ])
    return model

# Compile the Models
def compile_models(generator, discriminator):
    discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])

    discriminator.trainable = False
    z = tf.keras.Input(shape=(100,))
    img = generator(z)
    valid = discriminator(img)

    combined = tf.keras.Model(z, valid)
    combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))

    return combined

# Load and Preprocess Data using tf.data.Dataset
def load_data():
    (X_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
    X_train = (X_train - 127.5) / 127.5  # Normalize images to [-1, 1]
    X_train = np.expand_dims(X_train, axis=-1)  # Add channel dimension
    dataset = tf.data.Dataset.from_tensor_slices(X_train).shuffle(buffer_size=60000).batch(128)
    return dataset

# Training the GAN
def train(epochs, batch_size=128, save_interval=50):
    dataset = load_data()

    half_batch = batch_size // 2
    valid = np.ones((half_batch, 1))
    fake = np.zeros((half_batch, 1))

    for epoch in range(epochs):
        for imgs_batch in dataset:
            # Use the first half of the batch for real images
            imgs = imgs_batch[:half_batch]

            # Generate fake images
            noise = np.random.normal(0, 1, (half_batch, 100))
            gen_imgs = generator.predict(noise)

            # Train the discriminator
            d_loss_real = discriminator.train_on_batch(imgs, valid)
            d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # Generate fake images for the generator training
            noise = np.random.normal(0, 1, (batch_size, 100))
            valid_y = np.ones((batch_size, 1))

            # Train the generator
            g_loss = combined.train_on_batch(noise, valid_y)

            if epoch % save_interval == 0:
                print(f"{epoch} [D loss: {d_loss[0]} | D accuracy: {d_loss[1]}] [G loss: {g_loss}]")
                save_imgs(epoch)
            break  # Break after processing one batch for demonstration purposes

# Save Generated Images
def save_imgs(epoch):
    r, c = 5, 5
    noise = np.random.normal(0, 1, (r * c, 100))
    gen_imgs = generator.predict(noise)

    gen_imgs = 0.5 * gen_imgs + 0.5  # Rescale images to [0, 1]

    fig, axs = plt.subplots(r, c)
    for i in range(r):
        for j in range(c):
            axs[i, j].imshow(gen_imgs[i * c + j, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
    fig.savefig(f"gan_images/mnist_{epoch}.png")
    plt.close()

# Initialize Models
generator = build_generator()
discriminator = build_discriminator()
combined = compile_models(generator, discriminator)

# Train the GAN
train(epochs=10000, batch_size=64, save_interval=1000)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 106ms/step
0 [D loss: 0.6084439754486084 | D accuracy: 0.640625] [G loss: [array(0.65144455, dtype=float32), array(0.65144455, dtype=float32), array(0.5, dtype=float32)]]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 190ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 40ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 36ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m

KeyboardInterrupt: 

In [None]:
!apt-get install git
