In [1]:
# download the dataset from https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/static/wiki_crop.tar

# !wget https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/static/wiki_crop.tar
# !tar -xvf wiki_crop.tar

In [2]:
# import the packages
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
import time
from datetime import datetime
from keras import Input, Model
from keras.applications import InceptionResNetV2
from keras import backend as K
from keras.callbacks import TensorBoard
from keras.layers import (
    Conv2D,
    Flatten,
    Dense,
    BatchNormalization,
    Reshape,
    concatenate,
    LeakyReLU,
    Lambda,
    Activation,
    UpSampling2D,
    Dropout
)
from keras.optimizers import Adam
from keras.utils import to_categorical
from keras_preprocessing import image
from scipy.io import loadmat

In [3]:
def build_encoder():
    """
    Build the encoder network that encodes an image (x) to a latent vector (z)
    or a latent vector representation.
    """
    input_layer = Input(shape=(64, 64, 3))

    # 1st convolutional block
    enc = Conv2D(
        filters=32, kernel_size=5, strides=2, padding="same"
    )(input_layer)
    enc = LeakyReLU(alpha=0.2)(enc)

    # 2nd convolutional block
    enc = Conv2D(
        filters=64, kernel_size=5, strides=2, padding="same"
    )(enc)
    enc = BatchNormalization()(enc)
    enc = LeakyReLU(alpha=0.2)(enc)

    # 3rd convolutional block
    enc = Conv2D(
        filters=128, kernel_size=5, strides=2, padding="same"
    )(enc)
    enc = BatchNormalization()(enc)
    enc = LeakyReLU(alpha=0.2)(enc)

    # 4th convolutional block
    enc = Conv2D(
        filters=256, kernel_size=5, strides=2, padding="same"
    )(enc)
    enc = BatchNormalization()(enc)
    enc = LeakyReLU(alpha=0.2)(enc)

    # flattening layer
    enc = Flatten()(enc)

    # 1st fully-connected Layer
    enc = Dense(4096)(enc)
    enc = BatchNormalization()(enc)
    enc = LeakyReLU(alpha=0.2)(enc)

    # 2nd fully-connected Layer
    enc = Dense(100)(enc)

    # create the model and return it
    model = Model(inputs=[input_layer], outputs=[enc])
    return model

In [4]:
def build_generator():
    """
    Build the Generator.
    It takes a 100-dimensional vector z and generates an image with a dimension of (64, 64, 3).
    """
    # define the hyperparameters
    latent_dims = 100
    num_classes = 6

    input_z_noise = Input(shape=(latent_dims,))
    input_label = Input(shape=(num_classes,))

    # the generator will take both the noise vector and desired class label as input
    x = concatenate([input_z_noise, input_label])

    # 1st fully-connected block
    x = Dense(2048, input_dim=latent_dims + num_classes)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dropout(0.2)(x)

    # 2nd fully-connected block
    x = Dense(256 * 8 * 8)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dropout(0.2)(x)

    x = Reshape((8, 8, 256))(x)

    # 1st upsampling block
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(filters=128, kernel_size=5, padding="same")(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = LeakyReLU(alpha=0.2)(x)

    # 2nd upsampling block
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(filters=64, kernel_size=5, padding="same")(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = LeakyReLU(alpha=0.2)(x)

    # 3rd upsampling block
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(filters=3, kernel_size=5, padding="same")(x)
    x = Activation("tanh")(x)

    # create the model and return it
    model = Model(inputs=[input_z_noise, input_label], outputs=[x])
    return model

In [5]:
def expand_label_input(x):
    """
    Expand label_input so that it has a shape of (32, 32, 6) and not (6,)
    """
    x = K.expand_dims(x, axis=1)
    x = K.expand_dims(x, axis=1)
    x = K.tile(x, [1, 32, 32, 1])
    return x

The discriminator network is a CNN.

In [6]:
def build_discriminator():
    """
    Create a CNN-based Discriminator
    """
    # define the hyperparameters
    input_shape = (64, 64, 3)
    label_shape = (6,)
    image_input = Input(shape=input_shape)
    label_input = Input(shape=label_shape)

    # 1st convolutional block for the image input
    x = Conv2D(64, kernel_size=3, strides=2, padding="same")(image_input)
    x = LeakyReLU(alpha=0.2)(x)

    label_input1 = Lambda(expand_label_input)(label_input)
    x = concatenate([x, label_input1], axis=3)

    # 1st convolutional block
    x = Conv2D(128, kernel_size=3, strides=2, padding="same")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)

    # 2nd convolutional block
    x = Conv2D(256, kernel_size=3, strides=2, padding="same")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)

    # 3rd convolutional block
    x = Conv2D(512, kernel_size=3, strides=2, padding="same")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.2)(x)

    # flattening layer
    x = Flatten()(x)

    # 1st fully-connected block
    x = Dense(1, activation="sigmoid")(x)

    # create the model and return it
    model = Model(inputs=[image_input, label_input], outputs=[x])
    return model

