-
Notifications
You must be signed in to change notification settings - Fork 8
/
tusimple.py
102 lines (91 loc) · 3.78 KB
/
tusimple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os.path as osp
import numpy as np
import cv2
import os
import json
from .base_dataset import BaseDataset
from .builder import DATASETS
from pplanedet.utils.tusimple_metric import LaneEval
import logging
import random
SPLIT_FILES = {
'trainval': ['label_data_0313.json', 'label_data_0601.json', 'label_data_0531.json'],
'train': ['label_data_0313.json', 'label_data_0601.json'],
'val': ['label_data_0531.json'],
'test': ['test_label.json'],
}
@DATASETS.register()
class TuSimple(BaseDataset):
def __init__(self, data_root, split, processes=None, cfg=None):
super().__init__(data_root, split, processes, cfg)
self.anno_files = SPLIT_FILES[split]
self.load_annotations()
self.h_samples = list(range(160, 720, 10))
def load_annotations(self):
self.logger.info('Loading TuSimple annotations...')
self.data_infos = []
max_lanes = 0
for anno_file in self.anno_files:
anno_file = osp.join(self.data_root, anno_file)
with open(anno_file, 'r') as anno_obj:
lines = anno_obj.readlines()
for line in lines:
data = json.loads(line)
y_samples = data['h_samples']
gt_lanes = data['lanes']
mask_path = data['raw_file'].replace('clips', 'seg_label')[:-3] + 'png'
lanes = [[(x, y) for (x, y) in zip(lane, y_samples) if x >= 0] for lane in gt_lanes]
lanes = [lane for lane in lanes if len(lane) > 0]
max_lanes = max(max_lanes, len(lanes))
self.data_infos.append({
'img_path': osp.join(self.data_root, data['raw_file']),
'img_name': data['raw_file'],
'mask_path': osp.join(self.data_root, mask_path),
'lanes': lanes,
})
if self.training:
random.shuffle(self.data_infos)
self.max_lanes = max_lanes
def pred2lanes(self, pred):
ys = np.array(self.h_samples) / self.cfg.ori_img_h
lanes = []
for lane in pred:
xs = lane(ys)
invalid_mask = xs < 0
lane = (xs * self.cfg.ori_img_w).astype(int)
lane[invalid_mask] = -2
lanes.append(lane.tolist())
return lanes
def pred2tusimpleformat(self, idx, pred, runtime):
runtime *= 1000. # s to ms
img_name = self.data_infos[idx]['img_name']
lanes = self.pred2lanes(pred)
output = {'raw_file': img_name, 'lanes': lanes, 'run_time': runtime}
return json.dumps(output)
def save_tusimple_predictions(self, predictions, filename, runtimes=None):
if runtimes is None:
runtimes = np.ones(len(predictions)) * 1.e-3
lines = []
for idx, (prediction, runtime) in enumerate(zip(predictions, runtimes)):
line = self.pred2tusimpleformat(idx, prediction, runtime)
lines.append(line)
with open(filename, 'w') as output_file:
output_file.write('\n'.join(lines))
def evaluate(self, predictions, output_basedir, runtimes=None):
self.logger.info("Generating evaluation result on Tusimple")
if not os.path.exists(output_basedir):
os.mkdir(output_basedir)
pred_filename = os.path.join(output_basedir, 'tusimple_predictions.json')
self.save_tusimple_predictions(predictions, pred_filename, runtimes)
acc = 0
try:
json_pred = [json.loads(line)for line in open(pred_filename).readlines()]
except:
acc = -1
if len(json_pred) == 0:
acc = -1
if acc == -1:
return acc
result, acc = LaneEval.bench_one_submit(pred_filename, self.cfg.test_json_file)
self.logger.info(result)
return acc