In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

import matplotlib.pyplot as plt
import re

%matplotlib inline

In [None]:
IMAGE_SIZE_NO_CROP = 128  # Size of image before cropping
IMAGE_SIZE = 64  # Shapes of input image
BATCH_SIZE = 64  # Batch size
DATA_PATH = "/kaggle/input/celeba-dataset/img_align_celeba"
RANDOM_SEED = 42

Z_DIM = 128  # Dimension of face's manifold
GENERATOR_DENSE_SIZE = 512
OUTPUT_CHANNELS = 3

tf.random.set_seed(RANDOM_SEED)

Checking available GPUs

In [None]:
print(tf.config.experimental.list_physical_devices("GPU"))
print(tf.test.gpu_device_name())

# Prepare dataset

Here we will check and prepare our data. We need the faces only. Images in the dataset are centered on eyes, so we will crop faces utilizing that fact.

I've found caching is extremely useful in this task. The whole dataset can be put into memory if you have >12GB RAM. Prefetching will also help us to utilize resources better.

Sometimes image_dataset_from_directory is slow as fuck. Also Kaggle won't let us to cache everything in memory and will kill the kernel during the training, that's frustrating.

Nevertheless training on the whole dataset will take some time (first epoch with BATCH_SIZE=64 takes ~1800 seconds to finish with GPU accelerator here, ~1300 seconds for BATCH_SIZE=512, after caching it's ~300 seconds per epoch).

In [None]:
num_images = len(os.listdir(os.path.join(DATA_PATH, "img_align_celeba")))
print(f"Num images: {num_images}")

In [None]:
celeb_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    DATA_PATH,
    label_mode=None,
    color_mode="rgb",
    batch_size=BATCH_SIZE,
    image_size=(IMAGE_SIZE_NO_CROP, IMAGE_SIZE_NO_CROP),
    seed=RANDOM_SEED,
)

In [None]:
CACHE_FILE = "cache"

def crop_face(image):
    height, width = image.shape[1], image.shape[2]

    offset_height = int(height * 0.35)
    offset_width = int(height * 0.27)

    image = tf.image.crop_to_bounding_box(
        image, offset_height, offset_width, int(width * 0.45), int(height * 0.45)
    )

    return image


def process(image):
    #     images are centered on eyes, we will crop faces utilizing that fact
#     image = crop_face(image)
#     image = tf.image.resize(image, [IMAGE_SIZE, IMAGE_SIZE], preserve_aspect_ratio=True)
    image = tf.cast((image - 127.5) / 127.5, tf.float32)
    return image


celeb_dataset = (
    celeb_dataset
    .map(process)
    .unbatch()
    .shuffle(1024)
    .repeat()
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
#         .cache(filename=CACHE_FILE)
)

In [None]:
def display_samples(ds, row, col):
    ds_iter = iter(ds)
    plt.figure(figsize=(15, int(15*row/col)))
    for j in range(row*col):
        example_sample = next(ds_iter)
        plt.subplot(row,col,j+1)
        plt.axis('off')
        plt.imshow(example_sample * 0.5 + 0.5)
    plt.show()

In [None]:
display_samples(celeb_dataset.unbatch(), 4, 6)

# Generator

