In [None]:
import tqdm

def nop(it, *a, **k):
    return it

real_tqdm = tqdm.tqdm
tqdm.tqdm = nop

import time
import os
import glob
import pickle

import numpy as np
np.bool = np.bool_
import cv2
import torch
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import torch.nn.functional as F

from utils.inference.image_processing import crop_face, get_final_image, show_images, normalize_and_torch, normalize_and_torch_batch
from utils.inference.video_processing import read_video, get_target, get_final_video, add_audio_from_another_video, face_enhancement, crop_frames_and_get_transforms, resize_frames
from utils.inference.core import model_inference, transform_target_to_torch
from utils.inference.faceshifter_run import faceshifter_batch
from network.AEI_Net import AEI_Net
from coordinate_reg.image_infer import Handler
from insightface_func.face_detect_crop_multi import Face_detect_crop
from arcface_model.iresnet import iresnet100
from models.pix2pix_model import Pix2PixModel
from models.config_sr import TestOptions





### Load Models

In [None]:
app = Face_detect_crop(name='antelope', root='./insightface_func/models')
app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640))

# main model for generation
G = AEI_Net(backbone='unet', num_blocks=2, c_id=512)
G.eval()
G.load_state_dict(torch.load('weights/G_unet_2blocks.pth', map_location=torch.device('cpu')))
G = G.cuda()
G = G.half()

# arcface model to get face embedding
netArc = iresnet100(fp16=False)
netArc.load_state_dict(torch.load('arcface_model/backbone.pth'))
netArc=netArc.cuda()
netArc.eval()

# model to get face landmarks
handler = Handler('./coordinate_reg/model/2d106det', 0, ctx_id=0, det_size=640)

# model to make superres of face, set use_sr=True if you want to use super resolution or use_sr=False if you don't
use_sr = True
if use_sr:
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    torch.backends.cudnn.benchmark = True
    opt = TestOptions()
    #opt.which_epoch ='10_7'
    model = Pix2PixModel(opt)
    model.netG.train()

### Set here path to source image and video for faceswap

# 특징 빼기 + 인젝션

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

from utils.inference.faceshifter_run import faceshifter_batch
from utils.inference.image_processing import crop_face, normalize_and_torch, normalize_and_torch_batch
from utils.inference.video_processing import read_video, crop_frames_and_get_transforms, resize_frames
from utils.inference.core import transform_target_to_torch

image_to_image = True

set_target = False
half=True
similarity_th=0.15
"""
choose not really long videos, coz it can take a lot of time othervise 
choose source image as a photo -- preferable a selfie of a person
"""
if image_to_image:
    path_to_target = 'examples/images/jaeseung.jpg'
else:
    #path_to_video = 'examples/videos/random_gif.gif'
    path_to_video = "examples/videos/01__hugging_happy.mp4"
#source_full = cv2.imread('examples/images/elon_musk.jpg')
source_full = cv2.imread('examples/images/Bob-Ross.webp')
OUT_VIDEO_NAME = "examples/results/result_tmp.mp4"
crop_size = 224 # don't change this
BS = 60
# check, if we can detect face on the source image

try:    
    source = crop_face(source_full, app, crop_size)[0]
    print(source.shape)
    source = [source[:, :, ::-1]]
    print("Everything is ok!")
except TypeError:
    print("Bad source images")
    
# read video and find target image in the video that contains at least 1 face

if image_to_image:
    target_full = cv2.imread(path_to_target)
    full_frames = [target_full]
else:
    full_frames, fps = read_video(path_to_video)
target = get_target(full_frames, app, crop_size)

target_norm = normalize_and_torch_batch(np.array(target))
target_embeds = netArc(F.interpolate(target_norm, scale_factor=0.5, mode='bilinear', align_corners=True))

# Get the cropped faces from original frames and transformations to get those crops
crop_frames_list, tfm_array_list = crop_frames_and_get_transforms(full_frames, target_embeds, app, netArc, crop_size, set_target, similarity_th=similarity_th)


In [None]:
set_target = False
half=True
similarity_th=0.15


