In [1]:
from pathlib import Path
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt

import torch

from core.models import DF, LIAE
from core.options import read_yaml
from core.loglib import load_weights
from DFLIMG import DFLIMG
from facelib import FaceType, LandmarksProcessor
from merger.MergeMasked import *

In [2]:
project = ""
video = ""
subfolder = ""
model = ""
outvideo = ""

input_path = Path(f"../video-input/{project}/{video}")
aligned_path = Path(f"../video-input/{project}/{video}/aligned")

saved_models_path = Path(f"../workspace/{project}/{model}_torch")

test_path = Path(f"../video-output/{project}/{subfolder}/{outvideo}_raw")
test_mask_path = Path(f"../video-output/{project}/{subfolder}/{outvideo}_raw_mask")
test_path.mkdir(parents=True, exist_ok=True)
test_mask_path.mkdir(parents=True, exist_ok=True)

output_path = Path(f"../video-output/{project}/{subfolder}/{outvideo}2")
output_mask_path = Path(f"../video-output/{project}/{subfolder}/{outvideo}2_mask")
output_path.mkdir(parents=True, exist_ok=True)
output_mask_path.mkdir(parents=True, exist_ok=True)

outputvideo_path = Path(f"../video-output/{project}/{subfolder}/output_{outvideo}.mp4")

In [3]:
device = "cuda:0"

model_dict = read_yaml(saved_models_path.joinpath("model_opt.yaml"))
if model_dict.model_type.startswith('df'):
    model = DF(model_dict.resolution, 
               model_dict.ae_dims, model_dict.e_dims, 
               model_dict.d_dims, model_dict.d_mask_dims, 
               likeness=model_dict.likeness, double_res=model_dict.double_res).to(device)
else:
    model = LIAE(model_dict.resolution, 
               model_dict.ae_dims, model_dict.e_dims, 
               model_dict.d_dims, model_dict.d_mask_dims, 
               likeness=model_dict.likeness, double_res=model_dict.double_res).to(device)
    
model, log_history = load_weights(saved_models_path, model, finetune_start=False)

predictor_input_shape = (model_dict.resolution, model_dict.resolution, 3)
face_type = model_dict.face_type

Loading encoder
Loading inter_AB
Loading inter_B
Loading decoder


In [2]:
alignments = {}
input_path_image_paths = sorted(list(input_path.glob("*.jpg")))
align_path_image_paths = sorted(list(aligned_path.glob("*.jpg")))

for filepath in tqdm(align_path_image_paths):
    filepath = Path(filepath)
    dflimg = DFLIMG.load(filepath)
                    
    if dflimg is None or not dflimg.has_data():
        print(f"{filepath.name} is not a dfl image file")
        continue
    else:
        source_filename = dflimg.get_source_filename()
        if source_filename is None:
            continue
        else:
            source_filename_stem = Path(source_filename).stem
            if source_filename_stem not in alignments.keys():
                alignments[source_filename_stem] = [dflimg.get_source_landmarks()]

frames = [ 
    {
        "filepath" : Path(p),
        "landmarks_list" : alignments.get(Path(p).stem, None),
    }
    for p in input_path_image_paths 
]

In [5]:
frame_info = frames[3150]

img_bgr_uint8 = cv2.imread(str(frame_info["filepath"]))
img_bgr = img_bgr_uint8.astype(np.float32) / 255.0

img_face_landmarks = frame_info["landmarks_list"][0]
img_size = img_bgr.shape[1], img_bgr.shape[0]
img_face_mask_a = LandmarksProcessor.get_image_hull_mask(img_bgr.shape, img_face_landmarks)

size_dict, mat_dict = get_size_and_mat(predictor_input_shape, img_face_landmarks, 
            1.0, face_type, use_sr = False)

dst_face_bgr      = warp_and_clip(img_bgr,         mat_dict["face"], size_dict["output"])
dst_face_mask_a_0 = warp_and_clip(img_face_mask_a, mat_dict["face"], size_dict["output"])

In [8]:
with torch.no_grad():
    warped_dst = torch.from_numpy(dst_face_bgr).permute(2,0,1).unsqueeze(0).to(device)
    prd_face_bgr, prd_face_mask_a_0, prd_face_dst_mask_a_0 = model.single_forward(warped_dst)
    prd_face_bgr = prd_face_bgr[0].permute(1,2,0).detach().cpu().numpy()
    prd_face_mask_a_0 = prd_face_mask_a_0[0].permute(1,2,0).detach().cpu().numpy()
    prd_face_dst_mask_a_0 = prd_face_dst_mask_a_0[0].permute(1,2,0).detach().cpu().numpy()

prd_face_bgr = cv2.resize(prd_face_bgr, (size_dict["input"], size_dict["input"]), cv2.INTER_CUBIC)

prd_full = warp_and_clip(prd_face_bgr, mat_dict["face_output"], img_size, inverse=True)
dst_full = warp_and_clip(dst_face_bgr, mat_dict["face_output"], img_size, inverse=True)

In [9]:
wrk_face_mask_a_0 = get_normfacemask(dst_face_mask_a_0, prd_face_mask_a_0, prd_face_dst_mask_a_0, 
                         mask_mode=3, output_size=size_dict["output"])[:,:,0]
if wrk_face_mask_a_0.shape[0] != size_dict["mask"]:
    wrk_face_mask_a_0 = cv2.resize (wrk_face_mask_a_0, 
        (size_dict["mask"], size_dict["mask"]), 
        interpolation=cv2.INTER_CUBIC)

wrk_face_mask_b_0 = get_erodeblurmask(wrk_face_mask_a_0, ero=30, blur=80, input_size=size_dict["input"])
wrk_face_mask_b_0 = np.clip(wrk_face_mask_b_0, 0.0, 1.0)

img_face_mask_a = warp_and_clip(wrk_face_mask_b_0, mat_dict["face_mask"], 
            img_size, use_clip=True, inverse=True, remove_noise=True)

if wrk_face_mask_b_0.shape[0] != size_dict["output"]:
    wrk_face_mask_b_0 = cv2.resize (wrk_face_mask_b_0, 
        (size_dict["output"],size_dict["output"]), 
        interpolation=cv2.INTER_CUBIC)
                
wrk_face_mask_area = wrk_face_mask_b_0[...,None].copy()
wrk_face_mask_area[wrk_face_mask_area>0] = 1.0

In [3]:
fig, ax = plt.subplots(1,3,figsize=(18,6))
ax[0].imshow(dst_face_bgr[:,:,::-1])
ax[1].imshow(prd_face_bgr[:,:,::-1])
ax[2].imshow(wrk_face_mask_a_0, cmap="gray")
plt.show()

In [4]:
prd_face_bgr_c = prd_face_bgr.copy()
prd_face_bgr_c = get_colortransfer(prd_face_bgr_c, dst_face_bgr, wrk_face_mask_area, 
    color_transfer_mode=2)

out_img = warp_and_clip(prd_face_bgr_c, mat_dict["face_output"], img_size, inverse=True)

out_img = img_bgr*(1-img_face_mask_a) + out_img*img_face_mask_a

out_face = warp_and_clip(out_img, mat_dict["face"], size_dict["output"], use_clip=False)

plt.imshow(out_face[:,:,::-1])