In [None]:
import cv2
import torch
import time
import os
import glob
import warnings
warnings.filterwarnings("ignore")

import numpy as np
np.bool = np.bool_

from utils.inference.image_processing import crop_face, get_final_image, show_images
from utils.inference.video_processing import read_video, get_target, get_final_video, add_audio_from_another_video, face_enhancement
from utils.inference.core import model_inference

from network.AEI_Net import AEI_Net
from coordinate_reg.image_infer import Handler
from insightface_func.face_detect_crop_multi import Face_detect_crop
from arcface_model.iresnet import iresnet100
from models.pix2pix_model import Pix2PixModel
from models.config_sr import TestOptions


use_sr = True





app = Face_detect_crop(name='antelope', root='./insightface_func/models')
app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640))

# main model for generation
G = AEI_Net(backbone='unet', num_blocks=2, c_id=512)
G.eval()
G.load_state_dict(torch.load('weights/G_unet_2blocks.pth', map_location=torch.device('cpu')))
G = G.cuda()
G = G.half()

# arcface model to get face embedding
netArc = iresnet100(fp16=False)
netArc.load_state_dict(torch.load('arcface_model/backbone.pth'))
netArc=netArc.cuda()
netArc.eval()

# model to get face landmarks
handler = Handler('./coordinate_reg/model/2d106det', 0, ctx_id=0, det_size=640)

# model to make superres of face, set use_sr=True if you want to use super resolution or use_sr=False if you don't
if use_sr:
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    torch.backends.cudnn.benchmark = True
    opt = TestOptions()
    #opt.which_epoch ='10_7'
    model = Pix2PixModel(opt)
    model.netG.train()

# Easy Function

In [None]:
def swap_face(source: str,
              target: str,
              output_path: str,
              image_to_image=True,
              show_image=True,
              with_audio=False):
    """
    choose not really long videos, coz it can take a lot of time othervise 
    choose source image as a photo -- preferable a selfie of a person
    """
    if image_to_image:
        path_to_target = target
    else:
        path_to_video = target
    #source_full = cv2.imread('examples/images/elon_musk.jpg')
    source_full = cv2.imread(source)
    
    OUTPUT_NAME = output_path+f"/{os.path.splitext(os.path.basename(source))[0]}_{os.path.splitext(os.path.basename(target))[0]}"
    
    crop_size = 224 # don't change this
    BS = 60

    try:    
        source = crop_face(source_full, app, crop_size)[0]
        source = [source[:, :, ::-1]]
        print("Everything is ok!")
    except TypeError:
        print("Bad source images")
        return


    if image_to_image:
        target_full = cv2.imread(path_to_target)
        full_frames = [target_full]
        
    else:
        full_frames, fps = read_video(path_to_video)
    target = get_target(full_frames, app, crop_size)
    if type(target) == type(None) or len(target) == 0:
        print("no face detected in targe image/video")
        return
    
    final_frames_list, crop_frames_list, full_frames, tfm_array_list = model_inference(full_frames,
                                                                                    source,
                                                                                    target,
                                                                                    netArc,
                                                                                    G,
                                                                                    app,
                                                                                    set_target = False,
                                                                                    crop_size=crop_size,
                                                                                    BS=BS)

    if use_sr:
        final_frames_list = face_enhancement(final_frames_list, model)

    if image_to_image:
        result = get_final_image(final_frames_list, crop_frames_list, full_frames[0], tfm_array_list, handler)
        if show_image:
            show_images([source[0][:, :, ::-1], target_full, result], ['Source Image', 'Target Image', 'Swapped Image'], figsize=(20, 15))
        cv2.imwrite(f"{OUTPUT_NAME}.jpg", result)
        
    else:
        get_final_video(final_frames_list,
                        crop_frames_list,
                        full_frames,
                        tfm_array_list,
                        OUTPUT_NAME+".mp4",
                        fps, 
                        handler)
        if with_audio:
            add_audio_from_another_video(path_to_video, OUTPUT_NAME+".mp4", "audio")


In [None]:
swap_face("examples/images/IMG_1379.jpg",
          "examples/images/beckham.jpg",
          "examples/results",
          image_to_image=True,
          show_image=True)

In [None]:
swap_face("examples/images/beckham.jpg",
          "examples/videos/nggyup.mp4",
          "examples/results",
          image_to_image=False,
          with_audio=True)