In [10]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers

In [11]:
# define generator model
def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(256, input_shape=(100,), use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Dense(512, use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Dense(32 * 32 * 3, use_bias=False, activation='tanh'))
    model.add(layers.Reshape((32, 32, 3)))

    return model

In [12]:
# define discriminator model
def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[32, 32, 3]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

In [13]:
# define generator and discriminator models
generator = make_generator_model()
discriminator = make_discriminator_model()

In [14]:
# define loss functions
#cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def cross_entropy(y_true, y_pred_logits, from_logits=True):
    if from_logits:
        y_pred = tf.sigmoid(y_pred_logits)
    else:
        y_pred = y_pred_logits
    
    # Clip values to avoid log(0) error
    y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)
    
    # Calculate binary crossentropy
    bce = -tf.reduce_mean(y_true * tf.math.log(tf.reduce_mean(y_pred)) + (1 - y_true) * tf.math.log(tf.reduce_mean(1 - y_pred)))
    
    return bce

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

In [15]:
#define optimizers
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [16]:
# define training loop
@tf.function
def train_step(images):
    noise = tf.random.normal([batch_size, 100])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

In [17]:
# prepare dataset
(train_images, _), (_, _) = tf.keras.datasets.cifar10.load_data()
train_images = train_images.astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]
batch_size = 256
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(50000).batch(batch_size)

In [18]:
# train the GAN
epochs = 10
for epoch in range(epochs):
    for batch in train_dataset:
        train_step(batch)

        # print the losses every epoch
        noise = tf.random.normal([batch_size, 100])

        generated_images = generator(noise, training=False)

        real_output = discriminator(batch, training=False)
        fake_output = discriminator(generated_images, training=False)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    print(f"Epoch: {epoch+1}, Generator Loss: {gen_loss}, Discriminator Loss: {disc_loss}")

2023-04-23 12:45:13.645500: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential_3/dropout_2/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
2023-04-23 12:45:23.384575: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:954] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape insequential_3/dropout_2/dropout/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


Epoch: 1, Generator Loss: 1.9165880680084229, Discriminator Loss: 0.23474445939064026
Epoch: 2, Generator Loss: 3.1865994930267334, Discriminator Loss: 0.5697989463806152
Epoch: 3, Generator Loss: 1.7766401767730713, Discriminator Loss: 0.2567160129547119
Epoch: 4, Generator Loss: 2.1520726680755615, Discriminator Loss: 0.17120583355426788
Epoch: 5, Generator Loss: 1.8998396396636963, Discriminator Loss: 0.33121994137763977
Epoch: 6, Generator Loss: 1.655651330947876, Discriminator Loss: 0.6732302904129028
Epoch: 7, Generator Loss: 1.4891929626464844, Discriminator Loss: 0.5664693117141724
Epoch: 8, Generator Loss: 1.2592154741287231, Discriminator Loss: 0.8274476528167725
Epoch: 9, Generator Loss: 0.6781526207923889, Discriminator Loss: 0.8258140683174133
Epoch: 10, Generator Loss: 0.8363693952560425, Discriminator Loss: 0.6060590147972107
