In [None]:
import torch
import pickle
import torch.nn as nn
import numpy as np

class PixelNorm(nn.Module):
    def forward(self, x):
        return x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
    
class DepthToSpace(nn.Module):
    def __init__(self, size):
        super(DepthToSpace, self).__init__()
        self.size = size

    def forward(self, x):
        b, c, h, w = x.shape
        oh, ow = h * self.size, w * self.size
        oc = c // (self.size * self.size)
        x = x.view(b, self.size, self.size, oc, h, w)
        x = x.permute(0, 3, 4, 1, 5, 2)
        x = x.contiguous().view(b, oc, oh, ow)
        return x
    
class Res(nn.Module):
    def __init__(self, n_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(n_ch, n_ch, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(n_ch, n_ch, kernel_size=3, padding=1),
        )
        self.fuse = nn.LeakyReLU(0.2)

    def forward(self, x):
        return self.fuse(x + self.conv(x))
    
class OutConv(nn.Module):
    def __init__(self, n_ch):
        super().__init__()
        self.out_conv = nn.ModuleList([
            nn.Conv2d(n_ch, 3, kernel_size=1),
            nn.Conv2d(n_ch, 3, kernel_size=3, padding=1),
            nn.Conv2d(n_ch, 3, kernel_size=3, padding=1),
            nn.Conv2d(n_ch, 3, kernel_size=3, padding=1),
        ])

    def forward(self, x):
        return torch.cat([i(x) for i in self.out_conv], dim=1)

class Encoder(nn.Module):
    def __init__(self, e_ch):
        super().__init__()
        self.image_encoder = nn.Sequential(
            nn.Conv2d(     3, e_ch*1, kernel_size=5, stride=2, padding=2), nn.LeakyReLU(0.1, inplace=True), Res(e_ch*1),
            nn.Conv2d(e_ch*1, e_ch*2, kernel_size=5, stride=2, padding=2), nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(e_ch*2, e_ch*4, kernel_size=5, stride=2, padding=2), nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(e_ch*4, e_ch*8, kernel_size=5, stride=2, padding=2), nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(e_ch*8, e_ch*8, kernel_size=5, stride=2, padding=2), nn.LeakyReLU(0.1, inplace=True), Res(e_ch*8),
            nn.Flatten(), PixelNorm()
        )

    def forward(self, x):
        return self.image_encoder(x)
    
class Decoder(nn.Module):
    def __init__(self, ae_dim, d_ch, m_ch):
        super().__init__()
        self.image_decoder = nn.Sequential(
            nn.Conv2d(ae_dim, d_ch*8*4, kernel_size=3, padding=1), nn.LeakyReLU(0.1, inplace=True), DepthToSpace(2), Res(d_ch*8),
            nn.Conv2d(d_ch*8, d_ch*8*4, kernel_size=3, padding=1), nn.LeakyReLU(0.1, inplace=True), DepthToSpace(2), Res(d_ch*8),
            nn.Conv2d(d_ch*8, d_ch*4*4, kernel_size=3, padding=1), nn.LeakyReLU(0.1, inplace=True), DepthToSpace(2), Res(d_ch*4),
            nn.Conv2d(d_ch*4, d_ch*2*4, kernel_size=3, padding=1), nn.LeakyReLU(0.1, inplace=True), DepthToSpace(2), Res(d_ch*2),
            OutConv(d_ch*2), DepthToSpace(2), nn.Sigmoid()
        )

        self.mask_decoder = nn.Sequential(
            nn.Conv2d(ae_dim, m_ch*8*4, kernel_size=3, padding=1), nn.LeakyReLU(0.1, inplace=True), DepthToSpace(2),
            nn.Conv2d(m_ch*8, m_ch*8*4, kernel_size=3, padding=1), nn.LeakyReLU(0.1, inplace=True), DepthToSpace(2),
            nn.Conv2d(m_ch*8, m_ch*4*4, kernel_size=3, padding=1), nn.LeakyReLU(0.1, inplace=True), DepthToSpace(2),
            nn.Conv2d(m_ch*4, m_ch*2*4, kernel_size=3, padding=1), nn.LeakyReLU(0.1, inplace=True), DepthToSpace(2),
            nn.Conv2d(m_ch*2, m_ch*1*4, kernel_size=3, padding=1), nn.LeakyReLU(0.1, inplace=True), DepthToSpace(2),
            nn.Conv2d(m_ch*1,        1, kernel_size=1), nn.Sigmoid()
        )

    def forward(self, z):
        return self.image_decoder(z), self.mask_decoder(z)

In [None]:
def load_weights(model_path):
    with open(model_path, 'rb') as f:
        first_item = pickle.load(f)
        if isinstance(first_item, int):
            return {list(w.keys())[0]: list(w.values())[0] for _ in range(first_item) for w in [pickle.load(f)]}
    return np.load(model_path, allow_pickle=True)

In [None]:
def transfer_encoder(tf_encoder_path):

    weights = load_weights(tf_encoder_path)
    encoder_dims = weights["down1/conv1/weight:0"].shape[-1]
    print(f"encoder dims: {encoder_dims}")

    encoder = Encoder(encoder_dims)
    encoder_state_dict = encoder.state_dict()

    for i, ((layer_name, params), tf_params) in enumerate(zip(encoder_state_dict.items(), weights.values())):
        if tf_params.ndim == 4:
            tf_params = tf_params.transpose(3, 2, 0, 1)
            
        assert params.shape == tf_params.shape
        # print(f"Layer {i+1:02d}. torch_shape: {str(params.shape)[11:-1]:<16} | tf_shape: {tf_params.shape}")
        encoder_state_dict[layer_name] = torch.from_numpy(tf_params)

    encoder.load_state_dict(encoder_state_dict)
    return encoder


def transfer_inter(tf_inter_path):

    weights = load_weights(tf_inter_path)
    in_features  = weights['dense1/weight:0'].shape[0]
    ae_dims      = weights['dense1/weight:0'].shape[1]
    out_features = weights['dense2/weight:0'].shape[1]
    print(f"inter dims: {in_features} -> {ae_dims} -> {out_features}")

    inter = nn.Sequential(nn.Linear(in_features, ae_dims), nn.Linear(ae_dims, out_features))
    inter_state_dict = inter.state_dict()

    for i, ((layer_name, params), tf_params) in enumerate(zip(inter_state_dict.items(), weights.values())):
        if tf_params.ndim == 2:
            tf_params = tf_params.transpose(1, 0)
            
        assert params.shape == tf_params.shape
        # print(f"Layer {i+1:02d}. torch_shape: {str(params.shape)[11:-1]:<16} | tf_shape: {tf_params.shape}")
        inter_state_dict[layer_name] = torch.from_numpy(tf_params)

    inter.load_state_dict(inter_state_dict)
    return inter


def transfer_decoder(tf_decoder_path):

    weights = load_weights(tf_decoder_path)

    decoder = Decoder(2048, 64, 32)
    decoder_state_dict = decoder.state_dict()

    decoder_layer_order = [
        "upscale0/conv1/weight:0", "upscale0/conv1/bias:0", "res0/conv1/weight:0", "res0/conv1/bias:0", "res0/conv2/weight:0", "res0/conv2/bias:0",
        "upscale1/conv1/weight:0", "upscale1/conv1/bias:0", "res1/conv1/weight:0", "res1/conv1/bias:0", "res1/conv2/weight:0", "res1/conv2/bias:0",
        "upscale2/conv1/weight:0", "upscale2/conv1/bias:0", "res2/conv1/weight:0", "res2/conv1/bias:0", "res2/conv2/weight:0", "res2/conv2/bias:0",
        "upscale3/conv1/weight:0", "upscale3/conv1/bias:0", "res3/conv1/weight:0", "res3/conv1/bias:0", "res3/conv2/weight:0", "res3/conv2/bias:0",
        "out_conv/weight:0", "out_conv/bias:0", "out_conv1/weight:0", "out_conv1/bias:0", "out_conv2/weight:0", "out_conv2/bias:0", "out_conv3/weight:0", "out_conv3/bias:0",
        "upscalem0/conv1/weight:0", "upscalem0/conv1/bias:0", "upscalem1/conv1/weight:0", "upscalem1/conv1/bias:0", "upscalem2/conv1/weight:0", "upscalem2/conv1/bias:0",
        "upscalem3/conv1/weight:0", "upscalem3/conv1/bias:0", "upscalem4/conv1/weight:0", "upscalem4/conv1/bias:0", "out_convm/weight:0", "out_convm/bias:0",
    ]

    for i, ((layer_name, params), tf_name) in enumerate(zip(decoder_state_dict.items(), decoder_layer_order)):
        tf_params = weights[tf_name]

        if tf_params.ndim == 4:
            tf_params = tf_params.transpose(3, 2, 0, 1)
            
        assert params.shape == tf_params.shape
        # print(f"Layer {i+1:02d}. {layer_name:<36}: {str(params.shape)[11:-1]:<20} | {tf_name:<30}: {tf_params.shape}")
        decoder_state_dict[layer_name] = torch.from_numpy(tf_params)

    decoder.load_state_dict(decoder_state_dict)
    return decoder

In [None]:
from pathlib import Path

class IDTransfer(nn.Module):
    def __init__(self):
        super(IDTransfer, self).__init__()
        self.encoder  = transfer_encoder(Path(r"..." ))
        self.inter_AB = transfer_inter  (Path(r"..."))
        self.decoder  = transfer_decoder(Path(r"..." ))

    def forward(self, x):
        x = self.inter_AB(self.encoder(x)).reshape(-1, 1024, 7, 7)
        return self.decoder(torch.concat((x, x), axis=1))

In [None]:
import cv2
import numpy as np
import numexpr as ne

def reinhard_color_transfer(target : np.ndarray, source : np.ndarray, target_mask : np.ndarray = None, source_mask : np.ndarray = None, mask_cutoff=0.5) -> np.ndarray:
    """
    Transfer color using rct method.

        target      np.ndarray H W 3C   (BGR)   np.float32
        source      np.ndarray H W 3C   (BGR)   np.float32

        target_mask(None)   np.ndarray H W 1C  np.float32
        source_mask(None)   np.ndarray H W 1C  np.float32
        
        mask_cutoff(0.5)    float

    masks are used to limit the space where color statistics will be computed to adjust the target

    reference: Color Transfer between Images https://www.cs.tau.ac.il/~turkel/imagepapers/ColorTransfer.pdf
    """
    source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB)
    target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB)

    source_input = source
    if source_mask is not None:
        source_input = source_input.copy()
        source_input[source_mask[...,0] < mask_cutoff] = [0,0,0]
    
    target_input = target
    if target_mask is not None:
        target_input = target_input.copy()
        target_input[target_mask[...,0] < mask_cutoff] = [0,0,0]

    target_l_mean, target_l_std, target_a_mean, target_a_std, target_b_mean, target_b_std, \
        = target_input[...,0].mean(), target_input[...,0].std(), target_input[...,1].mean(), target_input[...,1].std(), target_input[...,2].mean(), target_input[...,2].std()
    
    source_l_mean, source_l_std, source_a_mean, source_a_std, source_b_mean, source_b_std, \
        = source_input[...,0].mean(), source_input[...,0].std(), source_input[...,1].mean(), source_input[...,1].std(), source_input[...,2].mean(), source_input[...,2].std()
    
    # not as in the paper: scale by the standard deviations using reciprocal of paper proposed factor
    target_l = target[...,0]
    target_l = ne.evaluate('(target_l - target_l_mean) * source_l_std / target_l_std + source_l_mean')

    target_a = target[...,1]
    target_a = ne.evaluate('(target_a - target_a_mean) * source_a_std / target_a_std + source_a_mean')
    
    target_b = target[...,2]
    target_b = ne.evaluate('(target_b - target_b_mean) * source_b_std / target_b_std + source_b_mean')

    np.clip(target_l,    0, 100, out=target_l)
    np.clip(target_a, -127, 127, out=target_a)
    np.clip(target_b, -127, 127, out=target_b)

    return cv2.cvtColor(np.stack([target_l,target_a,target_b], -1), cv2.COLOR_LAB2BGR)

