<a id='1'></a>
# Import packages

In [2]:
from keras.layers import *
import keras.backend as K
import tensorflow as tf

In [2]:
from umeyama import umeyama
from image_augmentation import random_transform
from prefetch_generator import background # !pip install prefetch_generator

In [3]:
import os
import cv2
import glob
import time
import numpy as np
from scipy import ndimage
from pathlib import PurePath, Path
from random import randint, shuffle
from IPython.display import clear_output

import matplotlib.pyplot as plt
%matplotlib inline

<a id='4'></a>
# Config

In [4]:
K.set_learning_phase(1)
# K.set_learning_phase(0) # set to 0 in inference phase (video conversion)

In [5]:
# Input/Output resolution
RESOLUTION = 64 # 64x64, 128x128, 256x256
assert (RESOLUTION % 64) == 0, "RESOLUTION should be 64, 128, or 256."

# Batch size
batchSize = 8

# Use motion blurs (data augmentation)
# set True if training data contains images extracted from videos
use_da_motion_blur = False 

# Use eye-aware training
# require images generated from prep_binary_masks.ipynb
use_bm_eyes = True

# Probability of random color matching (data augmentation)
prob_random_color_match = 0.5

# Path to training images
img_dirA = './faceA'
img_dirB = './faceB'
img_dirA_bm_eyes = "./binary_masks/faceA_eyes"
img_dirB_bm_eyes = "./binary_masks/faceB_eyes"

# Path to saved model weights
models_dir = "./models"

In [6]:
# Architecture configuration
arch_config = {}
arch_config['IMAGE_SHAPE'] = (RESOLUTION, RESOLUTION, 3)
arch_config['use_self_attn'] = True
arch_config['norm'] = "instancenorm" # instancenorm, batchnorm, layernorm, groupnorm, none
arch_config['model_capacity'] = "standard" # standard, lite

In [7]:
# Loss function weights configuration
loss_weights = {}
loss_weights['w_D'] = 0.1 # Discriminator
loss_weights['w_recon'] = 1. # L1 reconstruction loss
loss_weights['w_edge'] = 0.1 # edge loss
loss_weights['w_eyes'] = 30. # reconstruction and edge loss on eyes area
loss_weights['w_pl'] = (0.01, 0.1, 0.3, 0.1) # perceptual loss (0.003, 0.03, 0.3, 0.3)

# Init. loss config.
loss_config = {}
loss_config["gan_training"] = "mixup_LSGAN" # "mixup_LSGAN" or "relativistic_avg_LSGAN"
loss_config['use_PL'] = False
loss_config['use_mask_hinge_loss'] = False
loss_config['m_mask'] = 0.
loss_config['lr_factor'] = 1.
loss_config['use_cyclic_loss'] = False

<a id='5'></a>
# Define models

In [8]:
from networks.faceswap_gan_model import FaceswapGANModel

In [9]:
model = FaceswapGANModel(**arch_config)

<a id='6'></a>
# Load Model Weights

Weights file names:
```python
encoder.h5, decoder_A.h5, deocder_B.h5, netDA.h5, netDB.h5
```

In [12]:
model.load_weights(path=models_dir)

Model weights files are successfully loaded


### The following cells are for training, skip to [transform_face()](#tf) for inference.

# Define Losses and Build Training Functions

TODO: split into two methods

In [None]:
# https://github.com/rcmalli/keras-vggface
#!pip install keras_vggface
from keras_vggface.vggface import VGGFace

# VGGFace ResNet50
vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))

#vggface.summary()

model.build_pl_model(vggface_model=vggface)

In [14]:
model.build_train_functions(loss_weights=loss_weights, **loss_config)

<a id='9'></a>
# DataLoader

TODO: write a DataLoader class

