In [10]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers

In [11]:
# define generator model
def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(256, input_shape=(100,), use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Dense(512, use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Dense(28*28*1, use_bias=False, activation='tanh'))
    model.add(layers.Reshape((28, 28, 1)))

    return model

In [12]:
# define discriminator model
def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Flatten(input_shape=(28, 28, 1)))
    model.add(layers.Dense(512))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Dense(256))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Dense(1))

    return model

In [13]:
# define generator and discriminator models
generator = make_generator_model()
discriminator = make_discriminator_model()

In [14]:
# define loss functions
#cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def cross_entropy(y_true, y_pred_logits, from_logits=True):
    if from_logits:
        y_pred = tf.sigmoid(y_pred_logits)
    else:
        y_pred = y_pred_logits
    
    # Clip values to avoid log(0) error
    y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)
    
    # Calculate binary crossentropy
    bce = -tf.reduce_mean(y_true * tf.math.log(tf.reduce_mean(y_pred)) + (1 - y_true) * tf.math.log(tf.reduce_mean(1 - y_pred)))
    
    return bce

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

In [15]:
# define optimizers
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [16]:
# define training loop
@tf.function
def train_step(images):
    noise = tf.random.normal([batch_size, 100])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

In [17]:
# prepare dataset
(train_images, _), (_, _) = tf.keras.datasets.fashion_mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]
batch_size = 256
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(60000).batch(batch_size)

In [18]:
# train the GAN
epochs = 10
for epoch in range(epochs):
    for batch in train_dataset:
        train_step(batch)
        
        # print the losses every 10 epochs
        noise = tf.random.normal([batch_size, 100])

        generated_images = generator(noise, training=False)

        real_output = discriminator(batch, training=False)
        fake_output = discriminator(generated_images, training=False)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    print(f"Epoch: {epoch+1}, Generator Loss: {gen_loss}, Discriminator Loss: {disc_loss}")

Epoch: 1, Generator Loss: 1.8221484422683716, Discriminator Loss: 0.19364988803863525
Epoch: 2, Generator Loss: 1.979777216911316, Discriminator Loss: 0.1817137598991394
Epoch: 3, Generator Loss: 2.4828176498413086, Discriminator Loss: 0.11585566401481628
Epoch: 4, Generator Loss: 2.6926114559173584, Discriminator Loss: 0.10206776857376099
Epoch: 5, Generator Loss: 1.8112430572509766, Discriminator Loss: 0.202473983168602
Epoch: 6, Generator Loss: 1.8614026308059692, Discriminator Loss: 0.221042200922966
Epoch: 7, Generator Loss: 2.5821361541748047, Discriminator Loss: 0.12662768363952637
Epoch: 8, Generator Loss: 2.8797495365142822, Discriminator Loss: 0.09308136999607086
Epoch: 9, Generator Loss: 2.1926252841949463, Discriminator Loss: 0.12632963061332703
Epoch: 10, Generator Loss: 2.619027853012085, Discriminator Loss: 0.08763390779495239
