In [None]:
import gradio as gr
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.outputs import OutputKeys
from PIL import Image
import json
import os
import copy
import numpy as np
from util import *
import cv2

face_detector = pipeline(Tasks.face_detection, model='gaosheng/face_detect')
# face_recognizer = pipeline(Tasks.face_recognition, model='damo/cv_ir101_facerecognition_cfglint')
face_recognizer = pipeline(Tasks.face_recognition, model='iic/cv_ir101_facerecognition_cfglint')
emotion_recognizer = pipeline(Tasks.facial_expression_recognition, 'damo/cv_vgg19_facial-expression-recognition_fer')
portrait_matting = pipeline(Tasks.portrait_matting, model='damo/cv_unet_image-matting')
speech_recognizer = pipeline(task=Tasks.auto_speech_recognition, model='iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch', device='cpu')
face_bank = load_face_bank('face_bank/', face_recognizer)
name_box_map = {}
detected_image = None
original_image = None


def inference(img: Image, draw_detect_enabled, detect_threshold, sim_threshold) -> json:
    global original_image
    original_image = copy.deepcopy(img)

    img = resize_img(img)
    img = img.convert('RGB')

    global detected_image
    detected_image = copy.deepcopy(img)

    detection_result = face_detector(img)
    boxes = np.array(detection_result[OutputKeys.BOXES])
    scores = np.array(detection_result[OutputKeys.SCORES])
    faces = []

    for i in range(len(boxes)):
        score = scores[i]
        if score < detect_threshold:
            continue
        box = boxes[i]
        face_embedding = get_face_embedding(img, box, face_recognizer)
        name, sim = get_name_sim(face_embedding, face_bank)
        if name is None:
            continue
        if sim < sim_threshold:
            faces.append({'box': box, 'name': '未知', 'sim': sim})
        else:
            faces.append({'box': box, 'name': name, 'sim': sim})
            real_name = name[2:] # 去掉前2位学号
            name_box_map[real_name] = box
    rows = get_rows(faces)
    row_names = get_row_names(faces, rows)
    draw_name(img, row_names)
    if draw_detect_enabled:
        draw_faces(img, faces, emotion_recognizer)
    return img, get_row_names_text(row_names)

def search_face_cutouts(name_input, audio_input):

    name = name_input
    if not name:
        audio_text = recognize_speech_from_audio(audio_input, speech_recognizer)
        name_pinyin_bank = load_name_pinyin_bank('face_bank/')
        found_name, sim = find_name_by_audio_text(audio_text, name_pinyin_bank)
        if sim >= 0.3:
            name = found_name

    if not name:
        return "404.jpg"

    if name not in name_box_map:
        return "404.jpg"

    # 适当扩大边框范围，保证覆盖人脸但是又不会显得边框过大
    box = name_box_map[name]
    box[0] = box[0] - 5
    box[2] = box[2] + 2
    box[1] = box[1] - 2
    box[3] = box[3] + 2

    global original_image
    original_image = original_image.convert('RGB')
    original_image_box = []
    original_image_box.append(original_image.width*(box[0]/detected_image.width))
    original_image_box.append(original_image.height*(box[1]/detected_image.height))
    original_image_box.append(original_image.width*(box[2]/detected_image.width))
    original_image_box.append(original_image.height*(box[3]/detected_image.height))

    face_img = get_face_img(original_image, original_image_box)
    result = portrait_matting(face_img)
    face_cutouts = result[OutputKeys.OUTPUT_IMG]
    # 要先写成本地图片才能保存颜色(原始的npy array会导致丢部分颜色信息)
    cv2.imwrite('temp.png', face_cutouts)
    # 抠图之后的图像太小，需要等比放大一些
    # face_cutouts = Image.open('temp.png')
    # original_width, original_height = face_cutouts.size
    # new_width = original_width * 5
    # new_height = original_height * 5
    # resized_face_cutouts = face_cutouts.resize((new_width, new_height), Image.LANCZOS)
    return 'temp.png'

examples = ['example.jpg']

with gr.Blocks() as demo:
    with gr.Row():
        draw_detect_enabled = gr.Checkbox(label="是否画框", value=True)
        detect_threshold = gr.Slider(label="检测阈值", minimum=0, maximum=1, value=0.3)
        sim_threshold = gr.Slider(label="识别阈值", minimum=0, maximum=1, value=0.3)
    with gr.Row():
        with gr.Column():
            img_input = gr.Image(type="pil", height=350)
            submit = gr.Button("提交待识别图片")
        with gr.Column():
            img_output = gr.Image(type="pil", label="识别结果")
            name_output = gr.Text(label="人名列表")
    with gr.Row():
        with gr.Column():
            name_input = gr.Text(label="文本输入人名")
            audio_input = gr.Audio(sources=["microphone"], type="filepath", label="语音输入人名")
            submit2 = gr.Button("根据人名搜索头像")
        with gr.Column():
            face_cutouts = gr.Image(type="pil", label="搜索结果")
    submit.click(
        fn=inference,
        inputs=[img_input, draw_detect_enabled, detect_threshold, sim_threshold],
        outputs=[img_output, name_output])
    submit2.click(
        fn=search_face_cutouts,
        inputs=[name_input, audio_input],
        outputs=[face_cutouts])
    gr.Examples(examples, inputs=[img_input])

demo.launch(share=True)

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/site-packages/gradio/queueing.py", line 624, in process_events
    response = await route_utils.call_process_api(
  File "/usr/local/lib/python3.10/site-packages/gradio/route_utils.py", line 323, in call_process_api
    output = await app.get_blocks().process_api(
  File "/usr/local/lib/python3.10/site-packages/gradio/blocks.py", line 2043, in process_api
    result = await self.call_function(
  File "/usr/local/lib/python3.10/site-packages/gradio/blocks.py", line 1590, in call_function
    prediction = await anyio.to_thread.run_sync(  # type: ignore
  File "/usr/local/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
  File "/usr/local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2441, in run_sync_in_worker_thread
    return await future
  File "/usr/local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", 