In [2]:
import tqdm

def nop(it, *a, **k):
    return it

real_tqdm = tqdm.tqdm
tqdm.tqdm = nop

import time
import os
import glob
import pickle
from typing import Union

import numpy as np
np.bool = np.bool_
import cv2
import torch
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import torch.nn.functional as F

from utils.inference.image_processing import crop_face, get_final_image, show_images, normalize_and_torch, normalize_and_torch_batch
from utils.inference.video_processing import read_video, get_target, get_final_video, add_audio_from_another_video, face_enhancement, crop_frames_and_get_transforms, resize_frames
from utils.inference.core import model_inference, transform_target_to_torch
from utils.inference.faceshifter_run import faceshifter_batch, faceshifter_batch_zattrs
from network.AEI_Net import AEI_Net
from coordinate_reg.image_infer import Handler
from insightface_func.face_detect_crop_multi import Face_detect_crop
from arcface_model.iresnet import iresnet100
from models.pix2pix_model import Pix2PixModel
from models.config_sr import TestOptions

  mirr = onp.mirr
  npv = onp.npv
  pmt = onp.pmt
  ppmt = onp.ppmt
  pv = onp.pv
  rate = onp.rate
  cur_np_ver = LooseVersion(_np.__version__)
  np_1_17_ver = LooseVersion('1.17')
  cur_np_ver = LooseVersion(_np.__version__)
  np_1_15_ver = LooseVersion('1.15')


In [None]:
class DoublePCA:
    def __init__(self, layer_num=8, root_path="./pca_pkl"):
        self.layer_num = layer_num

        self.pca1_list = []
        for z_i in range(self.layer_num):
            with open(f"{root_path}/Altered_zattr{z_i}PCA.pkl", "rb") as file:
                self.pca1_list.append(pickle.load(file))
        with open(f"{root_path}/Altered_zattr_doublePCA.pkl", "rb") as file:
            self.pca2 = pickle.load(file)
        with open(f"{root_path}/Altered_zattr_doublePCAMinMax.pkl", "rb") as file:
            minmax = pickle.load(file)
        self.pca2_min = minmax["min"]
        self.pca2_max = minmax["max"]
    def transform(self, z_embeds):
        p1emb_array = [self.pca1_list[z_i].transform(z_embeds[z_i]) for z_i in range(self.layer_num)]
        p1emb_array = np.concatenate(p1emb_array, axis=1)
        return self.pca2.transform(p1emb_array)
    def inverse_transform(self, p2emb_array):
        p1emb_array = self.pca2.inverse_transform(p2emb_array).reshape([-1, self.layer_num, 128])
        z_embeds = [self.pca1_list[z_i].inverse_transform(p1emb_array[:,z_i]) for z_i in range(self.layer_num)]
        return z_embeds
    def calcul_z_embed_diff(self, target, original):
        assert target.shape[0] == original.shape[0]
        target = self.inverse_transform(target)
        original = self.inverse_transform(original)
        return [target[i]-original[i] for i in range(self.layer_num)]



