In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization

# Load MNIST data
(X_train, _), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = X_train.reshape(-1, 784)

# Generator
generator = Sequential([
    Dense(128, input_dim=100),
    LeakyReLU(0.2),
    BatchNormalization(),
    Dense(256),
    LeakyReLU(0.2),
    BatchNormalization(),
    Dense(512),
    LeakyReLU(0.2),
    BatchNormalization(),
    Dense(784, activation='tanh')
])

# Discriminator
discriminator = Sequential([
    Dense(512, input_dim=784),
    LeakyReLU(0.2),
    Dense(256),
    LeakyReLU(0.2),
    Dense(1, activation='sigmoid')
])

# Compile discriminator
discriminator.compile(loss='binary_crossentropy', optimizer='adam')

# Combined network
discriminator.trainable = False
gan_input = generator.input
gan_output = discriminator(generator(gan_input))
gan = Model(gan_input, gan_output)
gan.compile(loss='binary_crossentropy', optimizer='adam')

# Create directory if it doesn't exist
if not os.path.exists('generated_images'):
    os.makedirs('generated_images')

# Training
epochs, batch_size, half_batch = 10000, 64, 32

for epoch in range(epochs):
    idx = np.random.randint(0, X_train.shape[0], half_batch)
    real_images, noise = X_train[idx], np.random.normal(0, 1, (half_batch, 100))

    d_loss_real = discriminator.train_on_batch(real_images, np.ones((half_batch, 1)))
    d_loss_fake = discriminator.train_on_batch(generator.predict(noise), np.zeros((half_batch, 1)))
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    gan_loss = gan.train_on_batch(noise, np.ones((half_batch, 1)))  # Adjusted labels to half_batch

    if epoch % 100 == 0:
        print(f"Epoch: {epoch}, D Loss: {d_loss}, G Loss: {gan_loss}")

    if epoch % 1000 == 0:
        img = generator.predict(np.random.normal(0, 1, (1, 100))).reshape(28, 28)
        plt.imshow(img, cmap='gray')
        plt.axis('off')
        plt.savefig(f"generated_images/gan_image_{epoch}.png")
        plt.close()


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch: 5100, D Loss: 0.5074338316917419, G Loss: 1.601813554763794
Epoch: 5200, D Loss: 0.35013996064662933, G Loss: 3.0591654777526855
Epoch: 5300, D Loss: 0.49765220284461975, G Loss: 2.80460786819458
Epoch: 5400, D Loss: 0.3805611655116081, G Loss: 2.084256172180176
Epoch: 5500, D Loss: 0.3550235778093338, G Loss: 2.8296520709991455
Epoch: 5600, D Loss: 0.32926318049430847, G Loss: 2.8968544006347656
Epoch: 5700, D Loss: 0.32019439339637756, G Loss: 2.688912868499756
Epoch: 5800, D Loss: 0.39469318091869354, G Loss: 2.638209342956543
Epoch: 5900, D Loss: 0.5745624899864197, G Loss: 3.1747641563415527
Epoch: 6000, D Loss: 0.28732632100582123, G Loss: 3.21665096282959
Epoch: 6100, D Loss: 0.27410974353551865, G Loss: 3.2724459171295166
Epoch: 6200, D Loss: 0.13448522984981537, G Loss: 3.4539709091186523
Epoch: 6300, D Loss: 0.2780590206384659, G Loss: 2.925626754760742
Epoch: 6400, D Loss: 0.3958194926381111, G Loss: 2.5

In [None]:
import os
print(os.getcwd())

/content
