In [None]:

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

CODE_DIR = '../../sber-swap'
os.chdir(f'./{CODE_DIR}')

In [2]:
import cv2
import torch
import time
import os

from utils.inference.image_processing import crop_face, get_final_image, show_images
from utils.inference.video_processing import read_video, get_target, get_final_video, add_audio_from_another_video, face_enhancement
from utils.inference.core import model_inference

from network.AEI_Net import AEI_Net
from coordinate_reg.image_infer import Handler
from insightface_func.face_detect_crop_multi import Face_detect_crop
from arcface_model.iresnet import iresnet100
from models.pix2pix_model import Pix2PixModel
from models.config_sr import TestOptions



In [3]:
from argparse import Namespace

In [4]:
def load_model(use_sr=False):
    app = Face_detect_crop(name='antelope', root='./insightface_func/models')
    app.prepare(ctx_id= 0, det_thresh=0.6, det_size=(640,640))

    # main model for generation
    G = AEI_Net(backbone='unet', num_blocks=2, c_id=512)
    G.eval()
    G.load_state_dict(torch.load('weights/G_unet_2blocks.pth', map_location=torch.device('cpu')))
    G = G.cuda()
    G = G.half()

    # arcface model to get face embedding
    netArc = iresnet100(fp16=False)
    netArc.load_state_dict(torch.load('arcface_model/backbone.pth'))
    netArc=netArc.cuda()
    netArc.eval()

    # model to get face landmarks
    handler = Handler('./coordinate_reg/model/2d106det', 0, ctx_id=0, det_size=640)

    # model to make superres of face, set use_sr=True if you want to use super resolution or use_sr=False if you don't
    model = None
    if use_sr:
        os.environ['CUDA_VISIBLE_DEVICES'] = '0'
        torch.backends.cudnn.benchmark = True
        opt = TestOptions()
        #opt.which_epoch ='10_7'
        model = Pix2PixModel(opt)
        model.netG.train()
    
    return app, G, netArc, handler, model

In [5]:
image_to_image = True
aligned = True

In [6]:
def run_inference(source_full, target_full, use_sr, aligned=False,):
    # check, if we can detect face on the source image
    crop_size = 224 # don't change this
    BS = 60
    app, G, netArc, handler, model = load_model(use_sr=True)
    try:
        if not aligned:
            source = crop_face(source_full, app, crop_size)[0]
            source = [source]
        else:
            if source_full.shape[0] > 224:
                source = [cv2.resize(source_full, (224, 224))]
            else:
                source = [source_full]
        print("Everything is ok!")
    except TypeError:
        print("Bad source images")
    
    full_frames = [target_full]
    target = get_target(full_frames, app, crop_size)

        
    final_frames_list, crop_frames_list, full_frames, tfm_array_list = model_inference(full_frames,
                                                                                       source,
                                                                                       target,
                                                                                       netArc,
                                                                                       G,
                                                                                       app,
                                                                                       set_target = False,
                                                                                       crop_size=crop_size,
                                                                                       BS=BS)
    if use_sr:
        final_frames_list = face_enhancement(final_frames_list, model)
    result = get_final_image(final_frames_list, crop_frames_list, full_frames[0], tfm_array_list, handler)
    return result

In [7]:
import gradio as gr

# gr.Interface(fn=swap_image, inputs=["image", "image", "text", "text", 'text'], outputs="image").launch(server_name='0.0.0.0')

In [8]:
examples = [
    ['examples/images/320_persona.png', 'examples/images/sample_dst.jpg', True, True],
]

In [9]:
gr.Interface(
    run_inference,
    inputs=[
        gr.inputs.Image(),
        gr.inputs.Image(),
        gr.Checkbox(),
        gr.Checkbox(),
    ],
    outputs=gr.outputs.Image(),
    examples=examples
).launch(server_name='0.0.0.0')
        
    



Running on local URL:  http://localhost:7860/

To create a public link, set `share=True` in `launch()`.


(<gradio.routes.App at 0x7f750c436880>, 'http://localhost:7860/', None)

