Skip to content

Commit 451a314

Browse files
committed
load config from sample yaml file and update optionally with user-specific config file
1 parent 56b7a2d commit 451a314

File tree

4 files changed

+41
-20
lines changed

4 files changed

+41
-20
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# our custom Ansible configuration
22
config/config.local.yml
33

4+
# our custom obj_detect configuration
5+
config/config.obj_detect.yml
6+
47
# nano swp files and gedit temp files
58
*.swp
69
*~

ansible/roles/tf_object_detection/tasks/main.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
with_items:
1515
- pillow
1616
- lxml
17+
- pyyaml
1718
- jupyter
1819
- matplotlib
1920
- numpy

config/config.obj_detect.sample.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
---
2+
3+
## ATTENTION: Do not modify 'config.object_detection.sample.yml' !! You should create a copy named 'config.object_detection.yml' and modify that one !!
4+
5+
6+
model_name: 'ssd_mobilenet_v1_coco_11_06_2017'
7+
model_dl_base_path: 'http://download.tensorflow.org/models/object_detection/'
8+
model_dl_file_format: '.tar.gz'

object_detection_tutorial.py renamed to obj_detect.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
import zipfile
88
import time
99
import cv2
10+
import yaml
1011

1112
from Xlib import display, X
1213

1314
from collections import defaultdict
1415
from io import StringIO
15-
#from PIL import Image
16+
from PIL import Image
1617

17-
cap = cv2.VideoCapture(0)
18+
#cap = cv2.VideoCapture(0)
1819
#cap = cv2.VideoCapture('../opencv_extra/testdata/highgui/video/big_buck_bunny.mp4')
1920

2021
sys.path.append('../tensorflow_models/research')
@@ -24,27 +25,35 @@
2425
from utils import label_map_util
2526
from utils import visualization_utils as vis_util
2627

28+
29+
# Load config values from config.obj_detect.sample.yml (as default values) updated by optional user-specific config.obj_detect.yml
30+
## see also http://treyhunner.com/2016/02/how-to-merge-dictionaries-in-python/
31+
cfg = yaml.load(open("config/config.obj_detect.sample.yml", 'r'))
32+
if os.path.isfile("config/config.obj_detect.yml"):
33+
cfg_user = yaml.load(open("config/config.obj_detect.yml", 'r'))
34+
cfg.update(cfg_user)
35+
#for section in cfg:
36+
# print(section, ":", cfg[section])
37+
38+
39+
2740
# Any model exported using the `export_inference_graph.py` tool can be loaded here simply by changing `PATH_TO_CKPT` to point to a new .pb file.
2841
# See the [detection model zoo](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md) for a list of other models that can be run out-of-the-box with varying speeds and accuracies.
2942

30-
# What model to download.
31-
MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017'
32-
MODEL_FILE = MODEL_NAME + '.tar.gz'
33-
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'
34-
3543
# Path to frozen detection graph. This is the actual model that is used for the object detection.
36-
PATH_TO_CKPT = '../' + MODEL_NAME + '/frozen_inference_graph.pb'
44+
PATH_TO_CKPT = '../' + cfg['model_name'] + '/frozen_inference_graph.pb'
3745

3846
# List of the strings that is used to add correct label for each box.
3947
PATH_TO_LABELS = os.path.join('../tensorflow_models/research/object_detection/data', 'mscoco_label_map.pbtxt')
4048

4149
NUM_CLASSES = 90
4250

4351
# ## Download Model
52+
MODEL_FILE = cfg['model_name'] + cfg['model_dl_file_format']
4453
if not os.path.isfile(PATH_TO_CKPT):
4554
print('Model not found. We will download it now.')
4655
opener = urllib.request.URLopener()
47-
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, '../' + MODEL_FILE)
56+
opener.retrieve(cfg['model_dl_base_path'] + MODEL_FILE, '../' + MODEL_FILE)
4857
tar_file = tarfile.open('../' + MODEL_FILE)
4958
for file in tar_file.getmembers():
5059
file_name = os.path.basename(file.name)
@@ -95,25 +104,25 @@
95104

96105
windowPlacedYet = False
97106

98-
while(cap.isOpened()):
99-
# while(True):
107+
# while(cap.isOpened()):
108+
while(True):
100109

101110
dsp = display.Display()
102111
root = dsp.screen().root
103112
reso = root.get_geometry()
104-
# W,H = int(reso.width/2),int(reso.height/2)
113+
W,H = int(reso.width/2),int(reso.height/2)
105114
#W,H = 600,600
106-
# raw = root.get_image(0, 0, W, H, X.ZPixmap, 0xffffffff)
107-
# image = Image.frombytes("RGB", (W, H), raw.data, "raw", "RGBX")
108-
# image_np = np.array(image);
115+
raw = root.get_image(0, 0, W, H, X.ZPixmap, 0xffffffff)
116+
image = Image.frombytes("RGB", (W, H), raw.data, "raw", "RGBX")
117+
image_np = np.array(image);
109118

110119
# image_np_bgr = np.array(ImageGrab.grab(bbox=(0,0,600,600))) # grab(bbox=(10,10,500,500)) or just grab()
111120
# image_np = cv2.cvtColor(image_np_bgr, cv2.COLOR_BGR2RGB)
112121

113-
ret, image_np = cap.read()
114-
if not ret:
115-
print("Video finished!")
116-
break
122+
# ret, image_np = cap.read()
123+
# if not ret:
124+
# print("Video finished!")
125+
# break
117126

118127
# for image_path in TEST_IMAGE_PATHS:
119128
# image = Image.open(image_path)
@@ -150,5 +159,5 @@
150159
counter = 0
151160
start_time = time.time()
152161

153-
cap.release()
162+
#cap.release()
154163
cv2.destroyAllWindows()

0 commit comments

Comments
 (0)