In [None]:
import cv2
import sys
import os
import numpy as np
import logging
import glob
import insightface
from insightface.app import FaceAnalysis
from cpu_searcher import numpy_searcher
import time
import json

#from PIL import ImageDraw,Image,ImageFont

import matplotlib.pyplot as plt
%matplotlib inline

# Get init db image

In [None]:
img_pathes = glob.glob("/home/avs/Codes/face_db/CBA辽宁/*/*.jpg")
img_pathes.extend(glob.glob("/home/avs/Codes/face_db/CBA辽宁/*/*.jpeg"))

In [None]:
def path_2_img_key(path):
    img_name = os.path.basename(path)
    politican_name = os.path.basename(os.path.dirname(path))
    return '{}+{}'.format(politican_name, img_name)

def read_image_to_data(image_path):
    image_data = cv2.imread(image_path)
    if image_data is None:
        logging.error("== read invalid data ", image_path)
        return
    
    image_data = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)
    return image_data

def im_show(image):
    plt.imshow(image)
    plt.show()

In [None]:
#img_keys = map(path_2_img_key, img_pathes)
img_datas = list(map(read_image_to_data, img_pathes))

# Init face model and db

In [None]:
def is_rectangle_cross(rect1, rect2):
    x1 = max(rect1[0], rect2[0])
    y1 = max(rect1[1], rect2[1])
    x2 = min(rect1[2], rect2[2])
    y2 = min(rect1[3], rect2[3])
    if x1<x2 and y1<y2:
        return True
    else:
        return False

In [None]:
app = FaceAnalysis()
app.prepare(ctx_id=0, det_size=(640, 640))

searcher = numpy_searcher(1000, 512)

In [None]:
def load_label(json_path):
    label_faces = []
    with open(json_path, 'r') as fin:
        labels = json.loads(fin.read())
        for label in labels['shapes']:
            if label['label'] == 'face':
                label_faces.append((label['points'], label['group_id']))
    return label_faces

def load_real_db():
    test_json_pathes = glob.glob("/home/avs/Downloads/CBA辽宁_facedb/*.json")
    for test_json in test_json_pathes:
        test_image_path = test_json[:-4]+'jpg'
        if not os.path.exists(test_image_path):
            test_image_path = test_json[:-4]+'png'
        print("== read test ", test_image_path)
        test_image = cv2.imread(test_image_path)
        img_name = os.path.basename(test_image_path)
        faces = app.get(test_image)
        label_faces = load_label(test_json)
        for face in faces:
            for label_face in label_faces:
                label_box, player_num = label_face
                label_box = np.array(label_box).flatten()
                if is_rectangle_cross(face['bbox'], label_box):
                    feat = face.normed_embedding.tolist()
                    key = '{}_{}+{}'.format(player_num, img_name, img_name)
                    #print("== update player ", player_num)
                    searcher.update(key, feat)

In [None]:
def load_db_image(img_datas, img_pathes, searcher, app):
    for img_data,img_path in zip(img_datas, img_pathes):
        faces = app.get(img_data)
        label_path = img_path.split('.')[0]+'.json'
        label_faces = load_label(label_path)
        img_name = os.path.basename(img_path)
        for face in faces:
            for label_face in label_faces:
                label_box, player_num = label_face
                label_box = np.array(label_box).flatten()
                if is_rectangle_cross(face['bbox'], label_box):
                    print("== update ", player_num, ' ', img_name)
                    feat = face.normed_embedding.tolist()
                    key = '{}_{}+{}'.format(player_num, img_name, img_name)
                    searcher.update(key, feat)
#         if len(faces) != 1:
#             print("==invalid len {} for image {}".format(len(faces), img_key))
#             continue
#         else:
#             feat = faces[0].normed_embedding.tolist()
#             searcher.update(img_key, feat)
load_db_image(img_datas, img_pathes, searcher, app)

# Get test image and search db

In [None]:
test_images = glob.glob("/home/avs/Codes/face_db/CBA广东东莞/*/*.png")

test_image_datas = list(map(read_image_to_data, test_images))

In [None]:
def search_image(test_image_data, threshold=0.55):
    out_image = test_image_data.copy()
    faces = app.get(test_image_data)
    test_feats = []
    searched_faces = []
    for face in faces:
        #print(face['bbox'])
        test_feat = face.normed_embedding.tolist()
        test_feats.append(test_feat)
    if len(test_feats) == 0:
        print("empty image")
        return (out_image, searched_faces)
    test_feats = np.array(test_feats)
    #print('== test path {} with {} face'.format(test_path, len(faces)))
    topk_keys = searcher.topk(test_feats, topk=1)
    #print("== topk keys ", topk_keys)
    
#     if (isinstance(out_image, np.ndarray)):
#         out_image = Image.fromarray(out_image)
    
    for face_result, face in zip(topk_keys, faces):
        #print("== face topk ", face_result)
        searched_key, score = face_result[0]
        #print("== search result ", face_result)
        searched_name_num = searched_key.split('+')[0]
        #print("== similar name {} and score {}".format(searched_name_num, score))
        searched_name = searched_name_num.split('_')[0]
        #print("face_point ", face['bbox'])
        x1, y1, x2, y2 = face['bbox']
        
        if score>= threshold:
            print("== searched ", searched_name, " ", score)
            font = cv2.FONT_HERSHEY_SIMPLEX
            cv2.rectangle(out_image, (int(x1),int(y1)), (int(x2),int(y2)), (0,255,0) , 1)
            #cv2.putText(out_image, str(searched_name)+'_'+str(score), (int(x1-10),int(y1-10)), font, 1, (255,0,0), 3)
            cv2.putText(out_image, str(searched_name), (int(x1-10),int(y1-10)), font, 1, (255,0,0), 3)
            searched_faces.append([x1, y1, x2, y2, searched_name, score])
#             draw = ImageDraw.Draw(out_image)
#             fontText = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeMono.ttf", 20, encoding="utf-8")
#             draw.text((int(x1-10), int(y1-10)), searched_name, (255, 0, 0), font=fontText)
            
            
    return (out_image, searched_faces)
    #out_image = app.draw_on(test_image_data, faces)
    #im_show(out_image)
    
# for test_image_data, test_path in zip(test_image_datas, test_images):
#     out_show = searched_image(test_image_data)
#     im_show(out_show)

In [None]:
#in_video = '/home/avs/Codes/face_recognition/datas/recog/playerNum_LNBGvsZJCZ_615.ts'
in_video = '/home/avs/Downloads/LNBGvsZJCZ_615.ts'
basename = os.path.basename(in_video)
output_video = os.path.join('./datas/recog', '{}_{}'.format('faceid', basename))

output_csv = os.path.join('./datas/recog', '{}_{}.txt'.format('faceid', basename))


reader = cv2.VideoCapture(in_video)
#writer = cv2.VideoWriter(output_video, cv2.VideoWriter_fourcc(*"mp4v"),30, (1920, 1080))

fout = open(output_csv, 'w')

more = True
frame_id = -1
interval = 50
tic = time.time()
while more:
    more, frame = reader.read()
    if frame is not None:
        frame_id += 1
        searched_out, searched_faces = search_image(frame, 0.55)
        for search_face in searched_faces:
            print(search_face)
            face_out = ','.join(map(str, search_face))
            line = '{},{}\n'.format(frame_id, face_out)
            fout.write(line)
        if frame_id % interval == 0:
            toc = time.time()
            print("== frames speed",interval/(toc-tic))
            tic = time.time()
            print(frame_id)
        #writer.write(searched_out)
reader.release()
#writer.release()
fout.close()