In [15]:
# Motion blurs as data augmentation
def get_motion_blur_kernel(sz=7):
    rot_angle = np.random.uniform(-180,180)
    kernel = np.zeros((sz,sz))
    kernel[int((sz-1)//2), :] = np.ones(sz)
    kernel = ndimage.interpolation.rotate(kernel, rot_angle, reshape=False)
    kernel = np.clip(kernel, 0, 1)
    normalize_factor = 1 / np.sum(kernel)
    kernel = kernel * normalize_factor
    return kernel

def motion_blur(images, sz=7):
    # images is a list [image2, image2, ...]
    blur_sz = np.random.choice([5, 7, 9, 11])
    kernel_motion_blur = get_motion_blur_kernel(blur_sz)
    for i, image in enumerate(images):
        images[i] = cv2.filter2D(image, -1, kernel_motion_blur).astype(np.float64)
    return images

In [None]:
# Utils for loading data
def load_data(file_pattern):
    return glob.glob(file_pattern)
  
def random_warp_rev(image, res=RESOLUTION):
    assert image.shape == (256,256,6)
    res_scale = res//64
    assert res_scale >= 1, f"Resolution should be >= 64. Recieved {res}."
    interp_param = 80 * res_scale
    interp_slice = slice(interp_param//10,9*interp_param//10)
    dst_pnts_slice = slice(0,65*res_scale,16*res_scale)
    
    rand_coverage = np.random.randint(25) + 80 # random warping coverage
    rand_scale = np.random.uniform(5., 6.2) # random warping scale
    
    range_ = np.linspace(128-rand_coverage, 128+rand_coverage, 5)
    mapx = np.broadcast_to(range_, (5,5))
    mapy = mapx.T
    mapx = mapx + np.random.normal(size=(5,5), scale=rand_scale)
    mapy = mapy + np.random.normal(size=(5,5), scale=rand_scale)
    interp_mapx = cv2.resize(mapx, (interp_param,interp_param))[interp_slice,interp_slice].astype('float32')
    interp_mapy = cv2.resize(mapy, (interp_param,interp_param))[interp_slice,interp_slice].astype('float32')
    warped_image = cv2.remap(image, interp_mapx, interp_mapy, cv2.INTER_LINEAR)
    src_points = np.stack([mapx.ravel(), mapy.ravel()], axis=-1)
    dst_points = np.mgrid[dst_pnts_slice,dst_pnts_slice].T.reshape(-1,2)
    mat = umeyama(src_points, dst_points, True)[0:2]
    target_image = cv2.warpAffine(image, mat, (res,res))
    return warped_image, target_image

def random_color_match(image):
    global fns_all_trn_data
    rand_idx = np.random.randint(len(fns_all_trn_data))
    fn_match = fns_all_trn_data[rand_idx]
    tar_img = cv2.imread(fn_match)
    if tar_img is None:
        print(f"Failed reading image {fn_match} in random_color_match().")
        return image
    r = 60
    src_img = cv2.resize(image, (256,256))
    tar_img = cv2.resize(tar_img, (256,256))
    mt = np.mean(tar_img[r:-r,r:-r,:], axis=(0,1))
    st = np.std(tar_img[r:-r,r:-r,:], axis=(0,1))
    ms = np.mean(src_img[r:-r,r:-r,:], axis=(0,1))
    ss = np.std(src_img[r:-r,r:-r,:], axis=(0,1))    
    if ss.any() <= 1e-7: return src_img    
    result = st * (src_img.astype(np.float32) - ms) / (ss+1e-7) + mt
    if result.min() < 0:
        result = result - result.min()
    if result.max() > 255:
        result = (255.0/result.max()*result).astype(np.float32)
    return result

random_transform_args = {
    'rotation_range': 10,
    'zoom_range': 0.1,
    'shift_range': 0.05,
    'random_flip': 0.5,
    }
def read_image(fn, dir_bm_eyes=None, random_transform_args=random_transform_args):
    if dir_bm_eyes is None: raise ValueError(f"dir_bm_eyes is not set.")
        
    raw_fn = PurePath(fn).parts[-1]
    image = cv2.imread(fn)
    if image is None: raise IOError(f"Failed reading image {fn}.")        
    if np.random.uniform() <= prob_random_color_match:
        image = random_color_match(image)
    image = cv2.resize(image, (256,256)) / 255 * 2 - 1
    
    if use_bm_eyes:
        bm_eyes = cv2.imread(f"{dir_bm_eyes}/{raw_fn}")
        if bm_eyes is None:
            raise IOError(f"Failed reading binary mask {dir_bm_eyes}/{raw_fn}.")
        bm_eyes = cv2.resize(bm_eyes, (256,256)) / 255.
    else:
        bm_eyes = np.zeros_like(image)
    
    image = np.concatenate([image, bm_eyes], axis=-1)
    image = random_transform(image, **random_transform_args)
    warped_img, target_img = random_warp_rev(image)
    
    bm_eyes = target_img[...,3:]
    warped_img = warped_img[...,:3]
    target_img = target_img[...,:3]
    
    # Motion blur data augmentation:
    # we want the model to learn to preserve motion blurs of input images
    if np.random.uniform() < 0.25 and use_da_motion_blur: 
        warped_img, target_img = motion_blur([warped_img, target_img])
    
    return warped_img, target_img, bm_eyes

In [17]:
# A generator function that yields epoch and data
@background(32)
def minibatch(data, batchsize, dir_bm_eyes):
    length = len(data)
    epoch = i = 0
    tmpsize = None  
    shuffle(data)
    while True:
        size = tmpsize if tmpsize else batchsize
        if i+size > length:
            shuffle(data)
            i = 0
            epoch+=1        
        rtn = np.float32([read_image(data[j], dir_bm_eyes) for j in range(i,i+size)])
        i+=size
        tmpsize = yield epoch, rtn[:,0,:,:,:], rtn[:,1,:,:,:], rtn[:,2,:,:,:]       

def create_minibatch(data, batchsize, dir_bm_eyes):
    # This is a redundant function, to be written in to a DataLoader class.
    batch = minibatch(data, batchsize, dir_bm_eyes)
    tmpsize = None    
    while True:        
        ep1, warped_img, target_img, bm_eyes = next(batch)
        tmpsize = yield ep1, warped_img, target_img, bm_eyes

# Visualizer

TODO: write a Visualizer class

In [18]:
from utils import showG, showG_mask, showG_eyes

<a id='10'></a>
# Start Training
TODO: make training script compact

In [19]:
# Create ./models directory
Path(f"models").mkdir(parents=True, exist_ok=True)

In [20]:
# Get filenames
train_A = load_data(img_dirA+"/*.*")
train_B = load_data(img_dirB+"/*.*")

global fns_all_trn_data
fns_all_trn_data = train_A + train_B

assert len(train_A), "No image found in " + str(img_dirA)
assert len(train_B), "No image found in " + str(img_dirB)
print ("Number of images in folder A: " + str(len(train_A)))
print ("Number of images in folder B: " + str(len(train_B)))

if use_bm_eyes:
    assert len(glob.glob(img_dirA_bm_eyes+"/*.*")), "No binary mask found in " + str(img_dirA_bm_eyes)
    assert len(glob.glob(img_dirB_bm_eyes+"/*.*")), "No binary mask found in " + str(img_dirB_bm_eyes)
    assert len(glob.glob(img_dirA_bm_eyes+"/*.*")) == len(train_A), \
    "Number of faceA images does not match number of their binary masks. Can be caused by any none image file in the folder."
    assert len(glob.glob(img_dirB_bm_eyes+"/*.*")) == len(train_B), \
    "Number of faceB images does not match number of their binary masks. Can be caused by any none image file in the folder."

Number of images in folder A: 376
Number of images in folder B: 318


In [21]:
def show_loss_config(loss_config):
    for config, value in loss_config.items():
        print(f"{config} = {value}")

In [None]:
# Display random binary masks of eyes
train_batchA = create_minibatch(train_A, batchSize, img_dirA_bm_eyes)
train_batchB = create_minibatch(train_B, batchSize, img_dirB_bm_eyes)
_, _, tA, bmA = next(train_batchA)  
_, _, tB, bmB = next(train_batchB)
showG_eyes(tA, tB, bmA, bmB, batchSize)

In [None]:
def reset_session(save_path):
    global model, vggface
    model.save_weights(path=save_path)
    del model
    del vggface
    K.clear_session()
    model = FaceswapGANModel(**arch_config)
    model.load_weights(path=save_path)
    vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))
    model.build_pl_model(vggface_model=vggface)

In [22]:
# Start training
t0 = time.time()
gen_iterations = 0
errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0
errGAs = {}
errGBs = {}
# Dictionaries are ordered in Python 3.6
for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:
    errGAs[k] = 0
    errGBs[k] = 0

display_iters = 300
backup_iters = 5000
TOTAL_ITERS = 40000

train_batchA = create_minibatch(train_A, batchSize, img_dirA_bm_eyes)
train_batchB = create_minibatch(train_B, batchSize, img_dirB_bm_eyes)

while gen_iterations <= TOTAL_ITERS:  
    data_A = next(train_batchA) 
    data_B = next(train_batchB) 
    
    # Loss function automation
    if gen_iterations == (TOTAL_ITERS//5 - display_iters//2):
        clear_output()
        loss_config['use_PL'] = True
        loss_config['use_mask_hinge_loss'] = False
        loss_config['m_mask'] = 0.0
        reset_session(models_dir)
        print("Building new loss funcitons...")
        show_loss_config(loss_config)
        model.build_train_functions(loss_weights=loss_weights, **loss_config)
        print("Done.")
    elif gen_iterations == (TOTAL_ITERS//5 + TOTAL_ITERS//10 - display_iters//2):
        clear_output()
        loss_config['use_PL'] = True
        loss_config['use_mask_hinge_loss'] = True
        loss_config['m_mask'] = 0.5
        reset_session(models_dir)
        print("Building new loss funcitons...")
        show_loss_config(loss_config)
        model.build_train_functions(loss_weights=loss_weights, **loss_config)
        print("Complete.")
    elif gen_iterations == (2*TOTAL_ITERS//5 - display_iters//2):
        clear_output()
        loss_config['use_PL'] = True
        loss_config['use_mask_hinge_loss'] = True
        loss_config['m_mask'] = 0.2
        reset_session(models_dir)
        print("Building new loss funcitons...")
        show_loss_config(loss_config)
        model.build_train_functions(loss_weights=loss_weights, **loss_config)
        print("Done.")
    elif gen_iterations == (TOTAL_ITERS//2 - display_iters//2):
        clear_output()
        loss_config['use_PL'] = True
        loss_config['use_mask_hinge_loss'] = True
        loss_config['m_mask'] = 0.4
        reset_session(models_dir)
        print("Building new loss funcitons...")
        show_loss_config(loss_config)
        model.build_train_functions(loss_weights=loss_weights, **loss_config)
        print("Done.")
    elif gen_iterations == (2*TOTAL_ITERS//3 - display_iters//2):
        clear_output()
        loss_config['use_PL'] = True
        loss_config['use_mask_hinge_loss'] = False
        loss_config['m_mask'] = 0.
        loss_config['lr_factor'] = 0.3
        reset_session(models_dir)
        print("Building new loss funcitons...")
        show_loss_config(loss_config)
        model.build_train_functions(loss_weights=loss_weights, **loss_config)
        print("Done.")
    elif gen_iterations == (8*TOTAL_ITERS//10 - display_iters//2):
        clear_output()
        model.decoder_A.load_weights("models/decoder_B.h5") # swap decoders
        model.decoder_B.load_weights("models/decoder_A.h5") # swap decoders
        loss_config['use_PL'] = True
        loss_config['use_mask_hinge_loss'] = True
        loss_config['m_mask'] = 0.1
        loss_config['lr_factor'] = 0.3
        reset_session(models_dir)
        print("Building new loss funcitons...")
        show_loss_config(loss_config)
        model.build_train_functions(loss_weights=loss_weights, **loss_config)
        print("Done.")
    elif gen_iterations == (9*TOTAL_ITERS//10 - display_iters//2):
        clear_output()
        loss_config['use_PL'] = True
        loss_config['use_mask_hinge_loss'] = False
        loss_config['m_mask'] = 0.0
        loss_config['lr_factor'] = 0.1
        reset_session(models_dir)
        print("Building new loss funcitons...")
        show_loss_config(loss_config)
        model.build_train_functions(loss_weights=loss_weights, **loss_config)
        print("Done.")
    
    if gen_iterations == 5:
        print ("working.")
    
    # Train dicriminators for one batch
    errDA, errDB = model.train_one_batch_D(data_A=data_A, data_B=data_B)
    errDA_sum +=errDA[0]
    errDB_sum +=errDB[0]

    # Train generators for one batch
    errGA, errGB = model.train_one_batch_G(data_A=data_A, data_B=data_B)
    errGA_sum += errGA[0]
    errGB_sum += errGB[0]
    for i, k in enumerate(['ttl', 'adv', 'recon', 'edge', 'pl']):
        errGAs[k] += errGA[i]
        errGBs[k] += errGB[i]
    gen_iterations+=1
    
    # Visualization
    if gen_iterations % display_iters == 0:
        clear_output()
            
        # Display loss information
        show_loss_config(loss_config)
        print("----------") 
        print('[iter %d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f time: %f'
        % (gen_iterations, errDA_sum/display_iters, errDB_sum/display_iters,
           errGA_sum/display_iters, errGB_sum/display_iters, time.time()-t0))  
        print("----------") 
        print("Generator loss details:")
        print(f'[Adversarial loss]')  
        print(f'GA: {errGAs["adv"]/display_iters:.4f} GB: {errGBs["adv"]/display_iters:.4f}')
        print(f'[Reconstruction loss]')
        print(f'GA: {errGAs["recon"]/display_iters:.4f} GB: {errGBs["recon"]/display_iters:.4f}')
        print(f'[Edge loss]')
        print(f'GA: {errGAs["edge"]/display_iters:.4f} GB: {errGBs["edge"]/display_iters:.4f}')
        if loss_config['use_PL'] == True:
            print(f'[Perceptual loss]')
            try:
                print(f'GA: {errGAs["pl"][0]/display_iters:.4f} GB: {errGBs["pl"][0]/display_iters:.4f}')
            except:
                print(f'GA: {errGAs["pl"]/display_iters:.4f} GB: {errGBs["pl"]/display_iters:.4f}')
        
        # Display images
        print("----------") 
        _, wA, tA, _ = next(train_batchA)  
        _, wB, tB, _ = next(train_batchB)
        print("Transformed (masked) results:")
        showG(tA, tB, model.path_A, model.path_B, batchSize)   
        print("Masks:")
        showG_mask(tA, tB, model.path_mask_A, model.path_mask_B, batchSize)  
        print("Reconstruction results:")
        showG(wA, wB, model.path_bgr_A, model.path_bgr_B, batchSize)           
        errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0
        for k in ['ttl', 'adv', 'recon', 'edge', 'pl']:
            errGAs[k] = 0
            errGBs[k] = 0
        
        # Save models
        model.save_weights(path=models_dir)
    
    # Backup models
    if gen_iterations % backup_iters == 0: 
        bkup_dir = f"{models_dir}/backup_iter{gen_iterations}"
        Path(bkup_dir).mkdir(parents=True, exist_ok=True)
        model.save_weights(path=bkup_dir)

<a id='tf'></a>
# Single Image Transformation

In [1]:
def shift_n_scale(src_img, tar_img):
    """
    A color correction method
    """
    mt = np.mean(tar_img, axis=(0,1))
    st = np.std(tar_img, axis=(0,1))
    ms = np.mean(src_img, axis=(0,1))
    ss = np.std(src_img, axis=(0,1))    
    if ss.any() <= 1e-7: return src_img    
    result = st * (src_img.astype(np.float32) - ms) / (ss+1e-7) + mt
    result = (255.0 / result.max() * result).astype(np.float32)
    return result

def transform_face(inp_img, direction="AtoB", roi_coef=15, color_correction=False):
    """
    Parameters:
        inp_img: A RGB face image of any size.
        direction:  A string that is either AtoB or BtoA
        roi_coef: A coefficient that affects the cropped center area
        color_correction: boolean, whether use color correction or not
    Returns:
        result_img: A RGB swapped face image after masking.
        result_mask: The alpha mask which corresponds to the result_img.
    """
    def get_feather_edges_mask(img, roi_coef):
        img_size = img.shape
        mask = np.zeros_like(img)
        x_, y_ = img_size[0]//roi_coef, img_size[1]//roi_coef
        mask[x_:-x_, y_:-y_,:]  = 255
        mask = cv2.GaussianBlur(mask,(15,15),10).astype(np.float32) / 255
        return mask        

    if direction == "AtoB":
        path_func = model.path_abgr_B
    elif direction == "BtoA":
        path_func = model.path_abgr_A
    else:
        raise ValueError(f"direction should be either AtoB or BtoA, recieved {direction}.")

    # pre-process input image
    img_bgr = cv2.cvtColor(inp_img, cv2.COLOR_RGB2BGR)
    input_size = img_bgr.shape    
    roi_x, roi_y = input_size[0]//roi_coef, input_size[1]//roi_coef
    roi = img_bgr[roi_x:-roi_x, roi_y:-roi_y,:] # BGR, [0, 255]  
    roi_size = roi.shape
    ae_input = cv2.resize(roi, (RESOLUTION,RESOLUTION)) / 255. * 2 - 1 # BGR, [-1, 1]       

    # post-process transformed roi image
    ae_output = np.squeeze(np.array([path_func([[ae_input]])]))
    ae_output_a = ae_output[:,:,0] * 255
    ae_output_a = cv2.resize(ae_output_a, (roi_size[1],roi_size[0]))[...,np.newaxis]
    ae_output_bgr = np.clip( (ae_output[:,:,1:] + 1) * 255 / 2, 0, 255)
    ae_output_bgr = cv2.resize(ae_output_bgr, (roi_size[1],roi_size[0]))
    ae_output_masked = (ae_output_a/255 * ae_output_bgr + (1 - ae_output_a/255) * roi).astype('uint8') # BGR, [0, 255]
    if color_correction:
        ae_output_masked = shift_n_scale(ae_output_masked, roi)

    # merge transformed output back to input image
    blend_mask = get_feather_edges_mask(roi, roi_coef)        
    blended_img = blend_mask * ae_output_masked + (1 - blend_mask) * roi
    result = img_bgr
    result[roi_x:-roi_x, roi_y:-roi_y,:] = blended_img 
    result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) 
    result_alpha = np.zeros_like(img_bgr)
    result_alpha[roi_x:-roi_x, roi_y:-roi_y,:] = blend_mask * ae_output_a 
    return result, result_alpha

In [20]:
input_img = plt.imread("./TEST_IMAGE.jpg")[...,:3]

In [None]:
plt.imshow(input_img)

In [22]:
result_img, result_mask = transform_face(input_img, direction="BtoA", roi_coef=15)

In [None]:
plt.imshow(result_img)

In [None]:
plt.imshow(result_mask[..., 0])