In [8]:
def build_fr_model(input_shape):
    """
    Function to build the face recognition model.
    """
    # using ResNet of 164 layers because it has excellent performance
    resnet_model = InceptionResNetV2(
        include_top=False,
        weights="imagenet",
        input_shape=input_shape,
        pooling="avg"
    )
    image_input = resnet_model.input
    out = Dense(128)(x)
    embedder_model = Model(inputs=[image_input], outputs=[out])

    input_layer = Input(shape=input_shape)

    x = resnet_model.layers[-1].output
    x = embedder_model(input_layer)
    output = Lambda(lambda x: K.l2_normalize(x, axis=-1))(x)

    model = Model(inputs=[input_layer], outputs=[output])
    return model

In [7]:
def build_fr_combined_network(encoder, generator, fr_model):
    """
    Function to build the face recognition combined network.
    """
    # freeze the weights of the model responsible for facial recognition
    fr_model.trainable = False

    input_image = Input(shape=(64, 64, 3))
    input_label = Input(shape=(6,))

    # encode the image to a latent vector representation
    latent0 = encoder(input_image)

    # generate artificial images
    gen_images = generator([latent0, input_label])

    # resize the images generated by the generator for input to the model responsible for facial recognition
    resized_images = Lambda(
        lambda x: K.resize_images(
            gen_images,
            height_factor=2,
            width_factor=2,
            data_format="channels_last"
        )
    )(gen_images)
    embeddings = fr_model(resized_images)

    # create the model and return it
    model = Model(inputs=[input_image, input_label], outputs=[embeddings])
    return model

In [9]:
"""
Utility functions
"""

def build_image_resizer():
    """
    Function to resize the images from a shape of (64, 64, 3) to a shape of (192, 192, 3)
    """
    input_layer = Input(shape=(64, 64, 3))
    factor = int(192 / 64)

    resized_images = Lambda(
        lambda x: K.resize_images(
            x,
            height_factor=factor,
            width_factor=factor,
            data_format="channels_last"
        )
    )(input_layer)

    # create the model and return it
    model = Model(inputs=[input_layer], outputs=[resized_images])
    return model


def calculate_age(taken, dob):
    """
    Function to calculate the age of the person from the serial date number and the year the photo was taken.
    """
    birth = datetime.fromordinal(max(int(dob) - 366, 1))

    if birth.month < 7:
        return taken - birth.year
    else:
        return taken - birth.year - 1

In [10]:
def load_data(wiki_dir, dataset='wiki'):
    """
    Function to retrieve images and their corresponding ages from the directory
    """
    # load the .mat file
    meta = loadmat(os.path.join(wiki_dir, "{}.mat".format(dataset)))

    # load the list of all files
    full_path = meta[dataset][0, 0]["full_path"][0]

    # list of Matlab serial date numbers
    dob = meta[dataset][0, 0]["dob"][0]

    # list of years when photo was taken
    photo_taken = meta[dataset][0, 0]["photo_taken"][0]

    # calculate age for all dobs
    age = [calculate_age(photo_taken[i], dob[i]) for i in range(len(dob))]

    # create a list of tuples containing a pair of an image path and age
    images = []
    age_list = []
    for index, image_path in enumerate(full_path):
        images.append(image_path[0])
        age_list.append(age[index])

    # return a list of all images and respective age
    return images, age_list

To convert the age numerical value to the age category.

