In [1]:
# Performanc analysis of Connectivity model

import os
import math
import time

import torch
import numpy as np
from tqdm import tqdm

from util.result import import_subsequences_results
from util.feature import triplet_feature_to_list
from util.feature import pdist_np
# Parameters
csv_dir = "G:/tim/Project/Connectivity_train_test/real_data/Kranji-2018"
feat_dir = "D:/Bitbucket/triplettraining/feat/Kranji_test"
race_dir = "D:/Bitbucket/triplettraining/feat/Kranji_test"

# csv_dir = "G:/tim/Project/Connectivity_train_test/real_data/HVT-2018-2019_full_recall"
# feat_dir = "G:/tim/Project/Connectivity_train_test/train_data/trip_feat/HVT-2018-2019_full_recall"
# race_dir = "D:/Bitbucket/triplettraining/feat/HVT_test"

# Load Connectivity Model
from model.conn import ConnNet
device =  torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_path = "G:/tim/Project/Connectivity_train_test/res/Connectivity_batch_256_Kranji_2022-08-12-16-24-09/model_35.pkl"
#model_path = "./conn.pth"
model = ConnNet().to(device)
model.load_state_dict(torch.load(model_path))
model.eval()

all_races = []

race_dict = {}

for race_path in os.listdir(race_dir):
    racelb = os.path.splitext(race_path)[0]
    race_dict[racelb] = {}

for feat_path in os.listdir(feat_dir):
    racelb = os.path.splitext(feat_path)[0]
    full_feat_path = os.path.join(feat_dir,feat_path)
    if racelb in race_dict:
        race_dict[racelb]["feat_path"] = full_feat_path

for csv_path in os.listdir(csv_dir):
    racelb = os.path.splitext(csv_path)[0]
    ext = os.path.splitext(csv_path)[1]
    full_csv_path = os.path.join(csv_dir,csv_path)
    #if ext == ".csv":
    if racelb in race_dict:
        race_dict[racelb]["csv_path"] = full_csv_path
    # Skip the not in feat dir races
    
# Use only races in the featue dir
all_races = list(race_dict.keys())

Processing Kranji-20180318-10 (1/92)...



KeyboardInterrupt



In [None]:
# To control frame different, Connectivity is trained with offset = 1,2
t_diff = [1,2]

# To find the optimized threshold or not
THRESHOLD_OPTIMISE = True
match_thres_list = [0.8]
num_thres = len(match_thres_list)

if THRESHOLD_OPTIMISE:
    num_thres = 20
    match_thres_list = np.linspace(0, 1, num_thres)

precisions = [[] for _ in range(num_thres)]
recalls = [[] for _ in range(num_thres)]
best_f1_score = -1
best_thres = 0
best_precision = 0 
best_recall = 0 
best_match_thres_id = -1

TPs = {}
TP_dist = {}
FPs = {}
FP_dist = {}

