In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from mmdet.apis import DetInferencer
import glob
import pickle
import numpy as np
import os
# Choose to use a config
config = 'air_typing_detection_config.py'
# Setup a checkpoint file to load
checkpoint = 'detection_model.pth'
# Set the device to be used for evaluation
device = 'cuda:0'
# Initialize the DetInferencer
inferencer = DetInferencer(config, checkpoint, device,show_progress=False)

class tap_point:
    def __init__(self, xyxy, scaled_xyxy, tap_time_span):
        self.xyxy = xyxy
        self.scaled_xyxy = scaled_xyxy
        self.tap_time_span= [tap_time_span[0]/30*1000, tap_time_span[1]/30*1000]

class gt_tap_point:
    def __init__(self, xyxy, tap_timestamp):
        self.xyxy = xyxy
        self.tap_timestamp= tap_timestamp

class detected_box:
    def __init__(self, xyxy, scaled_xyxy):
        self.xyxy = xyxy
        self.scaled_xyxy = scaled_xyxy
        self.finger_in_timestamp_list = []
        self.actual_tap_time_span = None
        
    def get_finger_in_timestamp(self,tap_time_data):
        for i in range(len(tap_time_data)):
            for j in range(len(tap_time_data[i])):
                if len(tap_time_data[i][j])==0:
                    continue
                if self.scaled_xyxy[0]<=j and self.scaled_xyxy[2]>j and self.scaled_xyxy[1]<=i and self.scaled_xyxy[3]>i:
                    for k in range(len(tap_time_data[i][j])):
                        self.finger_in_timestamp_list.append(tap_time_data[i][j][k][0])
        self.finger_in_timestamp_list = sorted(self.finger_in_timestamp_list)
        
    
    def get_timestamp_of_tap(self):
        #max continuous time
        max_continuous_time=np.zeros(len(self.finger_in_timestamp_list))
        for i in range(len(self.finger_in_timestamp_list)):
            if i==0:
                max_continuous_time[i]=1
            else:
                if self.finger_in_timestamp_list[i]-self.finger_in_timestamp_list[i-1]==1:
                    max_continuous_time[i]=max_continuous_time[i-1]+1
                else:
                    max_continuous_time[i]=1
        actual_tap_time_span=[]
        for i in range(len(self.finger_in_timestamp_list)-1):
            current_continuous_time=max_continuous_time[i]
            last_continuous_time=max_continuous_time[i-1]
            next_continuous_time=max_continuous_time[i+1]
            if current_continuous_time>last_continuous_time and current_continuous_time>next_continuous_time and current_continuous_time>10:
                actual_tap_time_span.append([self.finger_in_timestamp_list[i]-current_continuous_time+1,self.finger_in_timestamp_list[i]])
        current_continuous_time=max_continuous_time[-1]
        last_continuous_time=max_continuous_time[-2]
        if current_continuous_time>last_continuous_time and current_continuous_time>10:
            actual_tap_time_span.append([self.finger_in_timestamp_list[-1]-current_continuous_time+1,self.finger_in_timestamp_list[-1]])
        self.actual_tap_time_span = actual_tap_time_span
    
    def get_tap_point(self):
        tap_point_list=[]
        for tap_time_span in self.actual_tap_time_span:
            tap_point_list.append(tap_point(self.xyxy,self.scaled_xyxy,tap_time_span))
        return tap_point_list

def get_directed_detected_result(image_path):
    result = inferencer(image_path, out_dir='./output')
    box_list_scaled=[]
    box_list_original=[]
    scale=10
    for i in range(len(result['predictions'][0]['bboxes'])):
        score=result['predictions'][0]['scores'][i]
        box=result['predictions'][0]['bboxes'][i]
        if score<0.3:
            continue
        box_list_scaled.append([round(box[0]/scale),round(box[1]/scale),round(box[2]/scale),round(box[3]/scale)])

        box_list_original.append([box[0],box[1],box[2],box[3]])
    return box_list_scaled, box_list_original

def get_box_and_coresponding_tap_time(image_path,tap_time_path):

    def mysort(item):
        return item.tap_time_span[0]

    def filter_tap_point_list(tap_point_list):
        tap_point_list_filtered = []
        for i in range(len(tap_point_list)):
            if i!= 0:
                if tap_point_list[i].xyxy == tap_point_list[i-1].xyxy:
                    tap_point_list_filtered[-1].tap_time_span[1] = tap_point_list[i].tap_time_span[1]
                    continue
            tap_point_list_filtered.append(tap_point_list[i])
        return tap_point_list_filtered

    detected_box_list_scaled, detected_box_list_original = get_directed_detected_result(image_path)
    with open(tap_time_path, 'rb') as f:
        tap_time_data = pickle.load(f)
    tap_point_list = []
    for i in range(len(detected_box_list_scaled)):
        box = detected_box(detected_box_list_original[i], detected_box_list_scaled[i])
        box.get_finger_in_timestamp(tap_time_data)
        box.get_timestamp_of_tap()
        tap_point_list.extend(box.get_tap_point())
    tap_point_list.sort(key=mysort)
    tap_point_list = filter_tap_point_list(tap_point_list)
    return tap_point_list