In [11]:
def age_to_category(age_list):
    """
    Functio to convert the age's numerical value to a category.
    
    The ranges are arbitrarily chosen and can be changed later.
    """
    print(f"age_list length: {len(age_list)}")

    age_list_cat = []

    for age in age_list:
        if 0 < age <= 18:
            cat = 0
        elif 18 < age <= 29:
            cat = 1
        elif 29 < age <= 39:
            cat = 2
        elif 39 < age <= 49:
            cat = 3
        elif 49 < age <= 59:
            cat = 4
        elif age >= 60:
            cat = 5
        age_list_cat.append(cat)

    return age_list_cat

In [12]:
def load_images(data_dir, image_paths, image_shape):
    """
    Function to load all images and create an ndarray containing all images.
    """
    images = None
    for ind, image_path in enumerate(image_paths):
        print(f"index: {ind} and image path: {image_path}")
        try:
            # load the image
            loaded_image = image.load_img(
                os.path.join(data_dir, image_path),
                target_size=image_shape
            )

            # convert the PIL image to a NumPy ndarray
            loaded_image = image.img_to_array(loaded_image)

            # add another (batch) dimension
            loaded_image = np.expand_dims(loaded_image, axis=0)

            # concatenate all the images into an array
            if images is None:
                images = loaded_image
            else:
                images = np.concatenate(
                    [images, loaded_image],
                    axis=0
                )
        except Exception as e:
            print(f"Error {e} at index {ind}")

    return images

In [13]:
def euclidean_distance_loss(y_true, y_pred):
    """
    Calculate the Euclidean distance: https://en.wikipedia.org/wiki/Euclidean_distance
    """
    return K.sqrt(K.sum(K.square(y_pred - y_true), axis=-1))


def write_log(callback, name, value, batch_no):
    summary = tf.Summary()
    summary_value = summary.value.add()
    summary_value.simple_value = value
    summary_value.tag = name
    callback.writer.add_summary(summary, batch_no)
    callback.writer.flush()

def save_rgb_img(img, path):
    """
    Save an RGB image to the directory
    """
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    ax.imshow(img)
    ax.axis("off")
    ax.set_title("Image")
    plt.savefig(path)
    plt.close()

In [15]:
"""
Define the training hyperparameters
"""
data_dir = "./data"
wiki_dir = os.path.join(data_dir, "wiki_crop")
epochs = 500
batch_size = 2
image_shape = (64, 64, 3)
z_shape = 100
TRAIN_GAN = True
TRAIN_ENCODER = False
TRAIN_GAN_WITH_FR = False
fr_image_shape = (192, 192, 3)

"""
Define the optimizers
"""
# optimizer for the discriminator
dis_optimizer = Adam(
    lr=0.0002, beta_1=0.5, beta_2=0.999, epsilon=10e-8
)
# optimizer for the generator
gen_optimizer = Adam(
    lr=0.0002, beta_1=0.5, beta_2=0.999, epsilon=10e-8
)
# optimizer for the GAN network
adversarial_optimizer = Adam(
    lr=0.0002, beta_1=0.5, beta_2=0.999, epsilon=10e-8
)

"""
Build and compile the networks
"""
# the discriminator
discriminator = build_discriminator()
discriminator.compile(
    loss=["binary_crossentropy"],
    optimizer=dis_optimizer
)
discriminator.trainable = False

# the generator
generator = build_generator()
generator.compile(
    loss=["binary_crossentropy"],
    optimizer=gen_optimizer
)

# the GAN
input_z_noise = Input(shape=(100,))
input_label = Input(shape=(6,))
recons_images = generator([input_z_noise, input_label])
valid = discriminator([recons_images, input_label])
adversarial_model = Model(
    inputs=[input_z_noise, input_label],
    outputs=[valid]
)
adversarial_model.compile(
    loss=["binary_crossentropy"],
    optimizer=gen_optimizer
)

# initialize the tensorboard
tensorboard = TensorBoard(log_dir=f"logs/{time.time()}")
tensorboard.set_model(generator)
tensorboard.set_model(discriminator)

# load the dataset
images, age_list = load_data(wiki_dir=wiki_dir, dataset="wiki")
print(f"Number of images = {len(images)}")
print(f"age_list size = {len(age_list)}")

# convert the numeric age to categorical age
age_cat = np.array(age_to_category(age_list))
final_age_cat = np.reshape(age_cat, [len(age_cat), 1])

