In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [2]:
import tensorflow as tf
import keras
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt

from IPython import display
tf.test.gpu_device_name()

'/device:GPU:0'

In [3]:
# BASE_PATH = '/content/drive/Colab Notebooks'
BASE_PATH = '.'
BUFFER_SIZE = 30000
BATCH_SIZE = 32
EPOCHS = 50
NOISE_DIM = 128
SEED = tf.random.normal([16, NOISE_DIM])

In [4]:
def make_generator_model():
  model = tf.keras.Sequential()

  model.add(layers.Dense(8*8*128, use_bias=False, input_shape=(NOISE_DIM,)))
  model.add(layers.Reshape((8, 8, 128)))

  model.add(layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False))
  model.add(layers.LeakyReLU())

  model.add(layers.Conv2DTranspose(256, (5, 5), strides=(2, 2), padding='same', use_bias=False))
  model.add(layers.LeakyReLU())

  model.add(layers.Conv2DTranspose(512, (5, 5), strides=(2, 2), padding='same', use_bias=False))
  model.add(layers.LeakyReLU())

  model.add(layers.Conv2D(3, (5, 5), padding='same', use_bias=False, activation='sigmoid'))

  return model

In [5]:
def make_generator_model():
  generator = keras.Sequential(
    [
        keras.Input(shape=(NOISE_DIM,)),
        layers.Dense(8 * 8 * 128),
        layers.Reshape((8, 8, 128)),
        layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(3, kernel_size=5, padding="same", activation="sigmoid"),
    ],
    name="generator",
  )

  return generator

In [6]:
def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[64, 64, 3]))
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())

    model.add(layers.Flatten())
    model.add(layers.Dropout(0.3))
    model.add(layers.Dense(1))

    return model

In [7]:
def make_discriminator_model():
    discriminator = keras.Sequential(
    [
        keras.Input(shape=(64, 64, 3)),
        layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Flatten(),
        layers.Dropout(0.2),
        layers.Dense(1, activation="sigmoid"),
    ],
    name="discriminator",
  )

    return discriminator

In [8]:
class GAN(keras.Model):
  def __init__(self, generator, discriminator):
    super(GAN, self).__init__()
    self.generator = generator
    self.discriminator = discriminator
    self.generator_optimizers = tf.keras.optimizers.Adam(1e-4)
    self.discriminator_optimizers = tf.keras.optimizers.Adam(1e-4)

  def train_step(self, real_images):
    if isinstance(real_images, tuple):
      real_images = real_images[0]
    batch_size = tf.shape(real_images)[0]

    noise = tf.random.normal([batch_size, NOISE_DIM])
    generated_images = self.generator(noise, training=False)

    x = tf.concat([real_images, generated_images], axis=0)
    y = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0)
    y += 0.05 * tf.random.uniform(tf.shape(y))

    with tf.GradientTape() as tape:
      y_pred = self.discriminator(x, training=True)
      loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

    gradients = tape.gradient(loss, self.discriminator.trainable_variables)
    self.discriminator_optimizers.apply_gradients(zip(gradients, self.discriminator.trainable_variables))
    self.compiled_metrics.update_state(y, y_pred)
    # self.compiled_metrics.metrics[0].update_state(y, y_pred)

    noise = tf.random.normal([BATCH_SIZE, NOISE_DIM])
    y = tf.ones((BATCH_SIZE, 1))

    with tf.GradientTape() as tape:
      x = self.generator(noise, training=True)
      y_pred = self.discriminator(x, training=False)
      loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

    gradients = tape.gradient(loss, self.generator.trainable_variables)
    self.generator_optimizers.apply_gradients(zip(gradients, self.generator.trainable_variables))
    # self.compiled_metrics.metrics[1].update_state(y, y_pred)

    return {m.name: m.result() for m in self.metrics}

  def call(self, x):
    return self.generator(x)

In [9]:
def generate_and_save_images(model, epoch, x):
  y = model(x, training=False)

  fig = plt.figure(figsize=(4, 4))

  for i in range(16):
      plt.subplot(4, 4, i+1)
      plt.imshow(y[i, :, :, :])
      plt.axis('off')

  plt.savefig(f'{BASE_PATH}/images/image_at_epoch_{epoch}.png')
  plt.show()

In [10]:
dataset = tf.keras.preprocessing.image_dataset_from_directory(
    "cats", label_mode=None, image_size=(64, 64), batch_size=BATCH_SIZE
)

dataset = dataset.map(lambda x: x / 255.0)

Found 15747 files belonging to 1 classes.


In [11]:
generator_model = make_generator_model()
discriminator_model = make_discriminator_model()

gan = GAN(generator_model, discriminator_model)
gan.compile(
    loss = tf.keras.losses.BinaryCrossentropy(),
    metrics = [tf.keras.metrics.BinaryAccuracy(name='discriminator'), tf.keras.metrics.BinaryAccuracy(name='generator')],
    run_eagerly=True
)

checkpoint_dir = BASE_PATH + '/training_checkpoints'
checkpoint = tf.train.Checkpoint(gan=gan)

ckptManager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=10)
ckptManager.restore_or_initialize()

In [13]:
for i in range(50):
  gan.fit(dataset, epochs=1)

  display.clear_output(wait=True)
  generate_and_save_images(model = gan.generator,
                           epoch = i,
                           x = SEED)

  if (i+1) % 25 == 0:
    ckptManager.save()

 31/493 [>.............................] - ETA: 2:57 - loss: 0.7107 - discriminator: 0.0000e+00 - generator: 0.0000e+00