target_norm = normalize_and_torch_batch(np.array(target))
target_embeds = netArc(F.interpolate(target_norm, scale_factor=0.5, mode='bilinear', align_corners=True))

# Get the cropped faces from original frames and transformations to get those crops
crop_frames_list, tfm_array_list = crop_frames_and_get_transforms(full_frames, target_embeds, app, netArc, crop_size, set_target, similarity_th=similarity_th)

# Normalize source images and transform to torch and get Arcface embeddings
source_embeds = []
for source_curr in source:
    source_curr = normalize_and_torch(source_curr)
    source_embeds.append(netArc(F.interpolate(source_curr, scale_factor=0.5, mode='bilinear', align_corners=True)))

print(source_embeds[0].shape)

print(crop_frames_list[0][0].shape, crop_frames_list[0][0].dtype)

plt.imshow(crop_frames_list[0][0])
plt.show()

final_frames_list = []
for idx, (crop_frames, tfm_array, source_embed) in enumerate(zip(crop_frames_list, tfm_array_list, source_embeds)):
    # Resize croped frames and get vector which shows on which frames there were faces
    resized_frs, present = resize_frames(crop_frames)
    resized_frs = np.array(resized_frs)

    # transform embeds of Xs and target frames to use by model
    target_batch_rs = transform_target_to_torch(resized_frs, half=half)
    print(target_batch_rs.shape)
    #assert False
    if half:
        source_embed = source_embed.half()

    # run model
    size = target_batch_rs.shape[0]
    model_output = []

    for i in tqdm(range(0, size, BS)):
        Y_st = faceshifter_batch(source_embed, target_batch_rs[i:i+BS], G)
        model_output.append(Y_st)
    torch.cuda.empty_cache()
    model_output = np.concatenate(model_output)

    # create list of final frames with transformed faces
    final_frames = []
    idx_fs = 0

    for pres in tqdm(present):
        if pres == 1:
            final_frames.append(model_output[idx_fs])
            idx_fs += 1
        else:
            final_frames.append([])
    final_frames_list.append(final_frames)
    assert False


In [None]:
final_frames[0].shape

plt.imshow(Y_st[0])
plt.show()

# pca

In [None]:
with open("netArcPCA2.pkl", "rb") as file:
    pca = pickle.load(file)


In [None]:
from utils.inference.core import transform_target_to_torch

#frame = cv2.imread("../vggFace2_Train/n000002/0002_01.jpg")

#crop_frames_list = [[cv2.resize(frame, (crop_size, crop_size))]]

print(crop_frames_list[0][0].shape)
plt.imshow(crop_frames_list[0][0][:,:,::-1])
plt.show()


resized_frs, present = resize_frames(crop_frames_list[0])
resized_frs = np.array(resized_frs)

target_batch_rs = transform_target_to_torch(resized_frs, half=True)


"""
source_full = cv2.imread(pic)
#print(source_full.shape)
#source = crop_face(source_full, app, crop_size)[0]
source_curr = cv2.resize(source_full[:, :, ::-1], (crop_size, crop_size))
source_curr = normalize_and_torch(source_curr)
#print(source_curr.shape)
source_embed = netArc(F.interpolate(source_curr, scale_factor=0.5, mode='bilinear', align_corners=True))
source_embed = source_embed.half()

"""
start_id = 7
tmp_embed = [np.load("./embeds/{}.npy".format(start_id))]
pca_array = pca.transform(tmp_embed)

modified_embed = pca.inverse_transform(pca_array)

source_embed = torch.from_numpy(modified_embed).half().to("cuda")
#source_embed[0][0] = 10.0

Y_st = faceshifter_batch(source_embed, target_batch_rs, G)
torch.cuda.empty_cache()

plt.imshow(Y_st[0][:, :, ::-1])
plt.show()

In [None]:
from insightface.utils import face_align
import ipywidgets as widgets

set_target = False
half=True
similarity_th=0.15
set_target = False
half=True
similarity_th=0.15