In [None]:
import cv2
import subprocess
import numpy as np
import torch
import os

from tqdm import tqdm
from pathlib import Path
from face_parser import FaceParser
from face_detector import FaceDetector

DEVICE = "cuda"
DTYPE  = torch.bfloat16

face_detector = FaceDetector(models_dir="./")
id_transfer = IDTransfer().to(device=DEVICE, dtype=DTYPE)
parser = FaceParser()

In [None]:
def processFrame(frame, enhancer = None, detectiom_res: int = 512, model_res: int = 224, num_pixel_shifts: int = 2):

    process_res = model_res * num_pixel_shifts
    out_frame = frame
    detected_face, M = face_detector.get_face(frame=frame, image_size=detectiom_res, zoom_out_factor=1)

    if M is not None:
        detected_face = cv2.resize(detected_face, (process_res, process_res))
        input_face = torch.tensor(detected_face, device=DEVICE, dtype=DTYPE).permute(2, 0, 1)
        input_face = input_face.reshape(3, model_res, num_pixel_shifts, model_res, num_pixel_shifts).permute(2,4,0,1,3)
        input_face = input_face.reshape(num_pixel_shifts*num_pixel_shifts, 3, model_res, model_res) / 255.

        with torch.no_grad():
            pred_img, pred_msk = id_transfer(input_face)

        pred_img = pred_img.view(num_pixel_shifts, num_pixel_shifts, 3, model_res, model_res).permute(2,3,0,4,1).reshape(3,process_res, process_res)
        pred_msk = pred_msk.view(num_pixel_shifts, num_pixel_shifts, 1, model_res, model_res).permute(2,3,0,4,1).reshape(1,process_res, process_res)

        pred_img = pred_img                .to(device="cpu", dtype=torch.float32).permute(1, 2, 0).numpy()
        pred_msk = pred_msk.repeat(3, 1, 1).to(device="cpu", dtype=torch.float32).permute(1, 2, 0).numpy()

        # pred_msk = cv2.dilate(pred_msk, np.ones((21,21), np.uint8), iterations=1)
        pred_msk = cv2.GaussianBlur(pred_msk, (11, 11), 0)

        pred_img = reinhard_color_transfer(pred_img, (detected_face/255.).astype(np.float32), target_mask=pred_msk, source_mask=pred_msk)
        pred_img = pred_img * 255.

        if enhancer:
            pred_img = enhancer.enhance(pred_img, blend=0.4)

        pred_img = cv2.resize(pred_img, (detectiom_res, detectiom_res))
        pred_msk = cv2.resize(pred_msk, (detectiom_res, detectiom_res))
        out_mask  = cv2.warpAffine(pred_msk, cv2.invertAffineTransform(M), dsize=frame.shape[1::-1], borderValue=(0,0,0))
        out_frame = cv2.warpAffine(pred_img, cv2.invertAffineTransform(M), dsize=frame.shape[1::-1])
        out_frame = (out_mask*out_frame + (1-out_mask)*frame).astype(np.uint8)

    return out_frame


