In [1]:
# %%capture
# !pip install gdown --force-reinstall
# !gdown --id 1TEOjzSBN6UZc1WQwb_huEhdFl1XyZel4
# !unzip KDEF_CROPPED_ALIGNED.zip

In [2]:
# %%capture
# !gdown --id 1Bd87admxOZvbIOAyTkGEntsEz3fyMt7H

In [3]:
# %%capture
# !pip install mlxtend
# !pip install dlib
# !pip install scikit-image

In [4]:
import torch
torch.multiprocessing.set_start_method('spawn')

In [8]:
import imageio
import matplotlib.pyplot as plt
from mlxtend.image import extract_face_landmarks
import cv2
import glob
import os
import numpy as np
import tqdm
import re
import pickle

# Dataset Preparation

In [10]:
!rm -rf /media/soroushh/Storage2/matrices
!mkdir -p /media/soroushh/Storage2/matrices/training
!mkdir -p /media/soroushh/Storage2/matrices/evaluation

from magface_model import iresnet100, load_dict_inf

magface = iresnet100(pretrained=False, num_classes=512)
magface = load_dict_inf(magface, "./magface_epoch_00025.pth")
magface = magface.to("cuda")
magface.eval()

def get_landmarks(image_paths, image_size):
    database = {}
    for image_path in tqdm.tqdm(image_paths):
        filename = image_path.split("/")[-1].split(".")[0]
        if not bool(re.match(r"\w{2}\d{2}\w{2}\w+", filename)):
            continue

        name = re.findall(r"(\w{2}\d{2})\w{2}\w+", filename)[0]
        pose_label = re.findall(r"\w{2}\d{2}\w{2}(\w+)", filename)[0]
        emotion_label = re.findall(r"\w{2}\d{2}(\w{2})\w+", filename)[0]

        # img = cv2.imread(image_path)
        img = cv2.imread(image_path.replace("CROPPED_ALIGNED", "KDEF")) # get real images from KDEF for landmarks
        img = cv2.resize(img, (image_size, image_size))
        img = cv2.copyMakeBorder(img, 10, 10, 10, 10, cv2.BORDER_CONSTANT)
        pose_img = get_pose_image(img, image_size)

        if pose_img is None:
            continue

        person_db = database.get(name, [])
        person_db.append((emotion_label, pose_label, pose_img, image_path))
        database[name] = person_db
        
    return database

def get_pose_image(img, image_size):
    landmarks = extract_face_landmarks(img)

    if np.all(landmarks == 0) or landmarks is None:
        return None

    landmarks[:, 0][landmarks[:, 0] >= image_size] = image_size - 1
    landmarks[:, 0][landmarks[:, 0] < 0] = 0
    landmarks[:, 1][landmarks[:, 1] >= image_size] = image_size - 1
    landmarks[:, 1][landmarks[:, 1] < 0] = 0

    # landmarks = landmarks[[30, 40, 46, 48, 56]]
    # print(landmarks.shape)

    pose_img = np.zeros((image_size, image_size))
    pose_img[landmarks[:, 1], landmarks[:, 0]] = 1
    pose_img = pose_img[:, :, np.newaxis]
    pose_img = pose_img[landmarks[:, 1].min():landmarks[:, 1].max(), landmarks[:, 0].min():landmarks[:, 0].max()]
    pose_img = cv2.resize(pose_img, (image_size, image_size))
    pose_img = np.where(pose_img > 0.5, 1, 0)

    # plt.imshow(pose_img)
    # plt.show()

    pose_img = pose_img[:, :, np.newaxis]

    return pose_img

def get_magface_embedding(image_path):
    img = cv2.imread(image_path)

    if img.shape[:2] != [112, 112]:
        img = cv2.resize(img, (112, 112))

    img = img[np.newaxis] / 255.
    img = torch.Tensor(img).to("cuda").to(torch.float32)
    img = img.permute(0, 3, 1, 2)
    embedding = magface(img)
    embedding = embedding.detach().cpu().numpy()
    embedding = np.squeeze(embedding)

    return embedding