# Loop through each race
for race_i in range(len(all_races)):
    t_start = time.time()
    each_race = all_races[race_i]
    print(f"Processing {each_race} ({race_i+1}/{len(all_races)})...")
    # Read CSV
    tar_csv_path = race_dict[each_race]["csv_path"]
    start_frm, end_frm, boxes_list = import_subsequences_results(tar_csv_path)
    
    # Read Feature
    trip_feat_npy = np.load(race_dict[each_race]["feat_path"], "r", True)
    trip_feat_list = triplet_feature_to_list(trip_feat_npy, boxes_list)
    
    # For Peformance analysis of each race
    TPs_all = [[] for _ in range(num_thres)]
    TP_dist_all = [[] for _ in range(num_thres)]
    FPs_all = [[] for _ in range(num_thres)]
    FP_dist_all = [[] for _ in range(num_thres)]
    
    # Loop through each frame offset 
    for t_index in range(len(t_diff)):
        offset = t_diff[t_index]
        # For storing num of TPs and FPs only
        TPs_stats = [0]*num_thres
        FPs_stats = [0]*num_thres
        total_matchable = 0
        # Do performance analysis
        for frm_id in range(len(trip_feat_list)):
            t0 = frm_id
            t1 = frm_id + offset
            if t1 >= len(trip_feat_list):
                continue
            if len(trip_feat_list[t0]) == 0 or len(trip_feat_list[t1]) == 0: 
                continue
            
            # Construct input for connectivity model
            t0_boxes = boxes_list[t0]
            t1_boxes = boxes_list[t1]        
            t0_ids = []
            t1_ids = []
            inputs = []
            for box_id_0 in range(len(t0_boxes)):
                jk_id_0 = t0_boxes[box_id_0][-1]
                if jk_id_0 not in t0_ids:
                    t0_ids.append(jk_id_0)
                for box_id_1 in range(len(t1_boxes)):
                    jk_id_1 = t1_boxes[box_id_1][-1]
                    if jk_id_1 not in t1_ids:
                        t1_ids.append(jk_id_1)

                    cap_0_feat = trip_feat_list[t0][box_id_0]
                    cap_1_feat = trip_feat_list[t1][box_id_1]
                    cap_0 = t0_boxes[box_id_0]
                    cap_1 = t1_boxes[box_id_1]
                    cap_0_cx = int((cap_0[1] + cap_0[3]) / 2)
                    cap_0_cy = int((cap_0[0] + cap_0[2]) / 2)
                    cap_1_cx = int((cap_1[1] + cap_1[3]) / 2)
                    cap_1_cy = int((cap_1[0] + cap_1[2]) / 2)
                    cap_1_w = int((cap_0[1] + cap_0[3]) / 2)
                    cap_1_h = int((cap_0[0] + cap_0[2]) / 2)
                    # Normalise
                    feat_diff = np.sum((cap_0_feat - cap_1_feat) ** 2) ** 0.5
                    dx = abs(cap_0_cx - cap_1_cx) / (math.sqrt(cap_1_w * cap_1_h) * 100)
                    dy = abs(cap_0_cy - cap_1_cy) / (math.sqrt(cap_1_w * cap_1_h) * 50)
                    norm_t = offset / 2
                    tensor_in = torch.tensor([feat_diff,dx,dy,norm_t]).float()
                    inputs.append(tensor_in)
            inputs = torch.stack(inputs).to(device)
            # Run inputs with connectivity model, and reshape it into num_cap_t0 * num_cap_t1 matrix
            prob_mtx = []
            with torch.no_grad():
                prob_mtx = model(inputs).detach().cpu().numpy()[:,1]
                prob_mtx = prob_mtx.reshape(len(t0_ids),len(t1_ids))
            # Calculate matchable ids number (exists in both t0 and t1) and all pairs
            matchable_ids = set(t0_ids) & set(t1_ids)
            num_match = len(matchable_ids)
            total_matchable += num_match
            num_pairs = len(t0_ids) * len(t1_ids)
            
            for match_thres_id in range(num_thres):
                match_thres = match_thres_list[match_thres_id]
                matchable_mtx = np.any(prob_mtx > match_thres,axis = 1)
                num_candidate = sum(matchable_mtx)
                max_prob_ids = np.argmax(prob_mtx,axis = 1)

                for x in range(len(max_prob_ids)):
                    valid = matchable_mtx[x]
                    if not valid:
                        continue
                    id_in_t0 = x
                    id_in_t1 = max_prob_ids[x]
                    tar_name_t0 = t0_ids[id_in_t0]
                    tar_name_t1 = t1_ids[id_in_t1]
                    prob = prob_mtx[x][id_in_t1]
                    if tar_name_t0 == tar_name_t1:
                        #print(TPs_all[match_thres_id] + [f"frm_{t0}_{tar_name_t0}_frm_{t1}_{tar_name_t1}"])
                        TPs_all[match_thres_id].append(f"frm_{t0}_{tar_name_t0}_frm_{t1}_{tar_name_t1}")
                        #print(len(TPs_all[match_thres_id]),TPs_stats[match_thres_id])
                        TP_dist_all[match_thres_id].append(prob)
                        TPs_stats[match_thres_id]+=1
                    else:
                        FPs_all[match_thres_id].append(f"frm_{t0}_{tar_name_t0}_frm_{t1}_{tar_name_t1}")
                        FP_dist_all[match_thres_id].append(prob)
                        FPs_stats[match_thres_id]+=1
        # Calculate the precisions and recalls for each threshold
        for match_thres_id in range(num_thres):
            num_TP = TPs_stats[match_thres_id]
            num_FP = FPs_stats[match_thres_id]
            total_matches = num_TP + num_FP
            precision = 0
            recall = num_TP / total_matchable
            if (total_matches>0):
                precision = num_TP / total_matches

            precisions[match_thres_id].append(precision)
            recalls[match_thres_id].append(recall)
    
    for match_thres_id in range(num_thres):
        match_thres = match_thres_list[match_thres_id]
        #print(f"Match Threshold:{match_thres}")
        avg_precision = sum(precisions[match_thres_id])/len(precisions[match_thres_id])
        avg_recall = sum(recalls[match_thres_id])/len(recalls[match_thres_id])
        avg_f1 = 0
        if (avg_precision+avg_recall)>0:
            avg_f1 = 2*avg_precision*avg_recall/(avg_precision+avg_recall)
        #print(match_thres,avg_f1)
        if avg_f1 > best_f1_score or best_match_thres_id == match_thres_id:
            best_f1_score = avg_f1
            best_match_thres_id = match_thres_id
            best_thres = match_thres
            best_precision = avg_precision
            best_recall = avg_recall
            
            TPs[each_race] = TPs_all[best_match_thres_id]
            TP_dist[each_race] = TP_dist_all[best_match_thres_id]
            FPs[each_race] = FPs_all[best_match_thres_id]
            FP_dist[each_race] = FP_dist_all[best_match_thres_id]
            
    t_end = time.time()
    print(f"Time_taken: {t_end-t_start}, best_match_thres: {best_thres}, best_f1_score: {best_f1_score}, with avg_precision: {best_precision}, avg_recall: {best_recall}")

