In [None]:
import cv2
import sys
import os
import numpy as np
import glob
from cpu_searcher import numpy_searcher
from clip_api import ClipFeat
from reid_api import PersonFeat
import insightface
import time

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
db_pathes = glob.glob("/home/avs/Codes/face_recognition/datas/arsenal_players/*/*.jpg")

In [None]:
def path_2_img_key(path):
    img_name = os.path.basename(path)
    player_name = os.path.basename(os.path.dirname(path))
    return '{}+{}'.format(player_name, img_name)

def read_image_to_data(image_path):
    image_data = cv2.imread(image_path)
    if image_data is None:
        logging.error("== read invalid data ", image_path)
        return
    
    image_data = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)
    return image_data

def im_show(image):
    plt.imshow(image)
    plt.show()
    

def init_detector():
    detector = insightface.model_zoo.get_model('scrfd_person_2.5g.onnx', download=True)
    detector.prepare(0, nms_thresh=0.5, input_size=(640, 640))
    return detector

In [None]:
img_keys = map(path_2_img_key, db_pathes)
img_datas = list(map(read_image_to_data, db_pathes))

In [None]:
searcher = numpy_searcher(1000, 512)
clip_feat = ClipFeat()
#feat_extractor = PersonFeat('./agw_r50.onnx')
for img_data,img_key in zip(img_datas, img_keys):
    person_feat = clip_feat.forward(img_data).cpu().numpy()[0].tolist()
    #person_feat = feat_extractor.forward(img_data)[0].tolist()
    #print(person_feat)
    searcher.update(img_key, person_feat)

In [None]:
detector = init_detector()

img_path = './datas/00000023.jpg'

def search_image(image, only_detect=False):
    number_used_flags = {i:False for i in range(50)}
    bboxes, _ = detector.detect(image)
    out_image = image.copy()

    for bbox in bboxes:
        x1, y1, x2, y2, _ = bbox
        person = image[int(y1):int(y2), int(x1):int(x2), :]
        #im_show(person)
        shape = person.shape
        #print("== person shape", person.shape)
        if shape[0]>0 and shape[1] > 0:
            cv2.rectangle(out_image, (int(x1),int(y1)), (int(x2),int(y2)), (0,255,0) , 1)
            if only_detect:
                continue
            
            test_feats = clip_feat.forward(person).cpu().numpy()
            #test_feats = feat_extractor.forward(person)

            topk_keys = searcher.topk(test_feats, topk=1)
        #     print("== top result", topk_keys)
            searched_key, score = topk_keys[0][0]
            #print("== searched key ", searched_key, score)
            if score> 0.88:
                print("== searched match", topk_keys)
                search_name = int(searched_key.split('+')[0])
                if number_used_flags[search_name]:
                    continue
                
                #print("==Matched  key ", searched_key, ' ', score)
                font = cv2.FONT_HERSHEY_SIMPLEX
                cv2.putText(out_image, str(search_name), (int(x1-10),int(y1-10)), font, 2, (255,0,0), 3)
                number_used_flags[search_name] = True
    #im_show(out_image)
    return out_image

# image_pathes = glob.glob("/home/avs/Codes/PaddleDetection/Arsenal_football_club/10037506_YBS_live55_rzzmtx_3min_cut.ts_imgs/*.jpg")
# dst_folder = "./datas/arsenal_video"
# for img_path in image_pathes:
#     image = cv2.imread(img_path)
#     image_name = os.path.basename(img_path)
#     #image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#     searched_out = search_image(image)
#     cv2.imwrite(os.path.join(dst_folder, image_name), searched_out)

In [None]:
in_video = '/home/avs/Codes/PaddleDetection/Arsenal_football_club/10039251_YBS_live55_ikdntk_3min_cut.ts'
basename = os.path.basename(in_video)
output_video = os.path.join('./datas/recog', basename)

reader = cv2.VideoCapture(in_video)
writer = cv2.VideoWriter(output_video, cv2.VideoWriter_fourcc(*"mp4v"),30, (1920, 1080))

more = True
frame_id = -1
interval = 50
tic = time.time()
while more:
    more, frame = reader.read()
    if frame is not None:
        frame_id += 1
        searched_out = search_image(frame)
        if frame_id % interval == 0:
            toc = time.time()
            print("== frames speed",interval/(toc-tic))
            tic = time.time()
            print(frame_id)
        writer.write(searched_out)
reader.release()
writer.release()