In [None]:
import torch
import torch.nn as nn

class PixelNorm(nn.Module):
    def forward(self, x):
        return x / torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-8)
    
class Res(nn.Module):
    def __init__(self, n_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(n_ch, n_ch, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(n_ch, n_ch, kernel_size=3, padding=1),
        )
        self.fuse = nn.LeakyReLU(0.2)

    def forward(self, x):
        return self.fuse(x + self.conv(x))
    
class OutConv(nn.Module):
    def __init__(self, n_ch):
        super().__init__()
        self.out_conv = nn.ModuleList([
            nn.Conv2d(n_ch, 3, kernel_size=1),
            nn.Conv2d(n_ch, 3, kernel_size=3, padding=1),
            nn.Conv2d(n_ch, 3, kernel_size=3, padding=1),
            nn.Conv2d(n_ch, 3, kernel_size=3, padding=1),
        ])

    def forward(self, x):
        return torch.cat([i(x) for i in self.out_conv], dim=1)

class Encoder(nn.Module):
    def __init__(self, n_ch):
        super().__init__()
        self.image_encoder = nn.Sequential(
            nn.Conv2d(     3, n_ch*1, kernel_size=5, stride=2, padding=2), nn.LeakyReLU(0.1, inplace=True), Res(n_ch*1),
            nn.Conv2d(n_ch*1, n_ch*2, kernel_size=5, stride=2, padding=2), nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(n_ch*2, n_ch*4, kernel_size=5, stride=2, padding=2), nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(n_ch*4, n_ch*8, kernel_size=5, stride=2, padding=2), nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(n_ch*8, n_ch*8, kernel_size=5, stride=2, padding=2), nn.LeakyReLU(0.1, inplace=True), Res(n_ch*8),
            nn.Flatten(), PixelNorm()
        )

    def forward(self, x):
        return self.image_encoder(x)
    
class Decoder(nn.Module):
    def __init__(self, i_ch, m_ch):
        super().__init__()
        self.image_decoder = nn.Sequential(
            nn.Conv2d(   512, i_ch*8*4, kernel_size=3, padding=1), nn.LeakyReLU(0.1, inplace=True), nn.PixelShuffle(2), Res(i_ch*8),
            nn.Conv2d(i_ch*8, i_ch*8*4, kernel_size=3, padding=1), nn.LeakyReLU(0.1, inplace=True), nn.PixelShuffle(2), Res(i_ch*8),
            nn.Conv2d(i_ch*8, i_ch*4*4, kernel_size=3, padding=1), nn.LeakyReLU(0.1, inplace=True), nn.PixelShuffle(2), Res(i_ch*4),
            nn.Conv2d(i_ch*4, i_ch*2*4, kernel_size=3, padding=1), nn.LeakyReLU(0.1, inplace=True), nn.PixelShuffle(2), Res(i_ch*2),
            OutConv(i_ch*2), nn.PixelShuffle(2), nn.Sigmoid()
        )

        self.mask_decoder = nn.Sequential(
            nn.Conv2d(   512, m_ch*8*4, kernel_size=3, padding=1), nn.LeakyReLU(0.1, inplace=True), nn.PixelShuffle(2),
            nn.Conv2d(m_ch*8, m_ch*8*4, kernel_size=3, padding=1), nn.LeakyReLU(0.1, inplace=True), nn.PixelShuffle(2),
            nn.Conv2d(m_ch*8, m_ch*4*4, kernel_size=3, padding=1), nn.LeakyReLU(0.1, inplace=True), nn.PixelShuffle(2),
            nn.Conv2d(m_ch*4, m_ch*2*4, kernel_size=3, padding=1), nn.LeakyReLU(0.1, inplace=True), nn.PixelShuffle(2),
            nn.Conv2d(m_ch*2, 1*4, kernel_size=1), nn.PixelShuffle(2), nn.Sigmoid()
        )

    def forward(self, z):
        return self.image_decoder(z), self.mask_decoder(z)


dtype  = torch.float32
device = torch.device('cuda:0')
image_size   = 512
process_size = 128

encoder = torch.load(r"...").to(device=device, dtype=dtype)
decoder = torch.load(r"...").to(device=device, dtype=dtype)
mlp     = torch.load(r"...").to(device=device, dtype=dtype)

In [None]:
import cv2
import subprocess
import numpy as np

from tqdm import tqdm
from pathlib import Path
from face_parser import FaceParser
from face_detector import FaceDetector

parser = FaceParser()
face_detector = FaceDetector(models_dir="./")


def process(input_video: Path, output_path: Path):

    input_stream  = cv2.VideoCapture(filename=str(input_video))
    output_stream = cv2.VideoWriter(
        filename  = str(output_path), 
        fourcc    = cv2.VideoWriter_fourcc(*"mp4v"), 
        fps       = int(input_stream.get(cv2.CAP_PROP_FPS)),
        frameSize = (int(input_stream.get(cv2.CAP_PROP_FRAME_WIDTH)), int(input_stream.get(cv2.CAP_PROP_FRAME_HEIGHT)))
        )

    scale = image_size // process_size
    total_frames = int(input_stream.get(cv2.CAP_PROP_FRAME_COUNT))
    
    for _ in tqdm(range(total_frames)):
        ret, frame = input_stream.read()
        if not ret: break

        out_frame  = frame
        input_face, M = face_detector.get_face(frame=frame, image_size=image_size, zoom_out_factor=1)

        if M is not None:
            input_face = cv2.cvtColor(input_face, cv2.COLOR_BGR2RGB)
            input_face = torch.tensor(input_face, device=device, dtype=dtype).permute(2, 0, 1)
            input_face = input_face.reshape(3, process_size, scale, process_size, scale).permute(2,4,0,1,3)
            input_face = input_face.reshape(scale*scale,3,process_size,process_size).to(torch.float32) / 255.

            with torch.no_grad():
                codes = mlp(encoder(input_face)).view(-1, 256, process_size//32, process_size//32)
                pred_img, pred_msk = decoder(torch.concat([codes, codes], dim=1))

            pred_img = pred_img.view(scale,scale,3,process_size,process_size).permute(2,3,0,4,1).reshape(3,image_size,image_size)          
            pred_msk = pred_msk.view(scale,scale,1,process_size,process_size).permute(2,3,0,4,1).reshape(1,image_size,image_size)          

            pred_img = (pred_img.clamp(0,1)*255).to(device="cpu", dtype=torch.uint8).permute(1, 2, 0).numpy()
            pred_msk = (pred_msk.repeat(3, 1, 1).clamp(0,1)).to(device="cpu").permute(1, 2, 0).numpy()

            output_face = cv2.cvtColor(pred_img, cv2.COLOR_BGR2RGB)

            out_mask  = cv2.warpAffine(pred_msk   , cv2.invertAffineTransform(M), dsize=frame.shape[1::-1], borderValue=(0,0,0))
            out_frame = cv2.warpAffine(output_face, cv2.invertAffineTransform(M), dsize=frame.shape[1::-1])
            out_frame = (out_mask*out_frame + (1-out_mask)*frame).astype(np.uint8)

        output_stream.write(out_frame)
    
    output_stream.release()
    input_stream.release()

    subprocess.run(f"ffmpeg -i {str(output_path)} -i {str(input_video)} -shortest -c copy out.mp4 -y")
    output_path.unlink()
    Path("out.mp4").rename(output_path)

In [None]:
for video in list(Path(r"...").iterdir()):
    process(
        input_video=video,
        output_path=Path("outs")/video.name
        )