# ðŸ¤ª AEI_NET - CelebA Faces

In this notebook, we'll walk through the steps required to train your own AEI_NET on the CelebA faces dataset

In [None]:
PRETRAINING COULD TAKE A WHILE BUT ONLY NEEDS TO BE RUN ONE TIME

In [None]:
from preprocess_inswapper import do_inswapper_pretraining
do_inswapper_pretraining(use_fixed_image=True)

In [None]:
%load_ext autoreload
%autoreload 2


# Limit TensorFlow to 80% of GPU memory
from gpu_memory import limit_gpu_memory 
limit_gpu_memory(0.35)

import os
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras import optimizers


from notebooks.utils import display

In [None]:
#from tensorflow.keras import mixed_precision
#mixed_precision.set_global_policy('mixed_float16')

## 0. Parameters <a name="parameters"></a>

In [None]:
IMAGE_SIZE = 64
CHANNELS = 3
BATCH_SIZE = 16
NUM_FEATURES = 128
Z_DIM = 200
LEARNING_RATE = 0.0005
EPOCHS = 1
BETA = 2000
LOAD_MODEL = True
TAKE_BATCHES = 500

## 1. Prepare the data <a name="prepare"></a>

In [None]:
from face_analysis import FaceAnalysis
from inswapper import INSwapper
from face import Face

PROVIDERS = ['CUDAExecutionProvider', 'CPUExecutionProvider']

face_analyser = FaceAnalysis()
face_analyser.prepare(ctx_id=0, det_size=(640, 640))
inswapper = INSwapper('/root/.insightface/models/inswapper_128.onnx')
emap = inswapper.emap

In [None]:
def load_and_preprocess_image(file_path):
    image = tf.io.read_file(file_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [IMAGE_SIZE, IMAGE_SIZE])
    #image = tf.cast(image, tf.float32) / 255.0  # Normalize to [0, 1]
    return image

data_dir = "/app/data/celeba-dataset/img_align_celeba/img_align_celeba/"
file_pattern = f"{data_dir}*.jpg"
file_count = len([f for f in os.listdir(data_dir) if f.endswith('.jpg')])
list_ds = tf.data.Dataset.list_files(file_pattern, shuffle=False)

train_data = list_ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
train_data = train_data.batch(BATCH_SIZE)
train_data = train_data.prefetch(buffer_size=tf.data.AUTOTUNE)

In [None]:
# Preprocess the data
def preprocess(img):
    return tf.cast(img, "float32") / 255.0

In [None]:
# Load the fixed image 999999.jpg and preprocess it
fixed_img_from_path = "/app/data/celeba-dataset/img_align_celeba/img_align_celeba/999999.jpg"
fixed_img_from = load_and_preprocess_image(fixed_img_from_path)
use_fixed_image = True

In [None]:
def get_target(img_into, img_from):
    def process_image(img_into_tensor, img_from_tensor):
        try:
            # Convert tensor to NumPy array
            img_into_np = img_into_tensor.numpy()
            img_from_np = img_from_tensor.numpy()

            # Get the embedding using OpenCV-based ArcFace model
            faces_into = face_analyser.get(img_into_np)
            faces_from = face_analyser.get(img_from_np)
            faces_into_sorted = sorted(faces_into, key=lambda x: x.bbox[0])
            faces_from_sorted = sorted(faces_from, key=lambda x: x.bbox[0])
            if faces_into_sorted and faces_from_sorted:
                face_into = faces_into_sorted[0]
                face_from = faces_from_sorted[0]
                result = inswapper.get(img_into_np, face_into, face_from, paste_back=True)
                embed = face_from.normed_embedding
                embed = np.dot(embed, emap)
                embed /= np.linalg.norm(embed)
                return result.astype(np.float32), embed
            else:
                # Generate Gaussian noise with the same shape as img_np
                noise = np.random.normal(loc=127.5, scale=50.0, size=img_into_np.shape)
                # Clip values to ensure they are within [0, 255]
                noise = np.clip(noise, 0, 255).astype(np.uint8)
                embed = np.random.normal(size=(512,)).astype(np.float32)
                return noise.astype(np.float32), embed
        except Exception as e:
            print(f"Error while in process_image:\n{e}")
            # Generate Gaussian noise with the same shape as img_np
            noise = np.random.normal(loc=127.5, scale=50.0, size=img_into_np.shape)
            # Clip values to ensure they are within [0, 255]
            noise = np.clip(noise, 0, 255).astype(np.uint8)
            embed = np.random.normal(size=(512,)).astype(np.float32)
            return noise.astype(np.float32), embed
            


    border_size = 50

    img_into_padded = tf.pad(
        img_into,
        paddings=[[border_size, border_size], [border_size, border_size], [0, 0]],
        mode='CONSTANT',
        constant_values=255,
    )

    img_from_padded = tf.pad(
        img_from,
        paddings=[[border_size, border_size], [border_size, border_size], [0, 0]],
        mode='CONSTANT',
        constant_values=255,
    )

    # Wrap the processing function with tf.py_function
    Y_target_padded, embed = tf.py_function(func=process_image, inp=[img_into_padded, img_from_padded], Tout=(tf.float32, tf.float32))
    Y_target = Y_target_padded[
        border_size:-border_size,   
        border_size:-border_size,   
        :
    ]
    
    return Y_target, embed

    
