In [None]:
import cv2
import sys
import os
import numpy as np
import time
import re
import glob
import insightface

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
def init_detector():
    detector = insightface.model_zoo.get_model('scrfd_person_2.5g.onnx', download=True)
    detector.prepare(0, nms_thresh=0.5, input_size=(640, 640))
    return detector
detector = init_detector()

In [None]:
sys.path.append(os.path.expanduser('~/Codes/PaddleOCR'))
from paddleocr import PaddleOCR
ocr_engine = PaddleOCR()

In [None]:
def im_show(image):
    plt.imshow(image)
    plt.show()

In [None]:
from clip_api import ClipDiscriminator
clipDiscriminator = ClipDiscriminator(["player wear white shirt", "other"])

In [None]:
def recog_image_number(image, only_detect=False):
    bboxes, _ = detector.detect(image)
    out_image = image.copy()

    playernum_boxes = []
    for bbox in bboxes:
        x1, y1, x2, y2, _ = bbox
        person = image[int(y1):int(y2), int(x1):int(x2), :]
        
        shape = person.shape
        
        if shape[0]>0 and shape[1] > 0:
#             person_clothes_prob = clipDiscriminator.forward(person)[0]
#             if person_clothes_prob <= 0.6:
#                 #print("T shirt color not match")
#                 continue

            person = image[int(y1):int(y2), int(x1):int(x2), :].copy()
            player_num = ''
            print("== show det person")
            im_show(person)
        
        #print("== person shape", person.shape)
        
            #cv2.rectangle(out_image, (int(x1),int(y1)), (int(x2),int(y2)), (0,255,0) , 1)
            ocr_result = ocr_engine.ocr(person, cls=False)
            print(ocr_result)
            for box_recog_text in ocr_result[0]:
                box, recog_text = box_recog_text
                y_ratio = box[0][1]/shape[0]
                x_ratio = box[0][0]/shape[1]
                x_min = int(box[0][0])
                y_min = int(box[0][1])
                x_max = int(box[2][0])
                y_max = int(box[2][1])
                box_img = person[y_min:y_max, x_min:x_max, :]
                print("== det box show")
                im_show(box_img)
                if (y_ratio >= 0.2 and y_ratio <= 0.5) and \
                    (x_ratio >= 0.2 and x_ratio <= 0.5): # restrict player number position
                    text = recog_text[0]
                    number = re.findall(r'\d+', text)
                    if len(number) > 0:
                        player_num = number[0]
                        print("== recog player number ", player_num)
                        if len(player_num) <= 2:
                            font = cv2.FONT_HERSHEY_SIMPLEX
                            cv2.rectangle(out_image, (int(x1),int(y1)), (int(x2),int(y2)), (0,255,0) , 1)
                            cv2.putText(out_image, player_num, (int(x1-10),int(y1-10)), font, 2, (255,0,0), 3)
                            playernum_boxes.append([x1, y1, x2, y2, player_num])
                else:
                    #pass
                    print("= ignore ratio ", y_ratio, ',', x_ratio)
    return out_image, playernum_boxes


def handle_video():
    in_video = '/home/avs/Downloads/LNBGvsZJCZ_615.ts'
    basename = os.path.basename(in_video)
    output_video = os.path.join('./datas/recog', '{}_{}'.format('playerNum_white_ocr_label', basename))
    output_txt = os.path.join('./datas/recog', '{}_{}.txt'.format('playerNum_white_ocr_label', basename))

    reader = cv2.VideoCapture(in_video)
    #writer = cv2.VideoWriter(output_video, cv2.VideoWriter_fourcc(*"mp4v"),30, (1920, 1080))
    fout = open(output_txt, 'w')

    more = True
    frame_id = -1
    interval = 10
    tic = time.time()
    while more:
        more, frame = reader.read()
        if frame is not None:
            frame_id += 1
            searched_out, playernum_boxes = recog_image_number(frame)
            for player_box in playernum_boxes:
                player_out = ','.join(map(str, player_box))
                line = '{},{}\n'.format(frame_id, player_out)
                fout.write(line)

            if frame_id % interval == 0:
                toc = time.time()
                print("== frames speed",interval/(toc-tic))
                tic = time.time()
                print(frame_id)
            #writer.write(searched_out)
    reader.release()
    #writer.release()
    fout.close()

In [None]:
def handle_img_folder():
    input_folder = '/mnt/nas/ActivityDatas/篮球海滨标注/chouzhen_sample/'
    output_folder = './datas/chouzhen_number'
    number_outpath = './datas/chouzhen_number.txt'
    
    fout = open(number_outpath, 'w')
    for idx, img_name in enumerate(os.listdir(input_folder)):
        if idx % 10 == 0:
            print("== handle ", idx)
        img_path = os.path.join(input_folder, img_name)
        frame = cv2.imread(img_path)

#         ocr_result = ocr_engine.ocr(frame, cls=False)
#         print(ocr_result)    
#         im_show(frame)
        
        searched_out, playernum_boxes = recog_image_number(frame)
        print("== search player number", img_name)
        out_path = os.path.join(output_folder, img_name)
        cv2.imwrite(out_path, searched_out)
        
        for player_box in playernum_boxes:
            player_out = ','.join(map(str, player_box))
            line = '{},{}\n'.format(img_name, player_out)
            fout.write(line)
            
handle_img_folder()