crop_frames_list = None
target_batch_rs = None
start_id = 0
tmp_embed = [np.load("./embeds/{}.npy".format(start_id))]
pca_array = pca.transform(tmp_embed)
with open("netArcPCA2MinMax.pkl", "rb") as file:
    pca_minmax = pickle.load(file)
pca_min = pca_minmax["min"]
pca_max = pca_minmax["max"]

def set_target(path_to_target='examples/images/jaeseung.jpg'):
    global crop_frames_list, target_batch_rs
    if image_to_image:
        target_full = cv2.imread(path_to_target)
        full_frames = [target_full]
    else:
        full_frames, fps = read_video(path_to_video)
    target = get_target(full_frames, app, crop_size)

    target_norm = normalize_and_torch_batch(np.array(target))
    target_embeds = netArc(F.interpolate(target_norm, scale_factor=0.5, mode='bilinear', align_corners=True))

    # Get the cropped faces from original frames and transformations to get those crops
    crop_frames_list, tfm_array_list = crop_frames_and_get_transforms(full_frames,
                                                                    target_embeds,
                                                                    app,
                                                                    netArc,
                                                                    crop_size,
                                                                    set_target,
                                                                    similarity_th=similarity_th
                                                                    )
    resized_frs, present = resize_frames(crop_frames_list[0])
    resized_frs = np.array(resized_frs)

    target_batch_rs = transform_target_to_torch(resized_frs, half=True)


set_target()