#functions.preprocess_face(img=image_path, target_size=(112, 112), enforce_detection=False)

# Function to generate random embed and package inputs and outputs
def prepare_inputs(img):
    # Preprocess the image
    img_processed = preprocess(img)
    # Get the batch size dynamically
    # Generate random embed with shape (batch_size, 256)
    #embed = tf.random.normal([batch_size, 256])  # Assuming c_id = 256
    # The model expects inputs: [image, embed]
    #Y_target = img_processed  # Target output image
    # Use tf.py_function to wrap get_target

    shuffled_indices = tf.random.shuffle(tf.range(BATCH_SIZE))
    img_random = tf.gather(img, shuffled_indices)
    indices = tf.range(BATCH_SIZE)
    def get_target_pair(idx):
        img_i = img[idx]          # Original image
        img_j = img_random[idx]   # Randomly selected image
        if use_fixed_image:
            img_j = fixed_img_from
        return get_target(img_i, img_j)
    
    #Y_target = tf.map_fn(get_target, img, dtype=tf.float32)
    Y_target, embed = tf.map_fn(get_target_pair, indices, dtype=(tf.float32, tf.float32))
    
    Y_target.set_shape(img.shape)
    embed.set_shape([BATCH_SIZE, 512])
    
    Y_target = preprocess(Y_target)

    return ((img_processed, embed), Y_target)


# Apply the mapping to your dataset
train = train_data.map(prepare_inputs, num_parallel_calls=tf.data.AUTOTUNE)

# Prefetch data to improve latency
train = train.prefetch(buffer_size=tf.data.AUTOTUNE)

In [None]:
((img_batch, embed_batch), Y_target) = next(iter(train))
train_sample = img_batch.numpy()
result_sample = Y_target.numpy()

In [None]:
# Show some faces from the training set
display(train_sample, 8, cmap=None)
display(result_sample, 8, cmap=None)
print(result_sample[0])
print(embed_batch.shape)
print(embed_batch)

## 2. Build the AEI_NET <a name="build"></a>

In [None]:
from aei_net import get_model
model = get_model()

## 3. Train the AEI_NET <a name="train"></a>

In [None]:
# Compile the variational autoencoder
optimizer = optimizers.Adam(learning_rate=LEARNING_RATE)
model.compile(optimizer=optimizer, loss=['mse', None])  

In [None]:
from training_callbacks import get_callbacks
model_checkpoint_callback, tensorboard_callback, image_generator = get_callbacks(train)

In [None]:
# Load old weights if required
if LOAD_MODEL:
    model.load_weights("./models/aei_net")
    tmp = model.predict(train.take(1))

In [None]:
total_elements = file_count
total_batches = -(-total_elements // BATCH_SIZE)
total_loops = -(-total_batches // TAKE_BATCHES)

print(f"Total batches: {total_batches}")
print(f"Total epochs needed: {total_loops}")

for i in range(total_loops):
    print(f"{i + 1} of {total_loops}...")

    model.fit(
        train.skip(i*TAKE_BATCHES).take(TAKE_BATCHES),
        epochs=EPOCHS,
        callbacks=[
            model_checkpoint_callback,
            #tensorboard_callback,
            image_generator,
        ],
    )

    # Save the final models
    model.save("./models/aei_net")

In [None]:
%load_ext tensorboard
%tensorboard --logdir=logs

## 3. Reconstruct using the variational autoencoder <a name="reconstruct"></a>

In [None]:
# Select a subset of the test set
((img_batch, embed_batch), Y_target) = next(iter(train))
example_images = img_batch.numpy()
batch_size = tf.shape(img_batch)[0]
example_embed = tf.random.normal([batch_size, 256]).numpy()

In [None]:
# Create autoencoder predictions and display
Y_pred, z_attr_pred = model.predict([example_images, example_embed])
predicted_images = np.clip(Y_pred * 255, 0, 255).astype(np.uint8)
print("Example real faces")
display(example_images)
print("Reconstructions")
display(predicted_images)