# get the unique classes by converting the list to a set
classes = len(set(age_cat))
y = to_categorical(final_age_cat, num_classes=len(set(age_cat)))

loaded_images = load_images(
    wiki_dir,
    images,
    (image_shape[0], image_shape[1])
)

# label smoothing
real_labels = np.ones((batch_size, 1), dtype=np.float32) * 0.9
fake_labels = np.zeros((batch_size, 1), dtype=np.float32) * 0.1

"""
Train the generator and the discriminator
"""
if TRAIN_GAN:
    for epoch in range(epochs):
        print(f"Epoch = {epoch}")

        gen_losses = []
        dis_losses = []

        number_of_batches = int(len(loaded_images) / batch_size)
        print(f"Number of batches = {number_of_batches}")
        for index in range(number_of_batches):
            print(f"Batch = {index + 1}")

            images_batch = loaded_images[
                index * batch_size: (index + 1) * batch_size
            ]
            # normalize the images
            images_batch = images_batch / 127.5 - 1.0
            images_batch = images_batch.astype(np.float32)

            y_batch = y[index * batch_size:(index + 1) * batch_size]

            # generate the noise vector
            z_noise = np.random.normal(0, 1, size=(batch_size, z_shape))

            """
            Train the discriminator network
            """

            # generate fake images
            initial_recon_images = generator.predict_on_batch(
                [z_noise, y_batch]
            )

            d_loss_real = discriminator.train_on_batch(
                [images_batch, y_batch],
                real_labels
            )
            d_loss_fake = discriminator.train_on_batch(
                [initial_recon_images, y_batch],
                fake_labels
            )

            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            print(f"Discriminator loss = {d_loss}")

            """
            Train the generator network
            """

            # generate the noise vector
            z_noise2 = np.random.normal(
                0,
                1,
                size=(batch_size, z_shape)
            )

            # generate the labels for the second input of the generator
            random_labels = np.random.randint(
                0,
                6,
                batch_size
            ).reshape(-1, 1)
            random_labels = to_categorical(random_labels, 6)

            g_loss = adversarial_model.train_on_batch(
                [z_noise2, random_labels],
                [1] * batch_size
            )

            print(f"Generator loss = {g_loss}")

            gen_losses.append(g_loss)
            dis_losses.append(d_loss)

        # write the losses to Tensorboard
        write_log(tensorboard, "g_loss", np.mean(gen_losses), epoch)
        write_log(tensorboard, "d_loss", np.mean(dis_losses), epoch)

        """
        Generate images after every 10th epoch
        """
        if epoch % 10 == 0:
            images_batch = loaded_images[0: batch_size]
            # normalize the images
            images_batch = images_batch / 127.5 - 1.0
            images_batch = images_batch.astype(np.float32)

            y_batch = y[0: batch_size]
            z_noise = np.random.normal(
                0,
                1,
                size=(batch_size, z_shape)
            )

            gen_images = generator.predict_on_batch(
                [z_noise, y_batch]
            )

            for ind, img in enumerate(gen_images[:5]):
                save_rgb_img(
                    img, path=f"./results/img_{epoch}_{ind}.png"
                )

    # save the trained networks
    try:
        generator.save_weights("generator.h5")
        discriminator.save_weights("discriminator.h5")
    except Exception as e:
        print(f"Error {e} encountered")

"""
Train the encoder
"""