In [None]:
def down_sample(filters, size, apply_instancenorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    layer = keras.Sequential()
    layer.add(layers.Conv2D(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))

    if apply_instancenorm:
        layer.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    layer.add(layers.LeakyReLU())

    return layer

def up_sample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    layer = keras.Sequential()
    layer.add(layers.Conv2DTranspose(filters, size, strides=2, padding='same', kernel_initializer=initializer,use_bias=False))
    layer.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    if apply_dropout:
        layer.add(layers.Dropout(0.5))

    layer.add(layers.ReLU())

    return layer

In [None]:
def Generator():
    inputs = layers.Input(shape=(Z_DIM, ))

    up_stack = [
        layers.Dense(4 * 4 * GENERATOR_DENSE_SIZE, use_bias=False),
        layers.BatchNormalization(),
        layers.LeakyReLU(),
        layers.Reshape((4, 4, GENERATOR_DENSE_SIZE)),# (size, 4, 4, 512)
        up_sample(512, 4, apply_dropout=True),       # (size, 8, 8, 512)
        up_sample(256, 4, apply_dropout=True),       # (size, 16, 16, 256)
        up_sample(128, 4),                           # (size, 32, 32, 128)
        up_sample(64, 4),                            # (size, 64, 64, 64)
    ]
    
    initializer = tf.random_normal_initializer(0., 0.02)
    last = layers.Conv2DTranspose(3, 4, strides=2, padding='same', kernel_initializer=initializer, activation='tanh') 
    # (size, 128, 128, 3)

    x = inputs

    # Upsampling and establishing the skip connections
    for up in up_stack:
        x = up(x)

    x = last(x)

    return keras.Model(inputs=inputs, outputs=x)

# Headless discriminator (shared part)

In [None]:
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
    
    inp = layers.Input(shape=[128, 128, 3], name='input_image')
    x = inp
    
    down1 = down_sample(64, 4, False)(x)       # (size, 64, 64, 64)
    down2 = down_sample(128, 4)(down1)         # (size, 32, 32, 128)
    down3 = down_sample(256, 4)(down2)         # (size, 16, 16, 256)

    zero_pad1 = layers.ZeroPadding2D()(down3) # (size, 18, 18, 256)
    conv = layers.Conv2D(512, 4, strides=1, kernel_initializer=initializer, use_bias=False)(zero_pad1) # (size, 15, 15, 512)

    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)
    leaky_relu = layers.LeakyReLU()(norm1)
    zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (size, 17, 17, 512)

    return tf.keras.Model(inputs=inp, outputs=zero_pad2)

# Head for two-objective discriminator

In [None]:
def DHead():
    initializer = tf.random_normal_initializer(0., 0.02)
    
    inp = layers.Input(shape=[17, 17, 512], name='input_image')
    x = inp
    
    last = layers.Conv2D(1, 4, strides=1, kernel_initializer=initializer)(x) # (size, 14, 14, 1)

    return tf.keras.Model(inputs=inp, outputs=last)

In [None]:
generator = Generator() 
discriminator = Discriminator() # differentiates real images and generated images
dHead1 = DHead() # Head for BCE
dHead2 = DHead() # Head for hinge loss

# DCGAN with 2-objective discriminator

In [None]:
class DCGan(tf.keras.Model):
    def __init__(
        self,
        generator,
        discriminator,
        dhead1,        
        dhead2
    ):
        super(DCGan, self).__init__()
        self.gen = generator
        self.disc = discriminator
        self.dhead1 = dhead1
        self.dhead2 = dhead2
        
    def compile(
        self,
        gen_optimizer,
        disc_optimizer,
        gen_loss_fn1,
        gen_loss_fn2,
        disc_loss_fn1,
        disc_loss_fn2,
        aug_fn
    ):
        super(DCGan, self).compile()
        self.gen_optimizer = gen_optimizer
        self.disc_optimizer = disc_optimizer
        self.gen_loss_fn1 = gen_loss_fn1
        self.gen_loss_fn2 = gen_loss_fn2
        self.disc_loss_fn1 = disc_loss_fn1
        self.disc_loss_fn2 = disc_loss_fn2
        self.aug_fn = aug_fn

        self.step_num = 0
        
    def train_step(self, real_image):
        batch_size = tf.shape(real_image)[0]
        noise = tf.random.uniform([batch_size, Z_DIM])
        
        with tf.GradientTape(persistent=True) as tape:
        
            # generates fake images from generator
            fake_image = self.gen(noise, training=True)

            # Diffaugment
            both_image = tf.concat([real_image, fake_image], axis=0)            
            
            aug_image = self.aug_fn(both_image)
            
            aug_real_image = aug_image[:batch_size]
            aug_fake_image = aug_image[batch_size:]
            
            
            # two-objective discriminator
            disc_fake_image1 = self.dhead1(self.disc(aug_fake_image, training=True), training=True)
            disc_real_image1 = self.dhead1(self.disc(aug_real_image, training=True), training=True)
            disc_fake_image2 = self.dhead2(self.disc(aug_fake_image, training=True), training=True)
            disc_real_image2 = self.dhead2(self.disc(aug_real_image, training=True), training=True)

            gen_loss1 = self.gen_loss_fn1(disc_fake_image1) 
            head_loss1 = self.disc_loss_fn1(disc_real_image1, disc_fake_image1)
            gen_loss2 = self.gen_loss_fn2(disc_fake_image2)
            head_loss2 = self.disc_loss_fn2(disc_real_image2, disc_fake_image2)

            total_gen_loss = (gen_loss1 + gen_loss2) * 0.4
            total_disc_loss = head_loss1 + head_loss2

        # Calculate the gradients for generator and discriminator
        generator_gradients = tape.gradient(total_gen_loss, 
                                            self.gen.trainable_variables)

        discriminator_gradients = tape.gradient(total_disc_loss, 
                                                self.disc.trainable_variables)
        

        # Heads gradients
        head_gradients1 = tape.gradient(head_loss1, 
                                        self.dhead1.trainable_variables)

        self.disc_optimizer.apply_gradients(zip(head_gradients1,
                                                  self.dhead1.trainable_variables))       

        head_gradients2 = tape.gradient(head_loss2, 
                                        self.dhead2.trainable_variables)
        self.disc_optimizer.apply_gradients(zip(head_gradients2, 
                                                  self.dhead2.trainable_variables))     
        
        
        
        # Apply the gradients to the optimizer
        self.gen_optimizer.apply_gradients(zip(generator_gradients,
                                                 self.gen.trainable_variables))

        self.disc_optimizer.apply_gradients(zip(discriminator_gradients, 
                                                self.disc.trainable_variables))
        
        return {
            "head_loss1": head_loss1, 
            "head_loss2": head_loss2, 
            "disc_real_image": disc_real_image1, 
            "disc_fake_image": disc_fake_image1, 
            "disc_real_image2": disc_real_image2, 
            "disc_fake_image2": disc_fake_image2, 
            "gen_loss": total_gen_loss, 
            "disc_loss": total_disc_loss
            }