input mean and std: 127.5 127.5
find model: ./insightface_func/models/antelope/glintr100.onnx recognition
find model: ./insightface_func/models/antelope/scrfd_10g_bnkps.onnx detection
set det-size: (640, 640)
loading ./coordinate_reg/model/2d106det 0
input mean and std: 127.5 127.5
find model: ./insightface_func/models/antelope/glintr100.onnx recognition
find model: ./insightface_func/models/antelope/scrfd_10g_bnkps.onnx detection
set det-size: (640, 640)


[15:36:00] ../src/nnvm/legacy_json_util.cc:208: Loading symbol saved by previous version v1.5.0. Attempting to upgrade...
[15:36:00] ../src/nnvm/legacy_json_util.cc:216: Symbol successfully upgraded!


Network [LIPSPADEGenerator] was created. Total number of parameters: 72.2 million. To see the architecture, do print(network).
Load checkpoint from path:  weights/10_net_G.pth
Everything is ok!


100%|██████████| 1/1 [00:00<00:00, 12.04it/s]
1it [00:00, 306.78it/s]
1it [00:00, 3358.13it/s]
100%|██████████| 1/1 [00:01<00:00,  1.23s/it]
100%|██████████| 1/1 [00:00<00:00, 13443.28it/s]
1it [00:01,  1.69s/it]
[15:36:08] ../src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:96: Running performance tests to find the best convolution algorithm, this can take a while... (set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)


input mean and std: 127.5 127.5
find model: ./insightface_func/models/antelope/glintr100.onnx recognition
find model: ./insightface_func/models/antelope/scrfd_10g_bnkps.onnx detection
set det-size: (640, 640)
loading ./coordinate_reg/model/2d106det 0
input mean and std: 127.5 127.5
find model: ./insightface_func/models/antelope/glintr100.onnx recognition
find model: ./insightface_func/models/antelope/scrfd_10g_bnkps.onnx detection
set det-size: (640, 640)


[15:37:06] ../src/nnvm/legacy_json_util.cc:208: Loading symbol saved by previous version v1.5.0. Attempting to upgrade...
[15:37:06] ../src/nnvm/legacy_json_util.cc:216: Symbol successfully upgraded!


Network [LIPSPADEGenerator] was created. Total number of parameters: 72.2 million. To see the architecture, do print(network).




Load checkpoint from path:  weights/10_net_G.pth
Everything is ok!


100%|██████████| 1/1 [00:00<00:00, 10.79it/s]
1it [00:00, 293.64it/s]
1it [00:00, 3437.95it/s]
100%|██████████| 1/1 [00:00<00:00, 37.08it/s]
100%|██████████| 1/1 [00:00<00:00, 15477.14it/s]
1it [00:00,  4.49it/s]


input mean and std: 127.5 127.5
find model: ./insightface_func/models/antelope/glintr100.onnx recognition
find model: ./insightface_func/models/antelope/scrfd_10g_bnkps.onnx detection
set det-size: (640, 640)
loading ./coordinate_reg/model/2d106det 0
input mean and std: 127.5 127.5
find model: ./insightface_func/models/antelope/glintr100.onnx recognition
find model: ./insightface_func/models/antelope/scrfd_10g_bnkps.onnx detection
set det-size: (640, 640)


[15:37:54] ../src/nnvm/legacy_json_util.cc:208: Loading symbol saved by previous version v1.5.0. Attempting to upgrade...
[15:37:54] ../src/nnvm/legacy_json_util.cc:216: Symbol successfully upgraded!


Network [LIPSPADEGenerator] was created. Total number of parameters: 72.2 million. To see the architecture, do print(network).




Load checkpoint from path:  weights/10_net_G.pth
Everything is ok!


100%|██████████| 1/1 [00:00<00:00, 10.63it/s]
1it [00:00, 257.67it/s]
1it [00:00, 3637.73it/s]
100%|██████████| 1/1 [00:00<00:00, 33.70it/s]
100%|██████████| 1/1 [00:00<00:00, 15947.92it/s]
1it [00:00,  4.61it/s]