if TRAIN_ENCODER:
    # build and compile the encoder
    encoder = build_encoder()
    encoder.compile(
        loss=euclidean_distance_loss,
        optimizer="adam"
    )

    # load the generator's weights
    try:
        generator.load_weights("generator.h5")
    except Exception as e:
        print(f"Error {e} encountered")

    z_i = np.random.normal(0, 1, size=(5000, z_shape))

    y = np.random.randint(
        low=0,
        high=6,
        size=(5000,),
        dtype=np.int64
    )
    # get the unique classes by converting the list to a set
    num_classes = len(set(y))
    y = np.reshape(np.array(y), [len(y), 1])
    y = to_categorical(y, num_classes=num_classes)

    for epoch in range(epochs):
        print(f"Epoch = {epoch}")

        encoder_losses = []
        number_of_batches = int(z_i.shape[0] / batch_size)
        print(f"Number of batches = {number_of_batches}")
        for index in range(number_of_batches):
            print(f"Batch = {index + 1}")

            z_batch = z_i[
                index * batch_size: (index + 1) * batch_size
            ]
            y_batch = y[
                index * batch_size: (index + 1) * batch_size
            ]

            generated_images = generator.predict_on_batch(
                [z_batch, y_batch]
            )

            # train the encoder model
            encoder_loss = encoder.train_on_batch(
                generated_images,
                z_batch
            )
            print(f"Encoder loss = {encoder_loss}")

            encoder_losses.append(encoder_loss)

        # write the encoder loss to Tensorboard
        write_log(
            tensorboard,
            "encoder_loss",
            np.mean(encoder_losses),
            epoch
        )

    # save the encoder for further use
    encoder.save_weights("encoder.h5")

"""
Optimize the encoder and the generator
"""
if TRAIN_GAN_WITH_FR:

    # load the encoder network
    encoder = build_encoder()
    encoder.load_weights("encoder.h5")

    # load the generator network
    generator.load_weights("generator.h5")

    image_resizer = build_image_resizer()
    image_resizer.compile(
        loss=["binary_crossentropy"],
        optimizer="adam"
    )

    # face recognition model
    fr_model = build_fr_model(input_shape=fr_image_shape)
    fr_model.compile(
        loss=["binary_crossentropy"],
        optimizer="adam"
    )

    # freeze the face recognition model's weights
    fr_model.trainable = False

    # input layers
    input_image = Input(shape=(64, 64, 3))
    input_label = Input(shape=(6,))

    # use the encoder and, then, the generator from its output
    latent0 = encoder(input_image)
    gen_images = generator([latent0, input_label])

    # resize images to the desired shape
    resized_images = Lambda(
        lambda x: K.resize_images(
            gen_images,
            height_factor=3,
            width_factor=3,
            data_format="channels_last"
        )
    )(gen_images)
    embeddings = fr_model(resized_images)

    # create a GAN and specify its inputs and outputs
    fr_adversarial_model = Model(
        inputs=[input_image, input_label],
        outputs=[embeddings]
    )

    # compile the GAN
    fr_adversarial_model.compile(
        loss=euclidean_distance_loss,
        optimizer=adversarial_optimizer
    )

    for epoch in range(epochs):
        print(f"Epoch = {epoch}")

        reconstruction_losses = []

        number_of_batches = int(len(loaded_images) / batch_size)
        print(f"Number of batches = {number_of_batches}")
        for index in range(number_of_batches):
            print(f"Batch = {index + 1}")

            images_batch = loaded_images[
                index * batch_size:(index + 1) * batch_size
            ]
            # normalize the images
            images_batch = images_batch / 127.5 - 1.0
            images_batch = images_batch.astype(np.float32)

            y_batch = y[index * batch_size:(index + 1) * batch_size]

            images_batch_resized = image_resizer.predict_on_batch(
                images_batch
            )

            real_embeddings = fr_model.predict_on_batch(
                images_batch_resized
            )

            reconstruction_loss = fr_adversarial_model.train_on_batch(
                [images_batch, y_batch],
                real_embeddings
            )

            print(f"Reconstruction loss = {reconstruction_loss}")
            reconstruction_losses.append(reconstruction_loss)

        # write the reconstruction loss to Tensorboard
        write_log(
            tensorboard,
            "reconstruction_loss",
            np.mean(reconstruction_losses),
            epoch
        )

        """
        Generate images
        """
        if epoch % 10 == 0:
            images_batch = loaded_images[0:batch_size]
            # normalize the images
            images_batch = images_batch / 127.5 - 1.0
            images_batch = images_batch.astype(np.float32)

            y_batch = y[0: batch_size]
            z_noise = np.random.normal(
                0, 1, size=(batch_size, z_shape)
            )

            gen_images = generator.predict_on_batch([z_noise, y_batch])

            for ind, img in enumerate(gen_images[:5]):
                save_rgb_img(img, path=f"./results/img_opt_{epoch}_{ind}.png")

    # save the improved weights for both models
    generator.save_weights("generator_optimized.h5")
    encoder.save_weights("encoder_optimized.h5")