class FaceSwapper():
    def __init__(self):
        self.crop_size = 224 

        self.app = Face_detect_crop(name='antelope', root='./insightface_func/models')
        self.app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640))

        # main model for generation
        self.G = AEI_Net(backbone='unet', num_blocks=2, c_id=512)
        self.G.eval()
        self.G.load_state_dict(torch.load('weights/G_unet_2blocks.pth', map_location=torch.device('cpu')))
        self.G = G.cuda()
        self.G = G.half()

        # arcface model to get face embedding
        self.netArc = iresnet100(fp16=False)
        self.netArc.load_state_dict(torch.load('arcface_model/backbone.pth'))
        self.netArc=netArc.cuda()
        self.netArc.eval()

        # model to get face landmarks
        self.handler = Handler('./coordinate_reg/model/2d106det', 0, ctx_id=0, det_size=640)

        # model to make superres of face, set use_sr=True if you want to use super resolution or use_sr=False if you don't
        
        os.environ['CUDA_VISIBLE_DEVICES'] = '0'
        torch.backends.cudnn.benchmark = True
        opt = TestOptions()
        #opt.which_epoch ='10_7'
        self.model = Pix2PixModel(opt)
        self.model.netG.train()
        
        # 중간 결과 임시저장
        self.source = None
        self.full_frames = None
        self.orig_target_embed = None
        self.final_z_outputs = None
        self.crop_frames_list = None
        self.tfm_array_list = None
        
    
    def swap(self, source: Union[np.ndarray, str], target: Union[np.ndarray, str], is_tgt_video=False, BS = 60):
        """ # TODO

        source와 target은 imread의 출력, 즉 bgr ndarray 입력으로 간주한다.
        단, target이 영상인 경우에는 [t, H, W, C]꼴의 ndarray 리스트 입력으로 생각한다

        """
        
        if isinstance(source, str):
            source = cv2.imread(source)
        self.source = source
        source_curr = normalize_and_torch(self.source)

        if isinstance(target, str):
            if is_tgt_video:
                full_frames, fps = read_video(tgt)
            else:
                target_full = cv2.imread(target)
                full_frames = [target_full]
        else:
            
            full_frames = target if is_tgt_video else [target]
            else:
                full_frames = target
        self.full_frames = full_frames

        cropped_target = get_target(full_frames, self.app, self.crop_size)
        target_norm = normalize_and_torch_batch(np.array(cropped_target))
        target_embeds = netArc(F.interpolate(target_norm, scale_factor=0.5, mode='bilinear', align_corners=True))
        crop_frames_list, tfm_array_list = crop_frames_and_get_transforms(full_frames,
                                                                    target_embeds,
                                                                    self.app,
                                                                    self.netArc,
                                                                    self.crop_size,
                                                                    set_target=False,
                                                                    similarity_th=0.15
                                                                    )
        self.crop_frames_list = crop_frames_list
        self.tfm_array_list = tfm_array_list

        source_embeds = []
        for source_curr in source:
            source_curr = normalize_and_torch(source_curr)
            source_embeds.append(netArc(F.interpolate(source_curr, scale_factor=0.5, mode='bilinear', align_corners=True)))

        final_frames_list = []
        self.final_z_outputs = []
        for idx, (crop_frames, tfm_array, source_embed) in enumerate(zip(self.crop_frames_list, self.tfm_array_list, source_embeds)):
            # Resize croped frames and get vector which shows on which frames there were faces
            resized_frs, present = resize_frames(crop_frames)
            resized_frs = np.array(resized_frs)

            # transform embeds of Xs and target frames to use by model
            target_batch_rs = transform_target_to_torch(resized_frs, half=half)
            #assert False
            source_embed = source_embed.half()

            # run model
            size = target_batch_rs.shape[0]
            model_output = []
            z_output = []
            for i in range(0, size, BS):
                zattrs = self.G.get_attr(target_batch_rs[i:i+BS])
                Y_st = faceshifter_batch_zattrs(source_embed, zattrs, BS, self.G)
                model_output.append(Y_st)
                z_output.append(zattrs)
            torch.cuda.empty_cache()
            model_output = np.concatenate(model_output)
            z_output = torch.concat(z_output,dim=0)

            # create list of final frames with transformed faces
            final_frames = []
            final_zs = []
            idx_fs = 0
            for pres in present:
                if pres == 1:
                    final_frames.append(model_output[idx_fs])
                    final_zs.append(z_output[idx_fs])
                    idx_fs += 1
                else:
                    final_frames.append([])
                    final_zs.append([])
            final_frames_list.append(final_frames)
            self.final_z_outputs.append(final_zs)

        final_frames_list = face_enhancement(final_frames_list, self.model)
        
        if is_tgt_video:
            assert False, "not implemented"
        else:
            result = get_final_image(final_frames_list, crop_frames_list, self.full_frames[0], self.tfm_array_list, self.handler)
        return result



In [19]:
isinstance("", (str, list))

True