Skip to content

Commit a31e7ca

Browse files
committed
outsource visualization part and let user enable or disable it
1 parent 984bbe8 commit a31e7ca

File tree

3 files changed

+64
-36
lines changed

3 files changed

+64
-36
lines changed

config/config.obj_detect.sample.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,6 @@ input_type: screen
1717
## will be passed to OpenCV VideoCapture
1818
#input_video: '../opencv_extra/testdata/highgui/video/big_buck_bunny.mp4'
1919
input_video: 0
20+
21+
# visualize the results of the object detection
22+
visualizer_enabled: True

obj_detect.py

Lines changed: 16 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
import tarfile
66
import tensorflow as tf
77
import zipfile
8-
import time
8+
from datetime import datetime
99
from Xlib import display
1010
import cv2
1111
import yaml
1212

13-
1413
from collections import defaultdict
1514
from io import StringIO
1615
#from PIL import Image
@@ -19,10 +18,7 @@
1918
sys.path.append('../tensorflow_models/research/slim')
2019
sys.path.append('../tensorflow_models/research/object_detection')
2120

22-
from utils import label_map_util
23-
from utils import visualization_utils as vis_util
24-
25-
from stuff.helper import FPS
21+
from stuff.helper import FPS, Visualizer
2622
from stuff.input import ScreenInput, VideoInput
2723

2824
# Load config values from config.obj_detect.sample.yml (as default values) updated by optional user-specific config.obj_detect.yml
@@ -50,10 +46,7 @@
5046
# Path to frozen detection graph. This is the actual model that is used for the object detection.
5147
PATH_TO_CKPT = '../' + cfg['model_name'] + '/frozen_inference_graph.pb'
5248

53-
# List of the strings that is used to add correct label for each box.
54-
PATH_TO_LABELS = os.path.join('../tensorflow_models/research/object_detection/data', 'mscoco_label_map.pbtxt')
5549

56-
NUM_CLASSES = 90
5750

5851
# ## Download Model
5952
MODEL_FILE = cfg['model_name'] + cfg['model_dl_file_format']
@@ -79,12 +72,6 @@
7972
od_graph_def.ParseFromString(serialized_graph)
8073
tf.import_graph_def(od_graph_def, name='')
8174

82-
# ## Loading label map
83-
# Label maps map indices to category names, so that when our convolution network predicts `5`, we know that this corresponds to `airplane`. Here we use internal utility functions, but anything that returns a dictionary mapping integers to appropriate string labels would be fine
84-
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
85-
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
86-
category_index = label_map_util.create_category_index(categories)
87-
8875
# # Detection
8976
PATH_TO_TEST_IMAGES_DIR = 'test_images'
9077
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]
@@ -107,16 +94,20 @@
10794
# TODO: Usually FPS calculation lives in a separate thread. As is now, the interval is a minimum value for each iteration.
10895
fps = FPS(cfg['fps_interval']).start()
10996

110-
windowPlacedYet = False
97+
vis = Visualizer(cfg['visualizer_enabled'])
11198

11299
while(input.isActive()):
100+
101+
# startTime=datetime.now()
102+
113103
ret, image_np = input.getImage()
114104
if not ret:
115-
print("No frames grabbed from input (anymore)! Exit.")
105+
print("No frames grabbed from input (anymore). Exit.")
116106
break
117107

118-
# image_np_bgr = np.array(ImageGrab.grab(bbox=(0,0,600,600))) # grab(bbox=(10,10,500,500)) or just grab()
119-
# image_np = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB)
108+
# timeElapsed=datetime.now()-startTime
109+
# print('1 Time elpased (hh:mm:ss.ms) {}'.format(timeElapsed))
110+
# startTime=datetime.now()
120111

121112
# for image_path in TEST_IMAGE_PATHS:
122113
# image = Image.open(image_path)
@@ -130,22 +121,11 @@
130121
(boxes, scores, classes, num) = sess.run(
131122
[detection_boxes, detection_scores, detection_classes, num_detections],
132123
feed_dict={image_tensor: image_np_expanded})
133-
# Visualization of the results of a detection.
134-
vis_util.visualize_boxes_and_labels_on_image_array(
135-
image_np,
136-
np.squeeze(boxes),
137-
np.squeeze(classes).astype(np.int32),
138-
np.squeeze(scores),
139-
category_index,
140-
use_normalized_coordinates=True,
141-
line_thickness=8)
142-
143-
cv2.imshow('object detection', image_np) # alternatively as 2nd param: cv2.resize(image_np, (800, 600)))
144-
if cv2.waitKey(1) & 0xFF == ord('q'):
145-
break
146-
if not windowPlacedYet:
147-
cv2.moveWindow('object detection', (int)(screen.width/3), (int)(screen.height/3))
148-
windowPlacedYet = True
124+
125+
ret = vis.show(image_np, boxes, classes, scores)
126+
if not ret:
127+
print("User asked to quit. Exit")
128+
break
149129

150130
fps.update()
151131

@@ -154,4 +134,4 @@
154134
print('[INFO] approx. FPS: {:.2f}'.format(fps.fps()))
155135

156136
input.cleanup()
157-
cv2.destroyAllWindows()
137+
vis.cleanup()

stuff/helper.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,18 @@
1+
import os
2+
import numpy as np
13
import datetime
4+
from Xlib import display
5+
import cv2
6+
7+
from utils import label_map_util
8+
from utils import visualization_utils as vis_util
9+
10+
# Loading label map (mapping indices to category names, e.g. 5 -> airplane)
11+
NUM_CLASSES = 90
12+
PATH_TO_LABELS = os.path.join('../tensorflow_models/research/object_detection/data', 'mscoco_label_map.pbtxt')
13+
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
14+
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
15+
category_index = label_map_util.create_category_index(categories)
216

317
class FPS:
418
def __init__(self, interval):
@@ -32,3 +46,34 @@ def elapsed(self):
3246

3347
def fps(self):
3448
return self._glob_numFrames / self.elapsed()
49+
50+
51+
class Visualizer:
52+
def __init__(self, enabled):
53+
self._enabled = enabled
54+
self._windowPlaced = False
55+
self._screen = display.Display().screen().root.get_geometry()
56+
57+
def show(self, image_np, boxes, classes, scores):
58+
if not self._enabled:
59+
return True
60+
61+
vis_util.visualize_boxes_and_labels_on_image_array(
62+
image_np,
63+
np.squeeze(boxes),
64+
np.squeeze(classes).astype(np.int32),
65+
np.squeeze(scores),
66+
category_index,
67+
use_normalized_coordinates=True,
68+
line_thickness=8)
69+
70+
cv2.imshow('Visualizer', image_np) # alternatively as 2nd param: cv2.resize(image_np, (800, 600)))
71+
if not self._windowPlaced:
72+
cv2.moveWindow('Visualizer', (int)((self._screen.width-image_np.shape[1])/2), (int)((self._screen.height-image_np.shape[0])/2))
73+
self._windowPlaced = True
74+
if cv2.waitKey(1) & 0xFF == ord('q'):
75+
return False
76+
return True
77+
78+
def cleanup(self):
79+
cv2.destroyAllWindows()

0 commit comments

Comments
 (0)