def process(input_video: Path, output_path: Path, enhancer, detectiom_res: int = 512, model_res: int = 224, num_pixel_shifts: int = 2):

    input_stream  = cv2.VideoCapture(filename=str(input_video))
    output_stream = cv2.VideoWriter(
        filename  = str(output_path), 
        fourcc    = cv2.VideoWriter_fourcc(*"mp4v"), 
        fps       = int(input_stream.get(cv2.CAP_PROP_FPS)),
        frameSize = (int(input_stream.get(cv2.CAP_PROP_FRAME_WIDTH)), int(input_stream.get(cv2.CAP_PROP_FRAME_HEIGHT)))
        )

    total_frames = int(input_stream.get(cv2.CAP_PROP_FRAME_COUNT))

    for _ in tqdm(range(total_frames)):
        ret, frame = input_stream.read()
        if not ret: break
        frame = processFrame(frame, enhancer, detectiom_res, model_res, num_pixel_shifts)
        output_stream.write(frame)
    
    output_stream = None
    input_stream.release()

    subprocess.run(f"ffmpeg -i {str(output_path)} -i {str(input_video)} -shortest -c copy out.mp4 -y")
    os.remove(str(output_path))
    Path("out.mp4").rename(output_path)

In [None]:
for video in Path(r"...").iterdir():
    process(
        input_video=video,
        output_path=Path("...")/video.name,
        enhancer=None,
        detectiom_res=512,
        model_res=224,
        num_pixel_shifts=1
    )