In [None]:
def discriminator_loss1(real, generated):
    real_loss = tf.math.maximum(tf.zeros_like(real), tf.ones_like(real) - real)

    generated_loss = tf.math.maximum(tf.zeros_like(generated), generated + tf.ones_like(generated))

    total_disc_loss = real_loss + generated_loss

    return tf.reduce_mean(total_disc_loss * 0.5)

def discriminator_loss2(real, generated):
    generated_loss = tf.keras.losses.BinaryCrossentropy(
        from_logits=True, 
        reduction=tf.keras.losses.Reduction.NONE, 
        label_smoothing=0.05)(tf.ones_like(generated), generated)
    real_loss = tf.keras.losses.BinaryCrossentropy(
        from_logits=True, 
        reduction=tf.keras.losses.Reduction.NONE, 
        label_smoothing=0.05)(tf.zeros_like(real), real)
    total_disc_loss = real_loss + generated_loss

    return tf.reduce_mean(total_disc_loss * 0.5)

def generator_loss1(generated):
    return  tf.reduce_mean(-generated)

def generator_loss2(generated):
    return tf.reduce_mean(tf.keras.losses.BinaryCrossentropy(
        from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated))

In [None]:
# Differentiable Augmentation for Data-Efficient GAN Training
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
# https://arxiv.org/pdf/2006.10738
# from https://github.com/mit-han-lab/data-efficient-gans/blob/master/DiffAugment_tf.py


def DiffAugment(x, policy='', channels_first=False):
    if policy:
        if channels_first:
            x = tf.transpose(x, [0, 2, 3, 1])
        for p in policy.split(','):
            for f in AUGMENT_FNS[p]:
                x = f(x)
        if channels_first:
            x = tf.transpose(x, [0, 3, 1, 2])
    return x


def rand_brightness(x):
    magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) - 0.5
    x = x + magnitude
    return x


def rand_saturation(x):
    magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) * 2
    x_mean = tf.reduce_sum(x, axis=3, keepdims=True) * 0.3333333333333333333
    x = (x - x_mean) * magnitude + x_mean
    return x


def rand_contrast(x):
    magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) + 0.5
    x_mean = tf.reduce_sum(x, axis=[1, 2, 3], keepdims=True) * 5.086e-6
    x = (x - x_mean) * magnitude + x_mean
    return x

