In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np

In [None]:
data, info = tfds.load(
      "mnist",
      with_info=True,
      split="train"
    )

tfds.show_examples(data, info, plot_scale=2.0, rows=3, cols=6)

print(f"\nFeatures: {info.features}")
print(f"Loaded examples: {len(data)}")

In [None]:
data.element_spec

In [None]:
#add model layers
def extract_and_normalize(item):
  image = tf.cast(item["image"], tf.float32)
  image = (image - 127.5) / 127.5  # mapping pixel values to range [-1;1]
  return image

train_images = data.map(extract_and_normalize) \
  .cache() \
  .shuffle(buffer_size=2048) \
  .batch(256) \
  .prefetch(buffer_size=tf.data.AUTOTUNE)

print(f"Number of batches: {len(train_images)}")
print(f"Element spec: {train_images.element_spec}")

In [None]:
generator = tf.keras.Sequential([
  
  # add model layers
  tf.keras.layers.Input(shape=(100,)),
  tf.keras.layers.Dense(7*7*256),  # Adjust the size to match the target shape
  tf.keras.layers.Reshape(target_shape=(7, 7, 256)),
  tf.keras.layers.UpSampling2D(),
  tf.keras.layers.Conv2D(filters=128, kernel_size=3, padding='same', activation='relu'),
  tf.keras.layers.UpSampling2D(),
  tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu'),
  tf.keras.layers.Conv2D(filters=1, kernel_size=3, padding='same', activation='tanh')
], name="generator")

def output_model(data_element):
    image = data_element
    greyscale = tf.image.rgb_to_grayscale(image)
    greyscale = tf.image.resize(greyscale, size=[28,28,1])
    return greyscale

generator.summary(line_length=120)

In [None]:
# discriminator model
discriminator = tf.keras.Sequential()
discriminator.add(tf.keras.layers.Input(shape=[28,28,1]))
discriminator.add(tf.keras.layers.Conv2D(filters=128, kernel_size=3, activation='relu'))
discriminator.add(tf.keras.layers.MaxPooling2D())
discriminator.add(tf.keras.layers.Flatten())
discriminator.add(tf.keras.layers.Dense(units=2500))

discriminator.summary()

In [None]:
# helper function for calculating cross entropy
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

# note: input parameters (decisions) are discriminator verdicts: "are images real or fake?"

def generator_loss(decisions):
  expected = tf.ones_like(decisions) # we expect all generated images are recognized as "real" (all ones)
  return cross_entropy(expected, decisions)


def discriminator_loss(real_decisions, fake_decisions):
  real_expected = tf.ones_like(real_decisions) # we expect real images are recognized as "real" (all ones)
  real_loss = cross_entropy(real_expected, real_decisions)
  fake_expected = tf.zeros_like(fake_decisions) # and we expect fake images are recoginized as "fake" (all zeros)
  fake_loss = cross_entropy(fake_expected, fake_decisions)
  return real_loss + fake_loss

In [None]:
gen_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0005)
disc_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0005)

@tf.function
def train_step(real_images):
  batch_size = len(real_images)
  gen_inputs = tf.random.normal(shape=(batch_size, 100))

  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    fake_images = generator(gen_inputs, training=True)

    real_decisions = discriminator(real_images, training=True)
    fake_decisions = discriminator(fake_images, training=True)

    gen_loss = generator_loss(fake_decisions)
    disc_loss = discriminator_loss(real_decisions, fake_decisions)

  gen_gradients = gen_tape.gradient(gen_loss, generator.trainable_variables)
  gen_optimizer.apply_gradients(zip(gen_gradients, generator.trainable_variables))

  disc_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
  disc_optimizer.apply_gradients(zip(disc_gradients, discriminator.trainable_variables))

  return gen_loss, disc_loss

In [None]:
seed_inputs = tf.random.normal(shape=(6, 100))

for epoch in range(25):
    gen_losses = []
    disc_losses = []

    for batch in train_images:
        gen_loss, disc_loss = train_step(batch)
        gen_losses.append(gen_loss)
        disc_losses.append(disc_loss)
        
        avg_gen_loss = sum(gen_losses) / len(gen_losses)
        avg_disc_loss = sum(disc_losses) / len(disc_losses)
        print(f'\rEpoch {epoch + 1}/{25} | Average Gen Loss: {avg_gen_loss} | Average Disc Loss: {avg_disc_loss}', end='', flush=True)
