In [None]:
import cv2
import sys
import os
import numpy as np
import glob
from clip_api import ClipFeat

import insightface
import time

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
def init_detector():
    detector = insightface.model_zoo.get_model('scrfd_person_2.5g.onnx', download=True)
    detector.prepare(0, nms_thresh=0.5, input_size=(640, 640))
    return detector

In [None]:
clip_feat = ClipFeat()
detector = init_detector()

In [None]:
in_video = '/home/avs/Downloads/LNBGvsZJCZ_615.ts'
basename = os.path.basename(in_video)
output_video = os.path.join('./datas/recog', '{}_{}'.format('person_track', basename))

reader = cv2.VideoCapture(in_video)
writer = cv2.VideoWriter(output_video, cv2.VideoWriter_fourcc(*"mp4v"),30, (1920, 1080))

trace_list = []
trace_id = 1

def get_color(idx):
    idx = idx * 3
    color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
    return color

class Trace:
    def __init__(self, feat):
        self.feats = [feat]
        self.miss_cnt = 0
        global trace_id
        self.id = trace_id
        trace_id += 1

    def update(self, feat):
        self.feats.append(feat)
        
    def miss(self):
        self.miss_cnt += 1
        
    def get_id(self):
        return self.id
        
    def cos_sim(self, feat):
        epsilon = 1e-10
        feat_dot = np.dot(self.feats[-1], np.transpose(feat))
        
        norm_test_feats = np.linalg.norm(feat, ord=2, axis=1, keepdims=True)
        norm_ref_datas = np.linalg.norm(self.feats[-1], ord=2, axis=1, keepdims=True)
        norm_dot = np.dot(norm_test_feats, np.transpose(norm_ref_datas))
        cos_distances = np.divide(feat_dot, norm_dot+epsilon)
        return cos_distances[0][0]
    
def trace_image(image, trace_list, thresh=0.7):
    bboxes, _ = detector.detect(image)
    out_image = image.copy()
    
    line_thickness = 1

    trace_2_match = {trace:False for trace in trace_list}
    new_trace_list = []
    for bbox in bboxes:
        x1, y1, x2, y2, _ = bbox
        person = image[int(y1):int(y2), int(x1):int(x2), :]
        int_box = [int(x1), int(y1), int(x2), int(y2)]
        shape = person.shape
        match = False
        #print("== person shape", person.shape)
        if shape[0]>0 and shape[1] > 0:
            person_feat = clip_feat.forward(person).cpu().numpy()
            #print(person_feat.shape)
            for trace in trace_list:
                cos_sim = trace.cos_sim(person_feat)
                if cos_sim >= thresh:
                    color = get_color(trace.get_id())
                    cv2.rectangle(out_image,int_box[0:2], int_box[2:4], color=color, thickness=line_thickness)
                    trace.update(person_feat)
                    trace_2_match[trace] = True
                    match = True
                    break
            if match == False:
                new_trace_list.append(Trace(person_feat))
    for trace, match in trace_2_match.items():
        if match == False:
            trace.miss()
            
    for new_trace in new_trace_list:
        trace_list.append(new_trace)
        
    out_trace_list = []
    for trace in trace_list:
        if trace.miss_cnt >= 3:
            continue
        out_trace_list.append(trace)
    return out_image, out_trace_list
        
    
more = True
frame_id = -1
interval = 50
tic = time.time()
while more:
    more, frame = reader.read()
    if frame is not None:
        frame_id += 1
        trace_out, trace_list = trace_image(frame, trace_list)
        if frame_id % interval == 0:
            toc = time.time()
            print("== frames speed",interval/(toc-tic))
            tic = time.time()
            print(frame_id)
        writer.write(trace_out)
reader.release()
writer.release()