In [13]:
tar_race = "Kranji-20180615-10"
len(FPs[tar_race])
#print(FP_dist[tar_race])

for i in range(num_thres):
    match_thres = match_thres_list[i]
    avg_precision = sum(precisions[i])/len(precisions[i])
    avg_recall = sum(recalls[i])/len(recalls[i])
    print(f"Best_match_thres: {match_thres}, with avg_precision: {avg_precision}, avg_recall: {avg_recall}")
    

Best_match_thres: 0.0, with avg_precision: 0.966143094879151, avg_recall: 0.9983455343687493
Best_match_thres: 0.05263157894736842, with avg_precision: 0.993944932158197, avg_recall: 0.9966537349964872
Best_match_thres: 0.10526315789473684, with avg_precision: 0.9950476144185382, avg_recall: 0.9951649556036265
Best_match_thres: 0.15789473684210525, with avg_precision: 0.995592103160094, avg_recall: 0.993933754087289
Best_match_thres: 0.21052631578947367, with avg_precision: 0.9959649477901585, avg_recall: 0.992664240010454
Best_match_thres: 0.2631578947368421, with avg_precision: 0.9962255493121445, avg_recall: 0.9913314600784027
Best_match_thres: 0.3157894736842105, with avg_precision: 0.9964291955036053, avg_recall: 0.9899488566943169
Best_match_thres: 0.3684210526315789, with avg_precision: 0.9966008327694706, avg_recall: 0.9884295546384415
Best_match_thres: 0.42105263157894735, with avg_precision: 0.9967597235069553, avg_recall: 0.986859837815178
Best_match_thres: 0.473684210526315