def get_iou(pred_xyxy, gt_xyxy):

    pred_x1, pred_y1, pred_x2, pred_y2 = pred_xyxy
    gt_x1, gt_y1, gt_x2, gt_y2 = gt_xyxy

    inter_x1 = max(pred_x1, gt_x1)
    inter_y1 = max(pred_y1, gt_y1)
    inter_x2 = min(pred_x2, gt_x2)
    inter_y2 = min(pred_y2, gt_y2)

    inter_width = max(0, inter_x2 - inter_x1)
    inter_height = max(0, inter_y2 - inter_y1)
    inter_area = inter_width * inter_height

    pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1)
    gt_area = (gt_x2 - gt_x1) * (gt_y2 - gt_y1)

    union_area = pred_area + gt_area - inter_area

    iou = inter_area / union_area

    return iou



In [6]:
import pandas as pd
from copy import deepcopy
def get_detection_rate(image_path,tap_time_path,gt_path):

    def is_in_time_span(gt_tap_point_list,pred_tap_point,i):
        if gt_tap_point_list[i].tap_timestamp>=pred_tap_point.tap_time_span[0] and gt_tap_point_list[i].tap_timestamp<=pred_tap_point.tap_time_span[1]:
            return True
        is_in = False
        if i==0:
            if pred_tap_point.tap_time_span[1]<gt_tap_point_list[1].tap_timestamp:
                is_in = True
        elif i==len(gt_tap_point_list)-1:
            if pred_tap_point.tap_time_span[0]>gt_tap_point_list[i-1].tap_timestamp:
                is_in = True
        else:
            if pred_tap_point.tap_time_span[0]>gt_tap_point_list[i-1].tap_timestamp and pred_tap_point.tap_time_span[1]<gt_tap_point_list[i+1].tap_timestamp:
                is_in = True
        return is_in
    def mysort(gt_point_item):
        return gt_point_item.tap_timestamp
    TP = 0
    FP = 0
    pred_coordinates_list = []
    pred_tap_point_list = get_box_and_coresponding_tap_time(image_path,tap_time_path)
    for i in range(len(pred_tap_point_list)):
        center_x = (pred_tap_point_list[i].xyxy[0]+pred_tap_point_list[i].xyxy[2])/2
        center_y = (pred_tap_point_list[i].xyxy[1]+pred_tap_point_list[i].xyxy[3])/2
        pred_coordinates_list.append([center_x,center_y])
    # return pred_coordinates_list

    gt = pd.read_pickle(gt_path)
    gt_tap_point_list = []
    for i in range(len(gt)):
        for k in range(len(gt['timestamps'].iloc[i])):
            gt_tap_point_list.append(gt_tap_point(gt[['x1','y1','x2','y2']].iloc[i].to_list(),gt['timestamps'].iloc[i][k]))
    gt_tap_point_list.sort(key=mysort)
    gt_num = len(gt_tap_point_list)
    len_pred = len(pred_tap_point_list)
    FN_point_list = []
    FP_point_list = []
    for i in range(len(gt_tap_point_list)):
        max_iou = 0
        index = None
        for j in range(len(pred_tap_point_list)):
            if not is_in_time_span(gt_tap_point_list,pred_tap_point_list[j],i):
                continue
            iou = get_iou(pred_tap_point_list[j].xyxy,gt_tap_point_list[i].xyxy)
            if iou>max_iou:
                max_iou = iou
                index = j
        if max_iou>0.5 :
            TP += 1
            pred_tap_point_list.pop(index) 
        else:
            FN_point_list.append(gt_tap_point_list[i])
    FP = len_pred-TP
    FN = gt_num-TP
    for i in range(len(pred_tap_point_list)):
        FP_point_list.append(pred_tap_point_list[i])
    return TP,FP,FN,pred_coordinates_list,gt_num,FP_point_list,FN_point_list


In [None]:
name = 'review' #the name of the dataset, conversation, e-mail, review angle/clockwise angle/counterclockwise
data_folder = 'data'
img_folder = os.path.join(data_folder, 'heatmap_img', name)
temporal_folder = os.path.join(data_folder, 'temporal_information', name)
tgt_dir = os.path.join(data_folder, 'detection_results', name)
gt_folder = os.path.join(data_folder, 'gt', name)
total_TP = 0
total_FP = 0
total_FN = 0
total_gt_num = 0
if os.path.exists(tgt_dir):
    os.system('rm -r ' + tgt_dir)
os.makedirs(tgt_dir, exist_ok=True)
angle_list = os.listdir(img_folder)
for i in range(len(angle_list)):
    sub_img_folder = os.path.join(img_folder, angle_list[i])
    img_name_list = os.listdir(sub_img_folder)
    img_name_list = [i for i in img_name_list if i.endswith('.png')]
    angle = angle_list[i]
    os.makedirs(os.path.join(tgt_dir, angle), exist_ok=True)
    for j in range(len(img_name_list)):
        img_name = img_name_list[j]
        img_path = os.path.join(sub_img_folder, img_name)
        result = inferencer(img_path)
        temporal_path = os.path.join(temporal_folder, angle_list[i], img_name.split('.png')[0] + '.pkl')
        gt_path = os.path.join(gt_folder,angle,img_name.split('.png')[0]+'.pickle')
        TP,FP,FN,pred_coordinates_list,gt_num,FP_point_list,FN_point_list = get_detection_rate(img_path,temporal_path,gt_path)
        tgt_path = os.path.join(tgt_dir, angle, img_name.split('.png')[0]+'.npy')
        total_TP += TP
        total_FP += FP
        total_FN += FN
        total_gt_num += gt_num
        np.save(tgt_path,pred_coordinates_list)
print('TP:',total_TP,'FP:',total_FP,'FN:',total_FN,'gt_num:',total_gt_num)