-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
147 lines (129 loc) · 5.71 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import argparse
import os
import cv2
import trackers
from trackers import AssembleModel, YOLOModel, RTDETRModel, utils
def make_parser():
parser = argparse.ArgumentParser("Run YOLO model on VisDrone Dataset")
parser.add_argument("--SEQUENCES_DIR",
type=str,
default=".",
help="Path to VisDrone dataset. It should point to train, or test directory of the dataset.")
parser.add_argument("--MODEL",
type=str,
required=True,
help="The model used for prediction.")
parser.add_argument("--WEIGHTS_PATH",
type=str,
required=True,
help="Path to the YOLO weight.")
parser.add_argument("--TRACKER",
type=str,
default="bytetrack",
choices=["bytetrack", "botsort"],
help="YOLO Trackers")
parser.add_argument("--SHOW",
default=False,
action="store_true",
help="Whether to show the tracked sequences (using CV2).")
parser.add_argument("--SAVE_RESULTS",
default=False,
action="store_true",
help="save result (.txt file)")
parser.add_argument("--SAVE_RESULTS_DIR",
default="results",
help="Where to save the results.txt files. Default to results/")
return parser
def create_model(args):
"""Create model based on user choice."""
if args.MODEL == "yolo":
return YOLOModel(weights_path=args.WEIGHTS_PATH)
elif args.MODEL == "rtdetr":
return RTDETRModel(weights_path=args.WEIGHTS_PATH)
elif args.MODEL == "ucmc":
return AssembleModel(
detector=trackers.detection.UCMCDetector('demo/ucmc/cam_para.txt', args.WEIGHTS_PATH),
associator=trackers.association.UCMCAssociator()
)
elif args.MODEL == "smiletrack":
args.name = args.TRACKER_NAME
args.track_high_thresh = 0.3
args.track_low_thresh = 0.1
args.new_track_thresh = 0.4
args.track_buffer = 20
args.match_thresh = 0.7
args.aspect_ratio_thresh = 1.6
args.min_box_area = 10
args.mot20 = False
args.with_reid = False
args.fast_reid_config = r"fast_reid/configs/MOT17/sbs_S50.yml",
args.fast_reid_weights = r"pretrained/mot17_sbs_S50.pth",
args.proximity_thresh = 0.5
args.appearance_thresh = 0.25
args.cmc_method = "sparseOptFlow"
args.ablation = False
return AssembleModel(
detector=trackers.detection.UCMCDetector('demo/ucmc/cam_para.txt', args.WEIGHTS_PATH),
associator=trackers.association.SMILETrackAssociator(args)
)
elif args.MODEL == "gt/ucmc":
return AssembleModel(
detector=trackers.detection.GTDetector(args.SEQUENCES_DIR),
associator=trackers.association.UCMCAssociator()
)
else:
raise ValueError("Unsupported model type, got {}".format(args.MODEL))
def main(args):
# Create window for display the result
if args.SHOW:
cv2.namedWindow(args.TRACKER_NAME, cv2.WINDOW_KEEPRATIO)
n_seqs = len(os.listdir(args.SEQUENCES_DIR))
for seq_index, current_seq in enumerate(sorted(os.listdir(args.SEQUENCES_DIR))):
model = create_model(args)
print(f'[INFO] [{seq_index+1}/{n_seqs}] Working on {current_seq}...')
seq_path = os.path.join(args.SEQUENCES_DIR, current_seq)
utils.run(model, seq_path, args)
def handle_args(args):
"""
This function handle and process arguments so the program can run smoothly after.
"""
args.TRACKER_NAME = os.path.basename(args.WEIGHTS_PATH)[:-3]
# Handle sequences dir
skipped_dir_name = ['test', 'train', 'val']
args.RESULTS_DIR_NAME = []
seq_dir_split = args.SEQUENCES_DIR.rstrip("/").split("/")
while seq_dir_split[-1] in skipped_dir_name:
current_split = seq_dir_split[-1]
seq_dir_split = seq_dir_split[:-1]
if len(current_split) < 1:
continue
args.RESULTS_DIR_NAME.append(current_split)
args.RESULTS_DIR_NAME.append(seq_dir_split[-1])
args.RESULTS_DIR_NAME = '-'.join(args.RESULTS_DIR_NAME[::-1])
if args.SAVE_RESULTS:
if args.SAVE_RESULTS_DIR == 'results':
if os.path.exists('TrackEval'):
args.SAVE_RESULTS_DIR = 'TrackEval/'
dirs_list = 'data/trackers/mot_challenge'.split('/')
for curr_dir in dirs_list:
args.SAVE_RESULTS_DIR = os.path.join(args.SAVE_RESULTS_DIR, curr_dir)
os.makedirs(args.SAVE_RESULTS_DIR, exist_ok=True)
args.SAVE_RESULTS_DIR = os.path.join(args.SAVE_RESULTS_DIR, args.RESULTS_DIR_NAME)
os.makedirs(args.SAVE_RESULTS_DIR, exist_ok=True)
tracker_index = 0
while os.path.exists(os.path.join(args.SAVE_RESULTS_DIR, f'{args.TRACKER_NAME}_{tracker_index:05d}')):
tracker_index += 1
args.SAVE_RESULTS_DIR = os.path.join(args.SAVE_RESULTS_DIR, f'{args.TRACKER_NAME}_{tracker_index:05d}')
os.mkdir(args.SAVE_RESULTS_DIR)
args.SAVE_RESULTS_DIR = os.path.join(args.SAVE_RESULTS_DIR, 'data')
os.mkdir(args.SAVE_RESULTS_DIR)
print(f'[INFO] Results file will be saved to {args.SAVE_RESULTS_DIR}')
return args
if __name__ == "__main__":
try:
args = make_parser().parse_args()
args = handle_args(args)
main(args)
except KeyboardInterrupt:
print('[INFO] Stopped by User...')
cv2.destroyAllWindows()