In [1]:
!git clone https://github.com/yachty66/headswap.git
%cd headswap/leslie_headswap
!pip install -r requirements.txt
!python downloader.py

Cloning into 'headswap'...
remote: Enumerating objects: 352, done.[K
remote: Counting objects: 100% (288/288), done.[K
remote: Compressing objects: 100% (240/240), done.[K
remote: Total 352 (delta 46), reused 268 (delta 39), pack-reused 64 (from 1)[K
Receiving objects: 100% (352/352), 28.31 MiB | 16.77 MiB/s, done.
Resolving deltas: 100% (47/47), done.
/content/headswap/leslie_headswap
Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu117
Collecting onnxruntime-gpu (from -r requirements.txt (line 7))
  Downloading onnxruntime_gpu-1.20.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Collecting face-alignment (from -r requirements.txt (line 8))
  Downloading face_alignment-1.4.1-py2.py3-none-any.whl.metadata (7.4 kB)
Collecting coloredlogs (from onnxruntime-gpu->-r requirements.txt (line 7))
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime-g

In [2]:
#(must be absolute path)
path_to_image_from_where_to_steal_face_from = "/content/headswap/images/source.png"
path_to_image_to_put_face_on = "/content/headswap/images/target.png"

In [3]:
#inference leslie model
from model.AlignModule.generator import FaceGenerator
from model.BlendModule.generator import Generator as Decoder
from model.AlignModule.config import Params as AlignParams
from model.BlendModule.config import Params as BlendParams
from model.third.faceParsing.model import BiSeNet
import torchvision.transforms.functional as TF
import torch.nn.functional as F
import torch
import cv2
import numpy as np
import pdb
from process.process_func import Process
from process.process_utils import *
import os
import onnxruntime as ort
from utils.utils import color_transfer2

class Infer(Process):
    def __init__(self,align_path,blend_path,parsing_path,params_path,bfm_folder):
        Process.__init__(self,params_path,bfm_folder)
        align_params = AlignParams()
        blend_params = BlendParams()
        self.device = 'cpu'
        if torch.cuda.is_available():
            self.device = 'cuda'

        self.parsing = BiSeNet(n_classes=19).to(self.device)

        self.netG = FaceGenerator(align_params).to(self.device)
        self.decoder = Decoder(blend_params).to(self.device)

        self.loadModel(align_path,blend_path,parsing_path)
        self.eval_model(self.netG,self.decoder,self.parsing)


        self.ort_session_sr = ort.InferenceSession('./pretrained_models/sr_cf.onnx', providers=['CPUExecutionProvider'])

    def run(self,src_img_path_list,tgt_img_path_list,save_base,crop_align=False,cat=False):
        os.makedirs(save_base,exist_ok=True)
        i = 0
        for src_img_path,tgt_img_path in zip(src_img_path_list,tgt_img_path_list):
            gen = self.run_single(src_img_path,tgt_img_path,crop_align=crop_align,cat=cat)
            img_name = os.path.splitext(os.path.basename(src_img_path))[0]+'-' + \
                        os.path.splitext(os.path.basename(tgt_img_path))[0]+'.png'
            cv2.imwrite(os.path.join(save_base,img_name),gen)
            print('\rhave done %04d'%i,end='',flush=True)
            i += 1
        print()
    def run_single(self,src_img_path,tgt_img_path,crop_align=False,cat=False):

        tgt_img = cv2.imread(tgt_img_path)
        tgt_align = tgt_img.copy()

        tgt_align,info = self.preprocess_align(tgt_img)
        if tgt_align is None:
            return None

        src_img = cv2.imread(src_img_path)
        src_align = src_img
        if crop_align:
            src_align,_ = self.preprocess_align(src_img,top_scale=0.55)

        src_inp = self.preprocess(src_align)
        tgt_inp = self.preprocess(tgt_align)

        tgt_params = self.get_params(cv2.resize(tgt_align,(256,256)),
                                info['rotated_lmk']/2.0).unsqueeze(0)

        gen = self.forward(src_inp,tgt_inp,tgt_params)

        gen = self.postprocess(gen[0])
        gen = self.run_sr(gen)
        mask = self.mask
        final = gen
        # gen = color_transfer2(tgt_align,gen)

        RotateMatrix = info['im'][:2]
        mask = info['mask'][...,0]

        rotate_gen = cv2.warpAffine(gen, RotateMatrix, (tgt_img.shape[1], tgt_img.shape[0]))
        mask = cv2.warpAffine(mask, RotateMatrix, (tgt_img.shape[1], tgt_img.shape[0])) * 1.0

        # ori_mask = mask.copy()
        kernel2 = cv2.getStructuringElement(cv2.MORPH_RECT,(17, 17))
        # mask = cv2.dilate(mask*1.0,kernel2)
        mask = cv2.erode(mask*1.0,kernel2)
        # mask = cv2.GaussianBlur(mask*255.0, (21, 21), 0) / 255.0
        mask = cv2.blur(mask*1.0, (15, 15), 0) / 255.0
        mask = np.clip(mask,0,1.0)[:,:,np.newaxis]

        # pdb.set_trace()
        final = rotate_gen * mask + tgt_img * (1-mask)

        if cat:
            final = np.concatenate([tgt_img,final],1)
            final[-256:,:256] = cv2.resize(src_align,(256,256))

        return final

    def forward(self,xs,xt,params):
        with torch.no_grad():

            # xg = self.netG(F.adaptive_avg_pool2d(xs,256),
            #                 F.adaptive_avg_pool2d(xt,256),
            #                 params)['fake_image']
            xg = F.adaptive_avg_pool2d(self.netG(F.adaptive_avg_pool2d(xs,256),
                            F.adaptive_avg_pool2d(xt,256),
                            params)['fake_image'],512)


            M_a = self.parsing(self.preprocess_parsing(xg))

            M_t = self.parsing(self.preprocess_parsing(xt))

            M_a = self.postprocess_parsing(M_a)
            M_t = self.postprocess_parsing(M_t)
            # xg[M_a.repeat(1,3,1,1)==0] = -0.5
            # xg[M_a.repeat(1,3,1,1)==16] = 0.6
            xg_gray = TF.rgb_to_grayscale(xg,num_output_channels=1)
            fake = self.decoder(xg,xg_gray,xt,M_a,M_t,xt,train=False)


            gen_mask = self.parsing(self.preprocess_parsing(fake))
            gen_mask = self.postprocess_parsing(gen_mask)
            gen_mask = gen_mask[0][0].cpu().numpy()
            mask_t = M_t[0][0].cpu().numpy()
            mask = np.zeros_like(gen_mask)
            for i in [1,2,3,4,5,6,7,8,9,10,11,12,13,17,18]:
                mask[gen_mask==i] = 1.0
                mask[mask_t==i] = 1.0

            self.mask = mask
        return fake

    def run_sr(self,input_np):
        input_np = cv2.cvtColor(input_np, cv2.COLOR_BGR2RGB)
        # prepare data
        input_np = input_np.transpose((2,0,1))
        input_np = np.array(input_np[np.newaxis, :])
        outputs_onnx = self.ort_session_sr.run(None, {'input_image':input_np.astype(np.uint8)})

        out_put_onnx = outputs_onnx[0]
        outimg = out_put_onnx[0,...].transpose(1,2,0)
        outimg = cv2.cvtColor(outimg, cv2.COLOR_BGR2RGB)
        return outimg


    def loadModel(self,align_path,blend_path,parsing_path):
        ckpt = torch.load(align_path, map_location=lambda storage, loc: storage)
        # self.netG.load_state_dict(ckpt['G'])
        self.netG.load_state_dict(ckpt['net_G_ema'])

        ckpt = torch.load(blend_path, map_location=lambda storage, loc: storage)
        self.decoder.load_state_dict(ckpt['G'],strict=False)

        self.parsing.load_state_dict(torch.load(parsing_path))


    def eval_model(self,*args):
        for arg in args:
            arg.eval()



if __name__ == "__main__":
    model = Infer(
                # 'checkpoint/Aligner/058-00008100.pth',
                'pretrained_models/epoch_00190_iteration_000400000_checkpoint.pt',
                'pretrained_models/Blender-401-00012900.pth',
                'pretrained_models/parsing.pth',
                'pretrained_models/epoch_20.pth',
                'pretrained_models/BFM')

    # find_path = lambda x: [os.path.join(x,f) for f in os.listdir(x)]
    # img_paths = find_path('../HeadSwap/test_img')[::-1]

    src_paths = [path_to_image_from_where_to_steal_face_from] #change this with the image you want to steal the face from
    tgt_paths = [path_to_image_to_put_face_on] #change with image you want to put the face on

    model.run(src_paths,tgt_paths,save_base='/content/headswap/images/',crop_align=True,cat=False)

Downloading: "https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth" to /root/.cache/torch/hub/checkpoints/s3fd-619a316812.pth
100%|██████████| 85.7M/85.7M [00:03<00:00, 23.3MB/s]
Downloading: "https://www.adrianbulat.com/downloads/python-fan/3DFAN4-4a694010b9.zip" to /root/.cache/torch/hub/checkpoints/3DFAN4-4a694010b9.zip
100%|██████████| 91.9M/91.9M [00:03<00:00, 25.0MB/s]
Downloading: "https://www.adrianbulat.com/downloads/python-fan/depth-6c4283c0e0.zip" to /root/.cache/torch/hub/checkpoints/depth-6c4283c0e0.zip
100%|██████████| 224M/224M [00:08<00:00, 28.0MB/s]
  self.ParamsModel.load_state_dict(torch.load(params_path)['net_recon'])
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 158MB/s]
  ckpt = torch.load(align_path, map_location=lambda storage, loc: storage)
  ckpt = torch.load(blend_path, map_location=lambda storage, loc: storage)
  s

have done 0000


In [4]:
import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

In [5]:
#install ff
%cd ..
%cd ff
!pip install -r requirements.txt

/content/headswap
/content/headswap/ff
Collecting filetype==1.2.0 (from -r requirements.txt (line 1))
  Downloading filetype-1.2.0-py2.py3-none-any.whl.metadata (6.5 kB)
Collecting gradio==5.9.1 (from -r requirements.txt (line 2))
  Downloading gradio-5.9.1-py3-none-any.whl.metadata (16 kB)
Collecting gradio-rangeslider==0.0.8 (from -r requirements.txt (line 3))
  Downloading gradio_rangeslider-0.0.8-py3-none-any.whl.metadata (10 kB)
Collecting numpy==2.2.0 (from -r requirements.txt (line 4))
  Downloading numpy-2.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.0/62.0 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting onnx==1.17.0 (from -r requirements.txt (line 5))
  Downloading onnx-1.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (16 kB)
Collecting onnxruntime==1.20.1 (from -r requirements.txt (line 6))
  Downloading onnxruntime-1.20.1-cp311-cp311-



In [6]:
#must be absolute path and extensions must be same
!python ff.py headless-run \
    --keep-temp \
    --log-level warn \
    --execution-providers cpu \
    --execution-thread-count 16 \
    --execution-queue-count 2 \
    --temp-frame-format png \
    --face-selector-mode many \
    --face-mask-types box occlusion \
    --processors face_swapper face_enhancer \
    --face-enhancer-model gfpgan_1.4 \
    --face-enhancer-blend 80 \
    --face-swapper-model inswapper_128 \
    --face-detector-model retinaface \
    -s /content/headswap/images/source.png \
    -t /content/headswap/images/source-target.png \
    -o /content/headswap/images/result.png