-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
executable file
·74 lines (55 loc) · 2.13 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import logging
import os
import cv2
import tensorflow as tf
from lib.config import Config
from lib.prediction.input_params import InputParamsResolver
from lib.prediction.model import Model
from lib.stream import ImageReader
from lib.stream import assert_video_port_availability, VideoReader
from lib.stream.writer.image_writter import ImageWriter
from lib.stream.writer.video_writter import VideoWriter
logging.getLogger().setLevel(logging.INFO)
def hide_tensorflow_logs(): os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
def create_reader(params):
if 'input_image' in params:
return ImageReader([params['input_image']])
elif 'input_video' in params:
return VideoReader(params['input_video'])
elif 'input_webcam' in params:
video_port = assert_video_port_availability(params['input_webcam'])
return VideoReader(video_port)
def create_writer(params):
if 'input_image' in params:
return ImageWriter(params['output'])
elif 'input_webcam' in params:
return VideoWriter(params['output'])
def show(frame):
if params['show_preview'] or params['input_webcam']:
scaled_frame = frame.scale(params['preview_scale'])
cv2.imshow('Object detection', scaled_frame.raw)
def create_model(params: object) -> object:
cfg = Config('./config.yml')
return Model(
params['model_path'],
params['label_map_path'],
classes=cfg.property('labels')
)
def check_end_prediction_action_keys(): return cv2.waitKey(25) & 0xFF == ord('q')
if __name__ == '__main__':
hide_tensorflow_logs()
params = InputParamsResolver().resolve()
model = create_model(params)
reader, writer = create_reader(params), create_writer(params)
with model.graph.as_default():
with tf.compat.v1.Session(graph=model.graph) as session:
for frame in reader:
if not params['disable_bboxes']:
model.predict(session, frame)
writer.write(frame)
show(frame)
if check_end_prediction_action_keys():
break
cv2.destroyAllWindows()
reader.close()
writer.close()