def rand_translation(x, ratio=0.125):
    batch_size = tf.shape(x)[0]
    image_size = tf.shape(x)[1:3]
    shift = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)
    translation_x = tf.random.uniform([batch_size, 1], -shift[0], shift[0] + 1, dtype=tf.int32)
    translation_y = tf.random.uniform([batch_size, 1], -shift[1], shift[1] + 1, dtype=tf.int32)
    grid_x = tf.clip_by_value(tf.expand_dims(tf.range(image_size[0], dtype=tf.int32), 0) + translation_x + 1, 0, image_size[0] + 1)
    grid_y = tf.clip_by_value(tf.expand_dims(tf.range(image_size[1], dtype=tf.int32), 0) + translation_y + 1, 0, image_size[1] + 1)
    x = tf.gather_nd(tf.pad(x, [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_x, -1), batch_dims=1)
    x = tf.transpose(tf.gather_nd(tf.pad(tf.transpose(x, [0, 2, 1, 3]), [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_y, -1), batch_dims=1), [0, 2, 1, 3])
    return x


def rand_cutout(x, ratio=0.5):
    batch_size = tf.shape(x)[0]
    image_size = tf.shape(x)[1:3]
    cutout_size = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)
    offset_x = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[0] + (1 - cutout_size[0] % 2), dtype=tf.int32)
    offset_y = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[1] + (1 - cutout_size[1] % 2), dtype=tf.int32)
    grid_batch, grid_x, grid_y = tf.meshgrid(tf.range(batch_size, dtype=tf.int32), tf.range(cutout_size[0], dtype=tf.int32), tf.range(cutout_size[1], dtype=tf.int32), indexing='ij')
    cutout_grid = tf.stack([grid_batch, grid_x + offset_x - cutout_size[0] // 2, grid_y + offset_y - cutout_size[1] // 2], axis=-1)
    mask_shape = tf.stack([batch_size, image_size[0], image_size[1]])
    cutout_grid = tf.maximum(cutout_grid, 0)
    cutout_grid = tf.minimum(cutout_grid, tf.reshape(mask_shape - 1, [1, 1, 1, 3]))
    mask = tf.maximum(1 - tf.scatter_nd(cutout_grid, tf.ones([batch_size, cutout_size[0], cutout_size[1]], dtype=tf.float32), mask_shape), 0)
    x = x * tf.expand_dims(mask, axis=3)
    return x


AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'translation': [rand_translation],
    'cutout': [rand_cutout],
}
def aug_fn(image):
    return DiffAugment(image,"color,translation,cutout")

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

# Training

In [None]:
EPOCHS = 30
NUM_SAMPLES_TO_GENERATE = 8
NUM_CHECKPOINT = 10


# You will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.uniform([NUM_SAMPLES_TO_GENERATE, Z_DIM])

In [None]:
def generate_and_save_images(model, epoch, test_input):
    # Notice `training` is set to False.
    # This is so all layers run in inference mode (batchnorm).
    predictions = model(test_input, training=False)

    fig = plt.figure(figsize=(8, 4), constrained_layout=True)

    for i in range(predictions.shape[0]):
        plt.subplot(2, 4, i + 1)
        plt.imshow((predictions[i].numpy() * 127.5 + 127.5).astype("uint8"))
        plt.axis("off")

    plt.savefig("image_at_epoch_{:04d}.png".format(epoch))
    plt.show()

In [None]:
class CheckpointCallback(tf.keras.callbacks.Callback):
    def __init__(self, manager):
        super(CheckpointCallback, self).__init__()
        self.manager = manager

    def on_epoch_end(self, epoch, logs=None):
        self.manager.save()
        
class SampleTestCallback(tf.keras.callbacks.Callback):
    def __init__(self, test_sample, num_epoch):
        super(SampleTestCallback, self).__init__()
        self.test_sample = test_sample
        self.num_epoch = num_epoch

    def on_epoch_end(self, epoch, logs=None):
        if epoch % self.num_epoch == 0:
            generate_and_save_images(self.model.gen, epoch, self.test_sample)

In [None]:
dc_gan_model = DCGan(
        generator, discriminator, dHead1,  dHead2
    )

In [None]:
dc_gan_model.compile(
    gen_optimizer = generator_optimizer,
    disc_optimizer = discriminator_optimizer,
    gen_loss_fn1 = generator_loss1,
    gen_loss_fn2 = generator_loss2,
    disc_loss_fn1 = discriminator_loss1,
    disc_loss_fn2 = discriminator_loss2,
    aug_fn = aug_fn ,
)

In [None]:
checkpoint_path = "./"
ckpt = tf.train.Checkpoint(
    generator_optimizer=generator_optimizer,
    discriminator_optimizer=discriminator_optimizer,
    generator=dc_gan_model.gen,
    discriminator=dc_gan_model.disc,
    dhead1=dc_gan_model.dhead1,
    dhead2=dc_gan_model.dhead2
)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=3)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

In [None]:
checkpoint_cb = CheckpointCallback(ckpt_manager)
test_cb = SampleTestCallback(seed, 1)

In [None]:
dc_gan_model.fit(celeb_dataset, epochs=EPOCHS, steps_per_epoch=3165, callbacks=[checkpoint_cb, test_cb])