def inject_drawing():
    global crop_frames_list, pca_array
    plt.figure(num=1, clear=True, figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.imshow(crop_frames_list[0][0][:,:,::-1])
    plt.title("Target Face")
    plt.subplot(1, 2, 2)

    modified_embed = pca.inverse_transform(pca_array)

    source_embed = torch.from_numpy(modified_embed).half().to("cuda")
    #source_embed[0][0] = 10.0

    Y_st = faceshifter_batch(source_embed, target_batch_rs, G)
    torch.cuda.empty_cache()

    plt.imshow(Y_st[0][:, :, ::-1])
    plt.title("After Swap")
    plt.show()

pca_sliders = []

iact_plot = widgets.interactive(
    inject_drawing
)

for i in range(40):
    pca_slider = widgets.FloatSlider(
        value=pca_array[0, i], 
        min = pca_min[i],
        max = pca_max[i],
        description=f"PCA #{i}",
        continuous_update=False,
        layout=widgets.Layout(width="300px"),
    )
    pca_slider.idx = i
    pca_slider.observe(lambda change: exec("pca_array[0, change.owner.idx]=change.new;iact_plot.update()"), names="value")
    #pca_slider.observe(lambda change: print(change.owner.idx), names="value")
    pca_sliders.append(pca_slider)
    

slider_multibox = widgets.HBox(
    children=[
    widgets.VBox(
    children = pca_sliders[:20]),
    widgets.VBox(
    children = pca_sliders[20:]),
    ]
)


display(iact_plot, slider_multibox)


# PCA 값에 따른 변화 그림

In [None]:
set_target = False
half=True
similarity_th=0.15
set_target = False
half=True
similarity_th=0.15

crop_frames_list = None
target_batch_rs = None
with open("netArcPCA2.pkl", "rb") as file:
    pca = pickle.load(file)

with open("netArcPCA2MinMax.pkl", "rb") as file:
    pca_minmax = pickle.load(file)
pca_min = pca_minmax["min"]
pca_max = pca_minmax["max"]

start_id = 0
#tmp_embed = [np.load("./embeds/{}.npy".format(start_id))]
#pca_array = pca.transform(tmp_embed)

tmp_embed = [((0.3+0.7*np.random.random(pca_min.shape[0]))*(pca_max-pca_min)+pca_min)]
#tmp_embed = [(pca_min+pca_max)/2]
pca_array = pca.transform(tmp_embed)


def set_target(path_to_target='examples/images/jaeseung.jpg'):
    global crop_frames_list, target_batch_rs
    if image_to_image:
        target_full = cv2.imread(path_to_target)
        full_frames = [target_full]
    else:
        full_frames, fps = read_video(path_to_video)
    target = get_target(full_frames, app, crop_size)

    target_norm = normalize_and_torch_batch(np.array(target))
    target_embeds = netArc(F.interpolate(target_norm, scale_factor=0.5, mode='bilinear', align_corners=True))

    # Get the cropped faces from original frames and transformations to get those crops
    crop_frames_list, tfm_array_list = crop_frames_and_get_transforms(full_frames,
                                                                    target_embeds,
                                                                    app,
                                                                    netArc,
                                                                    crop_size,
                                                                    set_target,
                                                                    similarity_th=similarity_th
                                                                    )
    resized_frs, present = resize_frames(crop_frames_list[0])
    resized_frs = np.array(resized_frs)

    target_batch_rs = transform_target_to_torch(resized_frs, half=True)


set_target(path_to_target="examples/images/jaeseung.jpg")

num_plotted_pca = 40
n_ticks = 5

plt.figure(num=1, clear=True, figsize=(n_ticks*2, num_plotted_pca*2))
plt.subplot(num_plotted_pca+1,n_ticks, 1)
plt.imshow(crop_frames_list[0][0][:,:,::-1])
plt.title("Target Face")
plt.axis("off")
plt.subplot(num_plotted_pca+1,n_ticks,2)

modified_embed = pca.inverse_transform(pca_array)

source_embed = torch.from_numpy(modified_embed).half().to("cuda")
#source_embed[0][0] = 10.0

Y_st = faceshifter_batch(source_embed, target_batch_rs, G)
torch.cuda.empty_cache()

plt.imshow(Y_st[0][:, :, ::-1])
plt.title("Swapped (Random)")
#plt.axis("off")
for side in ["top", "right", "bottom", "left"]: plt.gca().spines[side].set_visible(False)
plt.xticks([])
plt.yticks([])



for pci_i in range(num_plotted_pca):
    interval = (pca_max[pci_i] - pca_min[pci_i]) / (n_ticks-1)
    for c in range(n_ticks):
        plt.subplot(num_plotted_pca+1,n_ticks, (pci_i+1)*n_ticks+c+1)
        new_pca_array = pca_array.copy()
        new_pca_array[0, pci_i] = pca_min[pci_i] + (interval * c)
        modified_embed = pca.inverse_transform(new_pca_array)

        source_embed = torch.from_numpy(modified_embed).half().to("cuda")
        #source_embed[0][0] = 10.0

        Y_st = faceshifter_batch(source_embed, target_batch_rs, G)
        torch.cuda.empty_cache()

        plt.imshow(Y_st[0][:, :, ::-1])
        #plt.title("Swapped Face ")
        #plt.axis("off")
        for side in ["top", "right", "bottom", "left"]: plt.gca().spines[side].set_visible(False)
        plt.xticks([])
        plt.yticks([])
        plt.xlabel(f"{new_pca_array[0, pci_i]:.2f}")

        if c == 0:
            plt.ylabel("PCA #{}".format(pci_i))



In [None]:
# for comparison

set_target = False
half=True
similarity_th=0.15
crop_size = 224 # don't change this
BS = 60


crop_frames_list = None
target_batch_rs = None
with open("netArcPCA2.pkl", "rb") as file:
    pca = pickle.load(file)

with open("netArcPCA2MinMax.pkl", "rb") as file:
    pca_minmax = pickle.load(file)
pca_min = pca_minmax["min"]
pca_max = pca_minmax["max"]

start_id = 0
#tmp_embed = [np.load("./embeds/{}.npy".format(start_id))]
#pca_array = pca.transform(tmp_embed)

tmp_embed = [((0.4+0.6*np.random.random(pca_min.shape[0]))*(pca_max-pca_min)+pca_min)]
#tmp_embed = [(pca_min+pca_max)/2]
pca_array = pca.transform(tmp_embed)


def set_target(path_to_target='examples/images/jaeseung.jpg'):
    global crop_frames_list, target_batch_rs
    target_full = cv2.imread(path_to_target)
    full_frames = [target_full]
    target = get_target(full_frames, app, crop_size)

    target_norm = normalize_and_torch_batch(np.array(target))
    target_embeds = netArc(F.interpolate(target_norm, scale_factor=0.5, mode='bilinear', align_corners=True))

    # Get the cropped faces from original frames and transformations to get those crops
    crop_frames_list, tfm_array_list = crop_frames_and_get_transforms(full_frames,
                                                                    target_embeds,
                                                                    app,
                                                                    netArc,
                                                                    crop_size,
                                                                    set_target,
                                                                    similarity_th=similarity_th
                                                                    )
    resized_frs, present = resize_frames(crop_frames_list[0])
    resized_frs = np.array(resized_frs)

    target_batch_rs = transform_target_to_torch(resized_frs, half=True)


set_target(path_to_target="examples/images/jaeseung.jpg")

num_plotted_pca = 20
n_ticks = 7




plt.figure(num=1, clear=True, figsize=(n_ticks*2, (num_plotted_pca*3+1)*2))

def draw_top(subplot_tuple, img):
    plt.subplot(*subplot_tuple)
    plt.imshow(img)
    plt.title("Target Face")
    plt.axis("off")


def draw_swap(subplot_tuple, pca_embed, is_top=False):
    global target_batch_rs, G
    plt.subplot(*subplot_tuple)

    modified_embed = pca.inverse_transform(pca_embed)
    source_embed = torch.from_numpy(modified_embed).half().to("cuda")

    Y_st = faceshifter_batch(source_embed, target_batch_rs, G)
    torch.cuda.empty_cache()
    plt.imshow(Y_st[0][:, :, ::-1])
    
    for side in ["top", "right", "bottom", "left"]: plt.gca().spines[side].set_visible(False)
    plt.xticks([])
    plt.yticks([])
    if is_top:
        plt.title("Swapped (Random)")


set_target(path_to_target="examples/images/jaeseung.jpg")
draw_top((num_plotted_pca*3+1, n_ticks, 1), crop_frames_list[0][0][:,:,::-1])
draw_swap((num_plotted_pca*3+1, n_ticks, 2), pca_array, is_top=True)

set_target(path_to_target="examples/images/elon_musk.jpg")
draw_top((num_plotted_pca*3+1, n_ticks, 3), crop_frames_list[0][0][:,:,::-1])
draw_swap((num_plotted_pca*3+1, n_ticks, 4), pca_array, is_top=True)

set_target(path_to_target="examples/images/tgt2.png")
draw_top((num_plotted_pca*3+1, n_ticks, 5), crop_frames_list[0][0][:,:,::-1])
draw_swap((num_plotted_pca*3+1, n_ticks, 6), pca_array, is_top=True)

for pci_i in range(num_plotted_pca):
    interval = (pca_max[pci_i] - pca_min[pci_i]) / (n_ticks-1)
    for c in range(n_ticks):
        new_pca_array = pca_array.copy()
        new_pca_array[0, pci_i] = pca_min[pci_i] + (interval * c)
        
        set_target(path_to_target="examples/images/jaeseung.jpg")
        draw_swap((num_plotted_pca*3+1, n_ticks, (pci_i*3+1)*n_ticks+c+1), new_pca_array)
        plt.xlabel(f"{new_pca_array[0, pci_i]:.2f}")
        if c == 0:
            plt.ylabel("PCA #{}".format(pci_i))
        set_target(path_to_target="examples/images/elon_musk.jpg")
        draw_swap((num_plotted_pca*3+1, n_ticks, (pci_i*3+2)*n_ticks+c+1), new_pca_array)
        plt.xlabel(f"{new_pca_array[0, pci_i]:.2f}")
        if c == 0:
            plt.ylabel("PCA #{}".format(pci_i))
        set_target(path_to_target="examples/images/tgt2.png")
        draw_swap((num_plotted_pca*3+1, n_ticks, (pci_i*3+3)*n_ticks+c+1), new_pca_array)
        plt.xlabel(f"{new_pca_array[0, pci_i]:.2f}")
        if c == 0:
            plt.ylabel("PCA #{}".format(pci_i))

plt.show()