# Usage

**To train a model**: Run 1 ~ 10.

**To load model weights**: Run 1 and 4 ~ 7.

**To use trained model to swap a single face image**: Run "**To load model weights**" and 11.

**To use trained model to create a video clips**: Run "**To load model weights**", 12 and 13 (or 14).


## Index
1. [Import Packages](#1)
2. [Install Requirements (optional)](#2)
3. [Import VGGFace (optional)](#3)
4. [Config](#4)
5. [Define Models](#5)
6. [Load Models](#6)
7. [Define Inputs/outputs Variables](#7)
8. [Define Loss Function](#8)
9. [Utils for loading/displaying images](#9)
10. [Start Training](#10)
11. [Helper Function: face_swap()](#11)
12. [Import Packages for Making Video Clips](#12)
13. [Make Video Clips w/o Face Alignment](#13)
14. [Make video clips w/ face alignment](#14)

<a id='1'></a>
# 1. Import packages

In [None]:
from keras.models import Sequential, Model
from keras.layers import *
from keras.layers.advanced_activations import LeakyReLU
from keras.activations import relu
from keras.initializers import RandomNormal
from keras.applications import *
import keras.backend as K
from keras.layers.core import Layer
from keras.engine import InputSpec
from keras import initializers
from tensorflow.contrib.distributions import Beta
import tensorflow as tf
from keras.optimizers import Adam

In [2]:
from image_augmentation import random_transform
from image_augmentation import random_warp
from umeyama import umeyama
from utils import get_image_paths, load_images, stack_images
from pixel_shuffler import PixelShuffler

In [3]:
import time
import numpy as np
from PIL import Image
import cv2
import glob
from random import randint, shuffle
from IPython.display import clear_output
from IPython.display import display
import matplotlib.pyplot as plt
%matplotlib inline

<a id='2'></a>
# 2. Install requirements

## ========== CAUTION ========== 

If you are running this jupyter on local machine. Please read [this blog](http://jakevdp.github.io/blog/2017/12/05/installing-python-packages-from-jupyter/) before running the following cells which pip install packages.

In [None]:
# https://github.com/rcmalli/keras-vggface
# Skip this cell if you don't want to use perceptual loss
#!pip install keras_vggface

In [None]:
# https://github.com/ageitgey/face_recognition
#!pip install face_recognition

We only import ```face_recognition``` and ```moviepy``` when making videos. They are not necessary in training GAN models.

In [None]:
#!pip install moviepy

<a id='3'></a>
# 3. Import VGGFace

In [None]:
from keras_vggface.vggface import VGGFace

In [None]:
vggface = VGGFace(include_top=False, model='resnet50', input_shape=(224, 224, 3))

In [None]:
#vggface.summary()

<a id='4'></a>
# 4. Config

mixup paper: https://arxiv.org/abs/1710.09412

Default training data directories: `./faceA/` and `./faceB/`

In [4]:
K.set_learning_phase(1)

In [5]:
channel_axis=-1
channel_first = False

In [6]:
IMAGE_SHAPE = (64, 64, 3)
nc_in = 3 # number of input channels of generators
nc_D_inp = 6 # number of input channels of discriminators

use_perceptual_loss = True # This should NOT be changed.
use_lsgan = True
use_self_attn = False
use_instancenorm = False
use_mixup = True
mixup_alpha = 0.2 # 0.2
w_l2 = 1e-4 # weight decay

# Adding motion blurs as data augmentation
# set True if training data contains images extracted from videos
use_da_motion_blur = False 

batchSize = 8
lrD = 1e-4 # Discriminator learning rate
lrG = 1e-4 # Generator learning rate

# Path of training images
img_dirA = './faceA/*.*'
img_dirB = './faceB/*.*'

<a id='5'></a>
# 5. Define models

In [7]:
conv_init = RandomNormal(0, 0.02)
gamma_init = RandomNormal(1., 0.02) # for batch normalization

In [23]:
class Scale(Layer):
    '''
    Code borrows from https://github.com/flyyufelix/cnn_finetune
    '''
    def __init__(self, weights=None, axis=-1, gamma_init='zero', **kwargs):
        self.axis = axis
        self.gamma_init = initializers.get(gamma_init)
        self.initial_weights = weights
        super(Scale, self).__init__(**kwargs)

    def build(self, input_shape):
        self.input_spec = [InputSpec(shape=input_shape)]

        # Compatibility with TensorFlow >= 1.0.0
        self.gamma = K.variable(self.gamma_init((1,)), name='{}_gamma'.format(self.name))
        self.trainable_weights = [self.gamma]

        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights

    def call(self, x, mask=None):
        return self.gamma * x

    def get_config(self):
        config = {"axis": self.axis}
        base_config = super(Scale, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))


def self_attn_block(inp, nc):
    '''
    Code borrows from https://github.com/taki0112/Self-Attention-GAN-Tensorflow
    '''
    assert nc//8 > 0, f"Input channels must be >= 8, but got nc={nc}"
    x = inp
    shape_x = x.get_shape().as_list()
    
    f = Conv2D(nc//8, 1, kernel_initializer=conv_init)(x)
    g = Conv2D(nc//8, 1, kernel_initializer=conv_init)(x)
    h = Conv2D(nc, 1, kernel_initializer=conv_init)(x)
    
    shape_f = f.get_shape().as_list()
    shape_g = g.get_shape().as_list()
    shape_h = h.get_shape().as_list()
    flat_f = Reshape((-1, shape_f[-1]))(f)
    flat_g = Reshape((-1, shape_g[-1]))(g)
    flat_h = Reshape((-1, shape_h[-1]))(h)   
    
    s = Lambda(lambda x: tf.matmul(x[0], x[1], transpose_b=True))([flat_g, flat_f])

    beta = Softmax(axis=-1)(s)
    o = Lambda(lambda x: tf.matmul(x[0], x[1]))([beta, flat_h])
    o = Reshape(shape_x[1:])(o)
    o = Scale()(o)
    
    out = add([o, inp])
    return out

In [10]:
#def batchnorm():
#    return BatchNormalization(momentum=0.9, axis=channel_axis, epsilon=1.01e-5, gamma_initializer = gamma_init)

def conv_block(input_tensor, f):
    x = input_tensor
    x = Conv2D(f, kernel_size=3, strides=2, kernel_regularizer=regularizers.l2(w_l2),  
               kernel_initializer=conv_init, use_bias=False, padding="same")(x)
    x = Activation("relu")(x)
    return x

def conv_block_d(input_tensor, f, use_instance_norm=False):
    x = input_tensor
    x = Conv2D(f, kernel_size=4, strides=2, kernel_regularizer=regularizers.l2(w_l2), 
               kernel_initializer=conv_init, use_bias=False, padding="same")(x)
    x = LeakyReLU(alpha=0.2)(x)
    return x

def res_block(input_tensor, f):
    x = input_tensor
    x = Conv2D(f, kernel_size=3, kernel_regularizer=regularizers.l2(w_l2), 
               kernel_initializer=conv_init, use_bias=False, padding="same")(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Conv2D(f, kernel_size=3, kernel_regularizer=regularizers.l2(w_l2), 
               kernel_initializer=conv_init, use_bias=False, padding="same")(x)
    x = add([x, input_tensor])
    x = LeakyReLU(alpha=0.2)(x)
    return x

# Legacy
#def upscale_block(input_tensor, f):
#    x = input_tensor
#    x = Conv2DTranspose(f, kernel_size=3, strides=2, use_bias=False, kernel_initializer=conv_init)(x) 
#    x = LeakyReLU(alpha=0.2)(x)
#    return x

def upscale_ps(filters, use_norm=True):
    def block(x):
        x = Conv2D(filters*4, kernel_size=3, kernel_regularizer=regularizers.l2(w_l2), 
                   kernel_initializer=RandomNormal(0, 0.02), padding='same')(x)
        x = LeakyReLU(0.2)(x)
        x = PixelShuffler()(x)
        return x
    return block

def Discriminator(nc_in, input_size=64):
    inp = Input(shape=(input_size, input_size, nc_in))
    #x = GaussianNoise(0.05)(inp)
    x = conv_block_d(inp, 64, False)
    x = conv_block_d(x, 128, False)
    x = self_attn_block(x, 128) if use_self_attn else x
    x = conv_block_d(x, 256, False)
    x = self_attn_block(x, 256) if use_self_attn else x
    out = Conv2D(1, kernel_size=4, kernel_initializer=conv_init, use_bias=False, padding="same")(x)   
    return Model(inputs=[inp], outputs=out)

def Encoder(nc_in=3, input_size=64):
    inp = Input(shape=(input_size, input_size, nc_in))
    x = Conv2D(64, kernel_size=5, kernel_initializer=conv_init, use_bias=False, padding="same")(inp)
    x = conv_block(x,128)
    x = conv_block(x,256)
    x = self_attn_block(x, 256) if use_self_attn else x
    x = conv_block(x,512) 
    x = self_attn_block(x, 512) if use_self_attn else x
    x = conv_block(x,1024)
    x = Dense(1024)(Flatten()(x))
    x = Dense(4*4*1024)(x)
    x = Reshape((4, 4, 1024))(x)
    out = upscale_ps(512)(x)
    return Model(inputs=inp, outputs=out)

# Legacy, left for someone to try if interested
#def Decoder(nc_in=512, input_size=8):
#    inp = Input(shape=(input_size, input_size, nc_in))   
#    x = upscale_block(inp, 256)
#    x = Cropping2D(((0,1),(0,1)))(x)
#    x = upscale_block(x, 128)
#    x = res_block(x, 128)
#    x = Cropping2D(((0,1),(0,1)))(x)
#    x = upscale_block(x, 64)
#    x = res_block(x, 64)
#    x = res_block(x, 64)
#    x = Cropping2D(((0,1),(0,1)))(x)
#    x = Conv2D(3, kernel_size=5, kernel_initializer=conv_init, use_bias=False, padding="same")(x)
#    out = Activation("tanh")(x)
#    return Model(inputs=inp, outputs=out)

def Decoder_ps(nc_in=512, input_size=8):
    input_ = Input(shape=(input_size, input_size, nc_in))
    x = input_
    x = upscale_ps(256)(x)
    x = upscale_ps(128)(x)
    x = self_attn_block(x, 128) if use_self_attn else x
    x = upscale_ps(64)(x)
    x = res_block(x, 64)
    x = self_attn_block(x, 64) if use_self_attn else x
    #x = Conv2D(4, kernel_size=5, padding='same')(x)   
    alpha = Conv2D(1, kernel_size=5, padding='same', activation="sigmoid")(x)
    rgb = Conv2D(3, kernel_size=5, padding='same', activation="tanh")(x)
    out = concatenate([alpha, rgb])
    return Model(input_, out )    

In [11]:
encoder = Encoder()
decoder_A = Decoder_ps()
decoder_B = Decoder_ps()

x = Input(shape=IMAGE_SHAPE)

netGA = Model(x, decoder_A(encoder(x)))
netGB = Model(x, decoder_B(encoder(x)))

In [12]:
netDA = Discriminator(nc_D_inp)
netDB = Discriminator(nc_D_inp)

<a id='6'></a>
# 6. Load Models

In [11]:
try:
    encoder.load_weights("models/encoder.h5")
    decoder_A.load_weights("models/decoder_A.h5")
    decoder_B.load_weights("models/decoder_B.h5")
    netDA.load_weights("models/netDA.h5") 
    netDB.load_weights("models/netDB.h5") 
    print ("Model weights files are successfully loaded")
except:
    print ("Error occurs during loading weights file.")
    pass

model loaded.


<a id='7'></a>
# 7. Define Inputs/Outputs Variables

    distorted_A: A (batch_size, 64, 64, 3) tensor, input of generator_A (netGA).
    distorted_B: A (batch_size, 64, 64, 3) tensor, input of generator_B (netGB).
    fake_A: (batch_size, 64, 64, 3) tensor, output of generator_A (netGA).
    fake_B: (batch_size, 64, 64, 3) tensor, output of generator_B (netGB).
    mask_A: (batch_size, 64, 64, 1) tensor, mask output of generator_A (netGA).
    mask_B: (batch_size, 64, 64, 1) tensor, mask output of generator_B (netGB).
    path_A: A function that takes distorted_A as input and outputs fake_A.
    path_B: A function that takes distorted_B as input and outputs fake_B.
    path_mask_A: A function that takes distorted_A as input and outputs mask_A.
    path_mask_B: A function that takes distorted_B as input and outputs mask_B.
    path_abgr_A: A function that takes distorted_A as input and outputs concat([mask_A, fake_A]).
    path_abgr_B: A function that takes distorted_B as input and outputs concat([mask_B, fake_B]).
    real_A: A (batch_size, 64, 64, 3) tensor, target images for generator_A given input distorted_A.
    real_B: A (batch_size, 64, 64, 3) tensor, target images for generator_B given input distorted_B.

In [13]:
def cycle_variables(netG):
    distorted_input = netG.inputs[0]
    fake_output = netG.outputs[0]
    alpha = Lambda(lambda x: x[:,:,:, :1])(fake_output)
    rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_output)
    
    masked_fake_output = alpha * rgb + (1-alpha) * distorted_input 

    fn_generate = K.function([distorted_input], [masked_fake_output])
    fn_mask = K.function([distorted_input], [concatenate([alpha, alpha, alpha])])
    fn_abgr = K.function([distorted_input], [concatenate([alpha, rgb])])
    return distorted_input, fake_output, alpha, fn_generate, fn_mask, fn_abgr

In [14]:
distorted_A, fake_A, mask_A, path_A, path_mask_A, path_abgr_A = cycle_variables(netGA)
distorted_B, fake_B, mask_B, path_B, path_mask_B, path_abgr_B = cycle_variables(netGB)
real_A = Input(shape=IMAGE_SHAPE)
real_B = Input(shape=IMAGE_SHAPE)

<a id='8'></a>
# 8. Define Loss Function

### Loss function hyper parameters configuration

In [15]:
# Hyper params for generators
w_D = 0.1 # Discriminator
w_recon = 1. # L1 reconstruvtion loss
w_edge = 1. # edge loss
w_pl1 = (0.01, 0.1, 0.2, 0.02) # perceptual loss 1 
w_pl2 = (0.003, 0.03, 0.1, 0.01) # perceptual loss 2 

# Alpha mask regularizations
#m_mask = 0.5 # Margin value of alpha mask hinge loss
w_mask = 0.1 # hinge loss
w_mask_fo = 0.01 # Alpha mask total variation loss

In [16]:
def first_order(x, axis=1):
    img_nrows = x.shape[1]
    img_ncols = x.shape[2]
    if axis == 1:
        return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, 1:, :img_ncols - 1, :])
    elif axis == 2:
        return K.abs(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, :img_nrows - 1, 1:, :])
    else:
        return None   

In [17]:
if use_lsgan:
    loss_fn = lambda output, target : K.mean(K.abs(K.square(output-target)))
else:
    loss_fn = lambda output, target : -K.mean(K.log(output+1e-12)*target+K.log(1-output+1e-12)*(1-target))

In [18]:
def define_loss(netD, real, fake_argb, distorted, vggface_feat=None):   
    alpha = Lambda(lambda x: x[:,:,:, :1])(fake_argb)
    fake_rgb = Lambda(lambda x: x[:,:,:, 1:])(fake_argb)
    fake = alpha * fake_rgb + (1-alpha) * distorted
    
    if use_mixup:
        dist = Beta(mixup_alpha, mixup_alpha)
        lam = dist.sample()
        mixup = lam * concatenate([real, distorted]) + (1 - lam) * concatenate([fake, distorted])        
        output_mixup = netD(mixup)
        loss_D = loss_fn(output_mixup, lam * K.ones_like(output_mixup)) 
        output_fake = netD(concatenate([fake, distorted])) # dummy
        loss_G = w_D * loss_fn(output_mixup, (1 - lam) * K.ones_like(output_mixup))
    else:
        output_real = netD(concatenate([real, distorted])) # positive sample
        output_fake = netD(concatenate([fake, distorted])) # negative sample   
        loss_D_real = loss_fn(output_real, K.ones_like(output_real))    
        loss_D_fake = loss_fn(output_fake, K.zeros_like(output_fake))   
        loss_D = loss_D_real + loss_D_fake
        loss_G = w_D * loss_fn(output_fake, K.ones_like(output_fake))  
    
    # Reconstruction loss
    loss_G += w_recon * K.mean(K.abs(fake_rgb - real))
    
    # Edge loss (similar with total variation loss) 
    loss_G += w_edge * K.mean(K.abs(first_order(fake_rgb, axis=1) - first_order(real, axis=1)))
    loss_G += w_edge * K.mean(K.abs(first_order(fake_rgb, axis=2) - first_order(real, axis=2)))
    
    
    # Perceptual Loss
    if not vggface_feat is None:
        def preprocess_vggface(x):
            x = (x + 1)/2 * 255 # channel order: BGR
            x -= [91.4953, 103.8827, 131.0912]
            return x
        pl_params = w_pl1
        real_sz224 = tf.image.resize_images(real, [224, 224])
        real_sz224 = Lambda(preprocess_vggface)(real_sz224)
        
        # Perceptial loss for raw output
        fake_sz224 = tf.image.resize_images(fake_rgb, [224, 224]) 
        fake_sz224 = Lambda(preprocess_vggface)(fake_sz224)        
        real_feat112, real_feat55, real_feat28, real_feat7 = vggface_feat(real_sz224)
        fake_feat112, fake_feat55, fake_feat28, fake_feat7  = vggface_feat(fake_sz224)    
        loss_G += pl_params[0] * K.mean(K.abs(fake_feat7 - real_feat7))
        loss_G += pl_params[1] * K.mean(K.abs(fake_feat28 - real_feat28))
        loss_G += pl_params[2] * K.mean(K.abs(fake_feat55 - real_feat55))
        loss_G += pl_params[3] * K.mean(K.abs(fake_feat112 - real_feat112))
        
        # Perceptial loss for masked output
        pl_params = w_pl2
        fake_sz224 = tf.image.resize_images(fake, [224, 224]) 
        fake_sz224 = Lambda(preprocess_vggface)(fake_sz224)
        fake_feat112, fake_feat55, fake_feat28, fake_feat7  = vggface_feat(fake_sz224)    
        loss_G += pl_params[0] * K.mean(K.abs(fake_feat7 - real_feat7))
        loss_G += pl_params[1] * K.mean(K.abs(fake_feat28 - real_feat28))
        loss_G += pl_params[2] * K.mean(K.abs(fake_feat55 - real_feat55))
        loss_G += pl_params[3] * K.mean(K.abs(fake_feat112 - real_feat112))
    
    return loss_D, loss_G

In [None]:
# ========== Define Perceptual Loss Model==========
if use_perceptual_loss:
    vggface.trainable = False
    out_size112 = vggface.layers[1].output
    out_size55 = vggface.layers[36].output
    out_size28 = vggface.layers[78].output
    out_size7 = vggface.layers[-2].output
    vggface_feat = Model(vggface.input, [out_size112, out_size55, out_size28, out_size7])
    vggface_feat.trainable = False
else:
    vggface_feat = None

In [21]:
netDA_train = netGA_train = netDB_train = netGB_train = None

In [22]:
def build_training_functions(use_PL=False, use_mask_hinge_loss=False, m_mask=0.5, lr_factor=1):
    global netGA, netDA, real_A, fake_A, distorted_A, mask_A
    global netGB, netDB, real_B, fake_B, distorted_B, mask_B
    global netDA_train, netGA_train, netDB_train, netGB_train
    global vggface_feat
    global w_mask, w_mask_fo
    
    if use_PL:
        loss_DA, loss_GA = define_loss(netDA, real_A, fake_A, distorted_A, vggface_feat)
        loss_DB, loss_GB = define_loss(netDB, real_B, fake_B, distorted_B, vggface_feat)
    else:
        loss_DA, loss_GA = define_loss(netDA, real_A, fake_A, distorted_A, vggface_feat=None)
        loss_DB, loss_GB = define_loss(netDB, real_B, fake_B, distorted_B, vggface_feat=None)

    # Alpha mask loss
    if not use_mask_hinge_loss:
        loss_GA += 1e-3 * K.mean(K.abs(mask_A))
        loss_GB += 1e-3 * K.mean(K.abs(mask_B))
    else:
        loss_GA += w_mask * K.mean(K.maximum(0., m_mask - mask_A))
        loss_GB += w_mask * K.mean(K.maximum(0., m_mask - mask_B))
        
    # Alpha mask total variation loss
    loss_GA += w_mask_fo * K.mean(first_order(mask_A, axis=1))
    loss_GA += w_mask_fo * K.mean(first_order(mask_A, axis=2))
    loss_GB += w_mask_fo * K.mean(first_order(mask_B, axis=1))
    loss_GB += w_mask_fo * K.mean(first_order(mask_B, axis=2))
    
    # L2 weight decay
    # https://github.com/keras-team/keras/issues/2662
    for loss_tensor in netGA.losses:
        loss_GA += loss_tensor
    for loss_tensor in netGB.losses:
        loss_GB += loss_tensor
    for loss_tensor in netDA.losses:
        loss_DA += loss_tensor
    for loss_tensor in netDB.losses:
        loss_DB += loss_tensor
    
    weightsDA = netDA.trainable_weights
    weightsGA = netGA.trainable_weights
    weightsDB = netDB.trainable_weights
    weightsGB = netGB.trainable_weights

    # Adam(..).get_updates(...)
    training_updates = Adam(lr=lrD, beta_1=0.5).get_updates(weightsDA,[],loss_DA)
    netDA_train = K.function([distorted_A, real_A],[loss_DA], training_updates)
    training_updates = Adam(lr=lrG*lr_factor, beta_1=0.5).get_updates(weightsGA,[], loss_GA)
    netGA_train = K.function([distorted_A, real_A], [loss_GA], training_updates)

    training_updates = Adam(lr=lrD, beta_1=0.5).get_updates(weightsDB,[],loss_DB)
    netDB_train = K.function([distorted_B, real_B],[loss_DB], training_updates)
    training_updates = Adam(lr=lrG*lr_factor, beta_1=0.5).get_updates(weightsGB,[], loss_GB)
    netGB_train = K.function([distorted_B, real_B], [loss_GB], training_updates)
    
    print ("Loss configuration:")
    print ("use_PL = " + str(use_PL))
    print ("use_mask_hinge_loss = " + str(use_mask_hinge_loss))
    print ("m_mask = " + str(m_mask))    

<a id='9'></a>
# 9. Utils For Loading/Displaying Images

In [23]:
from scipy import ndimage

In [24]:
def get_motion_blur_kernel(sz=7):
    rot_angle = np.random.uniform(-180,180)
    kernel = np.zeros((sz,sz))
    kernel[int((sz-1)//2), :] = np.ones(sz)
    kernel = ndimage.interpolation.rotate(kernel, rot_angle, reshape=False)
    kernel = np.clip(kernel, 0, 1)
    normalize_factor = 1 / np.sum(kernel)
    kernel = kernel * normalize_factor
    return kernel

def motion_blur(images, sz=7):
    # images is a list [image2, image2, ...]
    blur_sz = np.random.choice([5, 7, 9, 11])
    kernel_motion_blur = get_motion_blur_kernel(blur_sz)
    for i, image in enumerate(images):
        images[i] = cv2.filter2D(image, -1, kernel_motion_blur).astype(np.float64)
    return images

In [25]:
def load_data(file_pattern):
    return glob.glob(file_pattern)
  
def random_warp_rev(image):
    assert image.shape == (256,256,3)
    rand_coverage = np.random.randint(25) + 80 # random warping coverage
    rand_scale = np.random.uniform(5., 6.2) # random warping scale
    range_ = np.linspace(128-rand_coverage, 128+rand_coverage, 5)
    mapx = np.broadcast_to(range_, (5,5))
    mapy = mapx.T
    mapx = mapx + np.random.normal(size=(5,5), scale=rand_scale)
    mapy = mapy + np.random.normal(size=(5,5), scale=rand_scale)
    interp_mapx = cv2.resize(mapx, (80,80))[8:72,8:72].astype('float32')
    interp_mapy = cv2.resize(mapy, (80,80))[8:72,8:72].astype('float32')
    warped_image = cv2.remap(image, interp_mapx, interp_mapy, cv2.INTER_LINEAR)
    src_points = np.stack([mapx.ravel(), mapy.ravel()], axis=-1)
    dst_points = np.mgrid[0:65:16,0:65:16].T.reshape(-1,2)
    mat = umeyama(src_points, dst_points, True)[0:2]
    target_image = cv2.warpAffine(image, mat, (64,64))
    return warped_image, target_image

random_transform_args = {
    'rotation_range': 10,
    'zoom_range': 0.1,
    'shift_range': 0.05,
    'random_flip': 0.5,
    }
def read_image(fn, random_transform_args=random_transform_args):
    image = cv2.imread(fn)
    image = cv2.resize(image, (256,256)) / 255 * 2 - 1
    image = random_transform(image, **random_transform_args)
    warped_img, target_img = random_warp_rev(image)
    
    # Motion blur data augmentation:
    # we want the model to learn to preserve motion blurs of input images
    if np.random.uniform() < 0.25 and use_da_motion_blur: 
        warped_img, target_img = motion_blur([warped_img, target_img])
    
    return warped_img, target_img

In [26]:
# A generator function that yields epoch, batchsize of warped_img and batchsize of target_img
def minibatch(data, batchsize):
    length = len(data)
    epoch = i = 0
    tmpsize = None  
    shuffle(data)
    while True:
        size = tmpsize if tmpsize else batchsize
        if i+size > length:
            shuffle(data)
            i = 0
            epoch+=1        
        rtn = np.float32([read_image(data[j]) for j in range(i,i+size)])
        i+=size
        tmpsize = yield epoch, rtn[:,0,:,:,:], rtn[:,1,:,:,:]       

def minibatchAB(dataA, batchsize):
    batchA = minibatch(dataA, batchsize)
    tmpsize = None    
    while True:        
        ep1, warped_img, target_img = batchA.send(tmpsize)
        tmpsize = yield ep1, warped_img, target_img

In [27]:
def showG(test_A, test_B, path_A, path_B):
    figure_A = np.stack([
        test_A,
        np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])),
        np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])),
        ], axis=1 )
    figure_B = np.stack([
        test_B,
        np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])),
        np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])),
        ], axis=1 )

    figure = np.concatenate([figure_A, figure_B], axis=0 )
    figure = figure.reshape((4,7) + figure.shape[1:])
    figure = stack_images(figure)
    figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8')
    figure = cv2.cvtColor(figure, cv2.COLOR_BGR2RGB)

    display(Image.fromarray(figure))
    
def showG_mask(test_A, test_B, path_A, path_B):
    figure_A = np.stack([
        test_A,
        (np.squeeze(np.array([path_A([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1,
        (np.squeeze(np.array([path_B([test_A[i:i+1]]) for i in range(test_A.shape[0])])))*2-1,
        ], axis=1 )
    figure_B = np.stack([
        test_B,
        (np.squeeze(np.array([path_B([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1,
        (np.squeeze(np.array([path_A([test_B[i:i+1]]) for i in range(test_B.shape[0])])))*2-1,
        ], axis=1 )

    figure = np.concatenate([figure_A, figure_B], axis=0 )
    figure = figure.reshape((4,7) + figure.shape[1:])
    figure = stack_images(figure)
    figure = np.clip((figure + 1) * 255 / 2, 0, 255).astype('uint8')
    figure = cv2.cvtColor(figure, cv2.COLOR_BGR2RGB)

    display(Image.fromarray(figure))

<a id='10'></a>
# 10. Start Training

Show results and save model weights every `display_iters` iterations.

In [43]:
!mkdir models # create ./models directory

mkdir: cannot create directory ‘models’: File exists


In [44]:
# Get filenames
train_A = load_data(img_dirA)
train_B = load_data(img_dirB)

assert len(train_A), "No image found in " + str(img_dirA)
assert len(train_B), "No image found in " + str(img_dirB)

In [28]:
def show_loss_config(loss_config):
    for config, value in loss_config.items():
        print(str(config) + " = " + str(value))

In [29]:
# Init. loss config.
loss_config = {}
loss_config['use_PL'] = False
loss_config['use_mask_hinge_loss'] = False
loss_config['m_mask'] = 0.5
loss_config['lr_factor'] = 1.

In [None]:
t0 = time.time()
gen_iterations = 0
epoch = 0
errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0

display_iters = 300
train_batchA = minibatchAB(train_A, batchSize)
train_batchB = minibatchAB(train_B, batchSize)

# ========== Change TOTAL_ITERS to desired iterations  ========== 
TOTAL_ITERS = 40000
#iter_dec_swap = TOTAL_ITERS - (np.minimum(len(train_A)*15, len(train_B))*15) // batchSize
#if iter_dec_swap <= (9*TOTAL_ITERS//10 - display_iters//2):
#    iter_dec_swap = 9*TOTAL_ITERS//10 - display_iters//2

while gen_iterations <= TOTAL_ITERS: 
    epoch, warped_A, target_A = next(train_batchA) 
    epoch, warped_B, target_B = next(train_batchB) 
    
    # Loss function automation
    if gen_iterations == 0:
        build_training_functions(**loss_config)
    elif gen_iterations == (TOTAL_ITERS//5 - display_iters//2):
        clear_output()
        loss_config['use_PL'] = True
        loss_config['use_mask_hinge_loss'] = False
        loss_config['m_mask'] = 0.0
        build_training_functions(**loss_config)
    elif gen_iterations == (TOTAL_ITERS//5 + TOTAL_ITERS//10 - display_iters//2):
        clear_output()
        loss_config['use_PL'] = True
        loss_config['use_mask_hinge_loss'] = True
        loss_config['m_mask'] = 0.5
        build_training_functions(**loss_config)
    elif gen_iterations == (2*TOTAL_ITERS//5 - display_iters//2):
        clear_output()
        loss_config['use_PL'] = True
        loss_config['use_mask_hinge_loss'] = True
        loss_config['m_mask'] = 0.25
        build_training_functions(**loss_config)
    elif gen_iterations == (TOTAL_ITERS//2 - display_iters//2):
        clear_output()
        loss_config['use_PL'] = True
        loss_config['use_mask_hinge_loss'] = True
        loss_config['m_mask'] = 0.4
        build_training_functions(**loss_config)
    elif gen_iterations == (2*TOTAL_ITERS//3 - display_iters//2):
        clear_output()
        loss_config['use_PL'] = True
        loss_config['use_mask_hinge_loss'] = False
        loss_config['m_mask'] = 0.1
        loss_config['lr_factor'] = 0.3
        build_training_functions(**loss_config)
    elif gen_iterations == (9*TOTAL_ITERS//10 - display_iters//2):
        clear_output()
        decoder_A.load_weights("models/decoder_B.h5")
        decoder_B.load_weights("models/decoder_A.h5")
        loss_config['use_PL'] = True
        loss_config['use_mask_hinge_loss'] = True
        loss_config['m_mask'] = 0.1
        loss_config['lr_factor'] = 0.3
        build_training_functions(**loss_config)
    
    # Train dicriminators for one batch
    if gen_iterations % 1 == 0:
        errDA  = netDA_train([warped_A, target_A])
        errDB  = netDB_train([warped_B, target_B])
    errDA_sum +=errDA[0]
    errDB_sum +=errDB[0]
    
    if gen_iterations == 5:
        print ("working.")

    # Train generators for one batch
    errGA = netGA_train([warped_A, target_A])
    errGB = netGB_train([warped_B, target_B])
    errGA_sum += errGA[0]
    errGB_sum += errGB[0]
    gen_iterations+=1
    
    if gen_iterations % display_iters == 0:
        if gen_iterations % (display_iters) == 0:
            clear_output()
        show_loss_config(loss_config)
        print('[iter %d] Loss_DA: %f Loss_DB: %f Loss_GA: %f Loss_GB: %f time: %f'
        % (gen_iterations, errDA_sum/display_iters, errDB_sum/display_iters,
           errGA_sum/display_iters, errGB_sum/display_iters, time.time()-t0))   
        
        # get new batch of images and generate results for visualization
        _, wA, tA = train_batchA.send(14)  
        _, wB, tB = train_batchB.send(14)
        showG(tA, tB, path_A, path_B)   
        showG(wA, wB, path_A, path_B)         
        showG_mask(tA, tB, path_mask_A, path_mask_B)           
        errGA_sum = errGB_sum = errDA_sum = errDB_sum = 0
        
        # Save models
        encoder.save_weights("models/encoder.h5")
        decoder_A.save_weights("models/decoder_A.h5")
        decoder_B.save_weights("models/decoder_B.h5")
        netDA.save_weights("models/netDA.h5")
        netDB.save_weights("models/netDB.h5")

<a id='11'></a>
# 11. Helper Function: face_swap()
This function is provided for those who don't have enough VRAM to run dlib's CNN and GAN model at the same time.

    INPUTS:
        img: A RGB face image of any size.
        path_func: a function that is either path_abgr_A or path_abgr_B.
    OUPUTS:
        result_img: A RGB swapped face image after masking.
        result_mask: A single channel uint8 mask image.

In [21]:
def swap_face(img, path_func):
    input_size = img.shape
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # generator expects BGR input    
    ae_input = cv2.resize(img, (64,64))/255. * 2 - 1        
    
    result = np.squeeze(np.array([path_func([[ae_input]])]))
    result_a = result[:,:,0] * 255
    result_bgr = np.clip( (result[:,:,1:] + 1) * 255 / 2, 0, 255 )
    result_a = np.expand_dims(result_a, axis=2)
    result = (result_a/255 * result_bgr + (1 - result_a/255) * ((ae_input + 1) * 255 / 2)).astype('uint8')
       
    result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) 
    result = cv2.resize(result, (input_size[1],input_size[0]))
    result_a = np.expand_dims(cv2.resize(result_a, (input_size[1],input_size[0])), axis=2)
    return result, result_a

In [22]:
direction = "BtoA" # default trainsforming faceB to faceA

if direction is "AtoB":
    path_func = path_abgr_B
elif direction is "BtoA":
    path_func = path_abgr_A
else:
    print ("direction should be either AtoB or BtoA")

In [23]:
input_img = plt.imread("./TEST_FACE.jpg")

In [None]:
plt.imshow(input_img)

In [25]:
result_img, result_mask = swap_face(input_img, path_func)

In [None]:
plt.imshow(result_img)

In [None]:
plt.imshow(result_mask[:, :, 0])

<a id='12'></a>
# 12. Make video clips

Given a video as input, the following cells will detect face for each frame using dlib's cnn model. And use trained GAN model to transform detected face into target face. Then output a video with swapped faces.

In [16]:
# Download ffmpeg if needed, which is required by moviepy.

#import imageio
#imageio.plugins.ffmpeg.download()

Imageio: 'ffmpeg.linux64' was not found on your computer; downloading it now.
Try 1. Download from https://github.com/imageio/imageio-binaries/raw/master/ffmpeg/ffmpeg.linux64 (27.2 MB)
Downloading: 8192/28549024 bytes (0.02220032/28549024 bytes (7.8%5873664/28549024 bytes (20.69568256/28549024 bytes (33.513271040/28549024 bytes (46.5%16973824/28549024 bytes (59.5%20660224/28549024 bytes (72.4%24363008/28549024 bytes (85.3%27885568/28549024 bytes (97.7%28549024/28549024 bytes (100.0%)
  Done
File saved as /root/.imageio/ffmpeg/ffmpeg.linux64.


In [17]:
import face_recognition
from moviepy.editor import VideoFileClip

<a id='13'></a>
# 13. Make video clips w/o face alignment

### Default transform: face B to face A

In [25]:
use_smoothed_mask = True
use_smoothed_bbox = True

def kalmanfilter_init(noise_coef):
    kf = cv2.KalmanFilter(4,2)
    kf.measurementMatrix = np.array([[1,0,0,0],[0,1,0,0]], np.float32)
    kf.transitionMatrix = np.array([[1,0,1,0],[0,1,0,1],[0,0,1,0],[0,0,0,1]], np.float32)
    kf.processNoiseCov = noise_coef * np.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]], np.float32)
    return kf

def is_higher_than_480p(x):
    return (x.shape[0] * x.shape[1]) >= (858*480)

def is_higher_than_720p(x):
    return (x.shape[0] * x.shape[1]) >= (1280*720)

def is_higher_than_1080p(x):
    return (x.shape[0] * x.shape[1]) >= (1920*1080)

def calibrate_coord(faces, video_scaling_factor):
    for i, (x0, y1, x1, y0) in enumerate(faces):
        faces[i] = (x0*video_scaling_factor, y1*video_scaling_factor, 
                    x1*video_scaling_factor, y0*video_scaling_factor)
    return faces

def get_faces_bbox(image, model="cnn"):  
    if is_higher_than_1080p(image):
        video_scaling_factor = 4 + video_scaling_offset
        resized_image = cv2.resize(image, 
                                   (image.shape[1]//video_scaling_factor, image.shape[0]//video_scaling_factor))
        faces = face_recognition.face_locations(resized_image, model=model)
        faces = calibrate_coord(faces, video_scaling_factor)
    elif is_higher_than_720p(image):
        video_scaling_factor = 3 + video_scaling_offset
        resized_image = cv2.resize(image, 
                                   (image.shape[1]//video_scaling_factor, image.shape[0]//video_scaling_factor))
        faces = face_recognition.face_locations(resized_image, model=model)
        faces = calibrate_coord(faces, video_scaling_factor)  
    elif is_higher_than_480p(image):
        video_scaling_factor = 2 + video_scaling_offset
        resized_image = cv2.resize(image, 
                                   (image.shape[1]//video_scaling_factor, image.shape[0]//video_scaling_factor))
        faces = face_recognition.face_locations(resized_image, model=model)
        faces = calibrate_coord(faces, video_scaling_factor)
    elif manually_downscale:
        video_scaling_factor = manually_downscale_factor
        resized_image = cv2.resize(image, 
                                   (image.shape[1]//video_scaling_factor, image.shape[0]//video_scaling_factor))
        faces = face_recognition.face_locations(resized_image, model=model)
        faces = calibrate_coord(faces, video_scaling_factor)
    else:
        faces = face_recognition.face_locations(image, model=model)
    return faces

def get_smoothed_coord(x0, x1, y0, y1, shape, ratio=0.65):
    global prev_x0, prev_x1, prev_y0, prev_y1
    global frames
    if not use_kalman_filter:
        x0 = int(ratio * prev_x0 + (1-ratio) * x0)
        x1 = int(ratio * prev_x1 + (1-ratio) * x1)
        y1 = int(ratio * prev_y1 + (1-ratio) * y1)
        y0 = int(ratio * prev_y0 + (1-ratio) * y0)
    else:
        x0y0 = np.array([x0, y0]).astype(np.float32)
        x1y1 = np.array([x1, y1]).astype(np.float32)
        if frames == 0:
            for i in range(200):
                kf0.predict()
                kf1.predict()
        kf0.correct(x0y0)
        pred_x0y0 = kf0.predict()
        kf1.correct(x1y1)
        pred_x1y1 = kf1.predict()
        x0 = np.max([0, pred_x0y0[0][0]]).astype(np.int)
        x1 = np.min([shape[0], pred_x1y1[0][0]]).astype(np.int)
        y0 = np.max([0, pred_x0y0[1][0]]).astype(np.int)
        y1 = np.min([shape[1], pred_x1y1[1][0]]).astype(np.int)
        if x0 == x1 or y0 == y1:
            x0, y0, x1, y1 = prev_x0, prev_y0, prev_x1, prev_y1
    return x0, x1, y0, y1    
    
def set_global_coord(x0, x1, y0, y1):
    global prev_x0, prev_x1, prev_y0, prev_y1
    prev_x0 = x0
    prev_x1 = x1
    prev_y1 = y1
    prev_y0 = y0
    
def generate_face(ae_input, path_abgr, roi_size, roi_image):
    result = np.squeeze(np.array([path_abgr([[ae_input]])]))
    result_a = result[:,:,0] * 255
    result_bgr = np.clip( (result[:,:,1:] + 1) * 255 / 2, 0, 255 )
    result_a = cv2.GaussianBlur(result_a ,(7,7),6)
    result_a = np.expand_dims(result_a, axis=2)
    result = (result_a/255 * result_bgr + (1 - result_a/255) * ((ae_input + 1) * 255 / 2)).astype('uint8')
    if use_color_correction:
        result = color_hist_match(result, roi_image)
    result = cv2.cvtColor(result.astype(np.uint8), cv2.COLOR_BGR2RGB)
    result = cv2.resize(result, (roi_size[1],roi_size[0]))
    result_a = np.expand_dims(cv2.resize(result_a, (roi_size[1],roi_size[0])), axis=2)
    return result, result_a

def get_init_mask_map(image):
    return np.zeros_like(image)

def get_init_comb_img(input_img):
    comb_img = np.zeros([input_img.shape[0], input_img.shape[1]*2,input_img.shape[2]])
    comb_img[:, :input_img.shape[1], :] = input_img
    comb_img[:, input_img.shape[1]:, :] = input_img
    return comb_img    

def get_init_triple_img(input_img, no_face=False):
    if no_face:
        triple_img = np.zeros([input_img.shape[0], input_img.shape[1]*3,input_img.shape[2]])
        triple_img[:, :input_img.shape[1], :] = input_img
        triple_img[:, input_img.shape[1]:input_img.shape[1]*2, :] = input_img      
        triple_img[:, input_img.shape[1]*2:, :] = (input_img * .15).astype('uint8')  
        return triple_img
    else:
        triple_img = np.zeros([input_img.shape[0], input_img.shape[1]*3,input_img.shape[2]])
        return triple_img

def get_mask(roi_image, h, w):
    mask = np.zeros_like(roi_image)
    mask[h//15:-h//15,w//15:-w//15,:] = 255
    mask = cv2.GaussianBlur(mask,(15,15),10)
    return mask

def hist_match(source, template):
    # Code borrow from:
    # https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x
    oldshape = source.shape
    source = source.ravel()
    template = template.ravel()
    s_values, bin_idx, s_counts = np.unique(source, return_inverse=True,
                                            return_counts=True)
    t_values, t_counts = np.unique(template, return_counts=True)

    s_quantiles = np.cumsum(s_counts).astype(np.float64)
    s_quantiles /= s_quantiles[-1]
    t_quantiles = np.cumsum(t_counts).astype(np.float64)
    t_quantiles /= t_quantiles[-1]
    interp_t_values = np.interp(s_quantiles, t_quantiles, t_values)

    return interp_t_values[bin_idx].reshape(oldshape)

def color_hist_match(src_im, tar_im):
    #src_im = cv2.cvtColor(src_im, cv2.COLOR_BGR2Lab)
    #tar_im = cv2.cvtColor(tar_im, cv2.COLOR_BGR2Lab)
    matched_R = hist_match(src_im[:,:,0], tar_im[:,:,0])
    matched_G = hist_match(src_im[:,:,1], tar_im[:,:,1])
    matched_B = hist_match(src_im[:,:,2], tar_im[:,:,2])
    matched = np.stack((matched_R, matched_G, matched_B), axis=2).astype(np.float64)
    return matched

def process_video(input_img):   
    # modify this line to reduce input size
    #input_img = input_img[:, input_img.shape[1]//3:2*input_img.shape[1]//3,:] 
    image = input_img
    faces = get_faces_bbox(image, model="cnn")
    
    if len(faces) == 0:
        comb_img = get_init_comb_img(input_img)
        triple_img = get_init_triple_img(input_img, no_face=True)
        
    mask_map = get_init_mask_map(image)
    comb_img = get_init_comb_img(input_img)
    global prev_x0, prev_x1, prev_y0, prev_y1
    global frames    
    for (x0, y1, x1, y0) in faces:        
        # smoothing bounding box
        if use_smoothed_bbox:
            if frames != 0:
                x0, x1, y0, y1 = get_smoothed_coord(x0, x1, y0, y1, 
                                                    image.shape, 
                                                    ratio=0.65 if use_kalman_filter else bbox_moving_avg_coef)
                set_global_coord(x0, x1, y0, y1)
                frames += 1
            else:
                set_global_coord(x0, x1, y0, y1)
                _ = get_smoothed_coord(x0, x1, y0, y1, image.shape)
                frames += 1
        h = x1 - x0
        w = y1 - y0
            
        cv2_img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        roi_image = cv2_img[x0+h//15:x1-h//15,y0+w//15:y1-w//15,:]
        roi_size = roi_image.shape  
        
        ae_input = cv2.resize(roi_image, (64,64))/255. * 2 - 1        
        result, result_a = generate_face(ae_input, path_abgr_A, roi_size, roi_image)
        mask_map[x0+h//15:x1-h//15, y0+w//15:y1-w//15,:] = result_a
        mask_map = np.clip(mask_map + .15 * input_img, 0, 255 )     
        
        if use_smoothed_mask:
            mask = get_mask(roi_image, h, w)
            roi_rgb = cv2.cvtColor(roi_image, cv2.COLOR_BGR2RGB)
            smoothed_result = mask/255 * result + (1-mask/255) * roi_rgb
            comb_img[x0+h//15:x1-h//15, input_img.shape[1]+y0+w//15:input_img.shape[1]+y1-w//15,:] = smoothed_result
        else:
            comb_img[x0+h//15:x1-h//15, input_img.shape[1]+y0+w//15:input_img.shape[1]+y1-w//15,:] = result
            
        triple_img = get_init_triple_img(input_img)
        triple_img[:, :input_img.shape[1]*2, :] = comb_img
        triple_img[:, input_img.shape[1]*2:, :] = mask_map
    
    # ========== Change the following line for different output type==========
    # return comb_img[:, input_img.shape[1]:, :]  # return only result image
    # return comb_img  # return input and result image combined as one
    return triple_img #return input,result and mask heatmap image combined as one

**Description**
```python
    video_scaling_offset = 0 # Increase by 1 if OOM happens.
    manually_downscale = False # Set True if increasing offset doesn't help
    manually_downscale_factor = int(2) # Increase by 1 if OOM still happens.
    use_color_correction = False # Option for color corretion
```

In [None]:
use_kalman_filter = True

if use_kalman_filter:
    noise_coef = 5e-3 # Increase by 10x if tracking is slow. 
    kf0 = kalmanfilter_init(noise_coef)
    kf1 = kalmanfilter_init(noise_coef)
else:
    bbox_moving_avg_coef = 0.65

In [36]:
# Variables for smoothing bounding box
global prev_x0, prev_x1, prev_y0, prev_y1
global frames
prev_x0 = prev_x1 = prev_y0 = prev_y1 = 0
frames = 0
video_scaling_offset = 0 
manually_downscale = False
manually_downscale_factor = int(2) # should be an positive integer
use_color_correction = False

output = 'OUTPUT_VIDEO.mp4'
clip1 = VideoFileClip("INPUT_VIDEO.mp4")
clip = clip1.fl_image(process_video)#.subclip(11, 13) #NOTE: this function expects color images!!
%time clip.write_videofile(output, audio=False)

[MoviePy] >>>> Building video tmp_sh_test_clipped3.mp4
[MoviePy] Writing video tmp_sh_test_clipped3.mp4


100%|█████████▉| 540/541 [01:50<00:00,  4.92it/s]


[MoviePy] Done.
[MoviePy] >>>> Video ready: tmp_sh_test_clipped3.mp4 

CPU times: user 1min 33s, sys: 17.1 s, total: 1min 50s
Wall time: 1min 51s


### gc.collect() sometimes solves memory error

In [111]:
import gc
gc.collect()

603

<a id='14'></a>
# 14. Make video clips w/ face alignment

### Default transform: face B to face A

The code is not refined. Also I can't tell if face alignment improves the result.

Code reference: https://github.com/nlhkh/face-alignment-dlib

In [30]:
use_smoothed_mask = True
apply_face_aln = True
use_poisson_blending = False # SeamlessClone is NOT recommended for video.
use_comp_video = True # output a comparison video before/after face swap
use_smoothed_bbox = True

def kalmanfilter_init(noise_coef):
    kf = cv2.KalmanFilter(4,2)
    kf.measurementMatrix = np.array([[1,0,0,0],[0,1,0,0]], np.float32)
    kf.transitionMatrix = np.array([[1,0,1,0],[0,1,0,1],[0,0,1,0],[0,0,0,1]], np.float32)
    kf.processNoiseCov = noise_coef * np.array([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]], np.float32)
    return kf

def is_higher_than_480p(x):
    return (x.shape[0] * x.shape[1]) >= (858*480)

def is_higher_than_720p(x):
    return (x.shape[0] * x.shape[1]) >= (1280*720)

def is_higher_than_1080p(x):
    return (x.shape[0] * x.shape[1]) >= (1920*1080)

def calibrate_coord(faces, video_scaling_factor):
    for i, (x0, y1, x1, y0) in enumerate(faces):
        faces[i] = (x0*video_scaling_factor, y1*video_scaling_factor, 
                    x1*video_scaling_factor, y0*video_scaling_factor)
    return faces

def get_faces_bbox(image, model="cnn"):  
    if is_higher_than_1080p(image):
        video_scaling_factor = 4 + video_scaling_offset
        resized_image = cv2.resize(image, 
                                   (image.shape[1]//video_scaling_factor, image.shape[0]//video_scaling_factor))
        faces = face_recognition.face_locations(resized_image, model=model)
        faces = calibrate_coord(faces, video_scaling_factor)
    elif is_higher_than_720p(image):
        video_scaling_factor = 3 + video_scaling_offset
        resized_image = cv2.resize(image, 
                                   (image.shape[1]//video_scaling_factor, image.shape[0]//video_scaling_factor))
        faces = face_recognition.face_locations(resized_image, model=model)
        faces = calibrate_coord(faces, video_scaling_factor)  
    elif is_higher_than_480p(image):
        video_scaling_factor = 2 + video_scaling_offset
        resized_image = cv2.resize(image, 
                                   (image.shape[1]//video_scaling_factor, image.shape[0]//video_scaling_factor))
        faces = face_recognition.face_locations(resized_image, model=model)
        faces = calibrate_coord(faces, video_scaling_factor)
    elif manually_downscale:
        video_scaling_factor = manually_downscale_factor
        resized_image = cv2.resize(image, 
                                   (image.shape[1]//video_scaling_factor, image.shape[0]//video_scaling_factor))
        faces = face_recognition.face_locations(resized_image, model=model)
        faces = calibrate_coord(faces, video_scaling_factor)
    else:
        faces = face_recognition.face_locations(image, model=model)
    return faces

def get_smoothed_coord(x0, x1, y0, y1, shape, ratio=0.65):
    global prev_x0, prev_x1, prev_y0, prev_y1
    global frames
    if not use_kalman_filter:
        x0 = int(ratio * prev_x0 + (1-ratio) * x0)
        x1 = int(ratio * prev_x1 + (1-ratio) * x1)
        y1 = int(ratio * prev_y1 + (1-ratio) * y1)
        y0 = int(ratio * prev_y0 + (1-ratio) * y0)
    else:
        x0y0 = np.array([x0, y0]).astype(np.float32)
        x1y1 = np.array([x1, y1]).astype(np.float32)
        if frames == 0:
            for i in range(200):
                kf0.predict()
                kf1.predict()
        kf0.correct(x0y0)
        pred_x0y0 = kf0.predict()
        kf1.correct(x1y1)
        pred_x1y1 = kf1.predict()
        x0 = np.max([0, pred_x0y0[0][0]]).astype(np.int)
        x1 = np.min([shape[0], pred_x1y1[0][0]]).astype(np.int)
        y0 = np.max([0, pred_x0y0[1][0]]).astype(np.int)
        y1 = np.min([shape[1], pred_x1y1[1][0]]).astype(np.int)
        if x0 == x1 or y0 == y1:
            x0, y0, x1, y1 = prev_x0, prev_y0, prev_x1, prev_y1
    return x0, x1, y0, y1    
    
def set_global_coord(x0, x1, y0, y1):
    global prev_x0, prev_x1, prev_y0, prev_y1
    prev_x0 = x0
    prev_x1 = x1
    prev_y1 = y1
    prev_y0 = y0
    
def extract_eye_center(shape):
    xs = 0
    ys = 0
    for pnt in shape:
        xs += pnt[0]
        ys += pnt[1]
    return ((xs//6), ys//6)

def get_rotation_matrix(p1, p2):
    angle = angle_between_2_points(p1, p2)
    x1, y1 = p1
    x2, y2 = p2
    xc = (x1 + x2) // 2
    yc = (y1 + y2) // 2
    M = cv2.getRotationMatrix2D((xc, yc), angle, 1)
    return M, (xc, yc), angle

def angle_between_2_points(p1, p2):
    x1, y1 = p1
    x2, y2 = p2
    if x1 == x2:
        return 90
    tan = (y2 - y1) / (x2 - x1)
    return np.degrees(np.arctan(tan))

def get_rotated_img(img, det):
    #print (det, img.shape)
    shape = face_recognition.face_landmarks(img, det)
    pnts_left_eye = shape[0]["left_eye"]
    pnts_right_eye = shape[0]["right_eye"]
    if len(pnts_left_eye) == 0 or len(pnts_right_eye) == 0:
        return img, None, None    
    le_center = extract_eye_center(shape[0]["left_eye"])
    re_center = extract_eye_center(shape[0]["right_eye"])
    M, center, angle = get_rotation_matrix(le_center, re_center)
    M_inv = cv2.getRotationMatrix2D(center, -1*angle, 1)    
    rotated = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]), flags=cv2.INTER_CUBIC)    
    return rotated, M, M_inv, center

def hist_match(source, template):
    # Code borrow from:
    # https://stackoverflow.com/questions/32655686/histogram-matching-of-two-images-in-python-2-x
    oldshape = source.shape
    source = source.ravel()
    template = template.ravel()
    s_values, bin_idx, s_counts = np.unique(source, return_inverse=True,
                                            return_counts=True)
    t_values, t_counts = np.unique(template, return_counts=True)

    s_quantiles = np.cumsum(s_counts).astype(np.float64)
    s_quantiles /= s_quantiles[-1]
    t_quantiles = np.cumsum(t_counts).astype(np.float64)
    t_quantiles /= t_quantiles[-1]
    interp_t_values = np.interp(s_quantiles, t_quantiles, t_values)

    return interp_t_values[bin_idx].reshape(oldshape)

def color_hist_match(src_im, tar_im):
    matched_R = hist_match(src_im[:,:,0], tar_im[:,:,0])
    matched_G = hist_match(src_im[:,:,1], tar_im[:,:,1])
    matched_B = hist_match(src_im[:,:,2], tar_im[:,:,2])
    matched = np.stack((matched_R, matched_G, matched_B), axis=2).astype(np.float64)
    return matched

def process_video(input_img):   
    image = input_img
    # ========== Decrease image size if getting memory error ==========
    #image = input_img[:3*input_img.shape[0]//4, :, :]
    #image = cv2.resize(image, (image.shape[1]//2,image.shape[0]//2))
    orig_image = np.array(image)
    faces = get_faces_bbox(image, model="cnn")
    
    if len(faces) == 0:
        comb_img = np.zeros([orig_image.shape[0], orig_image.shape[1]*2,orig_image.shape[2]])
        comb_img[:, :orig_image.shape[1], :] = orig_image
        comb_img[:, orig_image.shape[1]:, :] = orig_image
        if use_comp_video:
            return comb_img
        else:
            return image
    
    global prev_x0, prev_x1, prev_y0, prev_y1
    global frames
    for (x0, y1, x1, y0) in faces:        
        # smoothing bounding box
        if use_smoothed_bbox:
            if frames != 0:
                x0, x1, y0, y1 = get_smoothed_coord(x0, x1, y0, y1, 
                                                    image.shape, 
                                                    ratio=0.65 if use_kalman_filter else bbox_moving_avg_coef)
                set_global_coord(x0, x1, y0, y1)
                frames += 1
            else:
                set_global_coord(x0, x1, y0, y1)
                _ = get_smoothed_coord(x0, x1, y0, y1, image.shape)
                frames += 1      
        h = x1 - x0
        w = y1 - y0
                
        if apply_face_aln:
            do_back_rot = True
            image, M, M_inv, center = get_rotated_img(image, [(x0, y1, x1, y0)])
            if M is None:
                do_back_rot = False
        
        cv2_img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 
        roi_image = cv2_img[x0+h//15:x1-h//15, y0+w//15:y1-w//15, :]
        roi_size = roi_image.shape            
        
        if use_smoothed_mask:
            mask = np.zeros_like(roi_image)
            #print (roi_image.shape, mask.shape)
            mask[h//15:-h//15,w//15:-w//15,:] = 255
            mask = cv2.GaussianBlur(mask,(15,15),10)
            roi_image_rgb = cv2.cvtColor(roi_image, cv2.COLOR_BGR2RGB)
        
        ae_input = cv2.resize(roi_image, (64,64))/255. * 2 - 1        
        result = np.squeeze(np.array([path_abgr_A([[ae_input]])]))
        result_a = result[:,:,0] * 255
        result_bgr = np.clip( (result[:,:,1:] + 1) * 255 / 2, 0, 255 )
        result_a = cv2.GaussianBlur(result_a ,(7,7),6)
        result_a = np.expand_dims(result_a, axis=2)
        result = (result_a/255 * result_bgr + (1 - result_a/255) * ((ae_input + 1) * 255 / 2)).astype('uint8')
        if use_color_correction:
            result = color_hist_match(result, roi_image)
        result = cv2.cvtColor(result.astype(np.uint8), cv2.COLOR_BGR2RGB)
        result = cv2.resize(result, (roi_size[1],roi_size[0]))        
        result_img = np.array(orig_image)
        
        if use_smoothed_mask and not use_poisson_blending:
            image[x0+h//15:x1-h//15, y0+w//15:y1-w//15,:] = mask/255*result + (1-mask/255)*roi_image_rgb
        elif use_poisson_blending:
            c = (y0+w//2, x0+h//2)
            image = cv2.seamlessClone(result, image, mask, c, cv2.NORMAL_CLONE)     
            
        if do_back_rot:
            image = cv2.warpAffine(image, M_inv, (image.shape[1], image.shape[0]), flags=cv2.INTER_CUBIC)
            result_img[x0+h//15:x1-h//15, y0+w//15:y1-w//15,:] = image[x0+h//15:x1-h//15, y0+w//15:y1-w//15,:]
        else:
            result_img[x0+h//15:x1-h//15, y0+w//15:y1-w//15,:] = image[x0+h//15:x1-h//15, y0+w//15:y1-w//15,:]   

        if use_comp_video:
            comb_img = np.zeros([orig_image.shape[0], orig_image.shape[1]*2,orig_image.shape[2]])
            comb_img[:, :orig_image.shape[1], :] = orig_image
            comb_img[:, orig_image.shape[1]:, :] = result_img
            
    if use_comp_video:
        return comb_img
    else:
        return result_img

**Description**
```python
    video_scaling_offset = 0 # Increase by 1 if OOM happens.
    manually_downscale = False # Set True if increasing offset doesn't help
    manually_downscale_factor = int(2) # Increase by 1 if OOM still happens.
    use_color_correction = False # Option for color corretion
```

In [None]:
use_kalman_filter = True

if use_kalman_filter:
    noise_coef = 5e-3 # Increase by 10x if tracking is slow. 
    kf0 = kalmanfilter_init(noise_coef)
    kf1 = kalmanfilter_init(noise_coef)
else:
    bbox_moving_avg_coef = 0.65

In [None]:
# Variables for smoothing bounding box
global prev_x0, prev_x1, prev_y0, prev_y1
global frames
prev_x0 = prev_x1 = prev_y0 = prev_y1 = 0
frames = 0
video_scaling_offset = 0 
manually_downscale = False
manually_downscale_factor = int(2) # should be an positive integer
use_color_correction = False

output = 'OUTPUT_VIDEO.mp4'
clip1 = VideoFileClip("TEST_VIDEO.mp4")
# .subclip(START_SEC, END_SEC) for testing
clip = clip1.fl_image(process_video)#.subclip(1, 5) #NOTE: this function expects color images!!
%time clip.write_videofile(output, audio=False)