def prepare_final_database(database):
    db_keys = list(database.keys())

    final_database = []
    for person1_name in db_keys:
        person1 = database[person1_name]
        for triple_to_reconstruct_index, triple_to_reconstruct in enumerate(person1):
            pose_img_to_reconstruct = triple_to_reconstruct[2]

            available_poses_input = []
            available_poses_output = []
            for person2_name in db_keys:
                person2 = database[person2_name]
                available_poses_input = list(filter(lambda index: person2[index][0] == triple_to_reconstruct[0], range(len(person2))))
                available_poses_output = list(filter(lambda index: person2[index][0] == triple_to_reconstruct[0] and person2[index][1] == triple_to_reconstruct[1], range(len(person2))))

                for triple_to_change_input_index in available_poses_input:
                    for triple_to_change_output_index in available_poses_output:
                        final_database.append((person1_name, triple_to_reconstruct_index, person2_name, triple_to_change_input_index, triple_to_change_output_index))

                if len(final_database) >= 60000:
                    break

            if len(final_database) >= 60000:
                break

        if len(final_database) >= 60000:
            break
            
    return final_database

def save_files(database, final_database, type_):
    for idx in tqdm.tqdm(range(len(final_database))):
        tup = final_database[idx]

        if os.path.isfile(f'/media/soroushh/Storage2/matrices/{type_}/{tup[0]}_{tup[1]}_{tup[2]}_{tup[3]}_{tup[4]}_pose_img_to_reconstruct.npy') and os.path.isfile(f'/media/soroushh/Storage2/matrices/{type_}/{tup[0]}_{tup[1]}_{tup[2]}_{tup[3]}_{tup[4]}_embedding_input.npy') and os.path.isfile(f'/media/soroushh/Storage2/matrices/{type_}/{tup[0]}_{tup[1]}_{tup[2]}_{tup[3]}_{tup[4]}_embedding_output.npy'):
            continue

        pose_img_to_reconstruct = database[tup[0]][tup[1]][2]
        image_path = database[tup[2]][tup[3]][3]
        embedding_input = get_magface_embedding(image_path)
        image_path = database[tup[2]][tup[4]][3]
        embedding_output = get_magface_embedding(image_path)

        with open(f'/media/soroushh/Storage2/matrices/{type_}/{tup[0]}_{tup[1]}_{tup[2]}_{tup[3]}_{tup[4]}_pose_img_to_reconstruct.npy', 'wb') as f:
            np.save(f, pose_img_to_reconstruct)

        with open(f'/media/soroushh/Storage2/matrices/{type_}/{tup[0]}_{tup[1]}_{tup[2]}_{tup[3]}_{tup[4]}_embedding_input.npy', 'wb') as f:
            np.save(f, embedding_input)

        with open(f'/media/soroushh/Storage2/matrices/{type_}/{tup[0]}_{tup[1]}_{tup[2]}_{tup[3]}_{tup[4]}_embedding_output.npy', 'wb') as f:
            np.save(f, embedding_output)

    with open(f'/media/soroushh/Storage2/database_{type_}.pickle', 'wb') as handle:
        pickle.dump(database, handle, protocol=pickle.HIGHEST_PROTOCOL)


people_paths = glob.glob(os.path.join("./CROPPED_ALIGNED/*"))
people_paths = np.array(people_paths)
np.random.shuffle(people_paths)
people_paths = people_paths.tolist()

index = int(0.8 * len(people_paths))
trainset = people_paths[:index]
evalset = people_paths[index:]

image_paths = [img_path for person_path in trainset for img_path in glob.glob(person_path + "/*.JPG")]
database = get_landmarks(image_paths, 112)
final_database = prepare_final_database(database)
save_files(database, final_database, "training")

image_paths = [img_path for person_path in evalset for img_path in glob.glob(person_path + "/*.JPG")]
database = get_landmarks(image_paths, 112)
final_database = prepare_final_database(database)
save_files(database, final_database, "evaluation")

=> loading pth from ./magface_epoch_00025.pth ...


100%|██████████| 3903/3903 [01:09<00:00, 56.01it/s]
100%|██████████| 60001/60001 [1:10:53<00:00, 14.11it/s]
100%|██████████| 977/977 [00:17<00:00, 55.77it/s]
100%|██████████| 49115/49115 [45:55<00:00, 17.82it/s]  
