-
Notifications
You must be signed in to change notification settings - Fork 0
/
pyserver.py
161 lines (137 loc) · 5.59 KB
/
pyserver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import zmq
import time
import sys
import cv2
import numpy as np
import copy
import sys
import json, pdb
port = "5550"
if len(sys.argv) > 1:
port = int(sys.argv[1])
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind("tcp://*:%s" % port)
class Recognizer:
def __init__(self):
self.colors = {'man': [200, 72, 72], 'skull': [236,236,236]}
self.map = {'man': 0, 'skull': 1, 'ladder': 2, 'door': 3, 'key': 4}
def blob_detect(self, img, id):
mask = np.zeros(np.shape(img))
mask[:,:,0] = self.colors[id][0];
mask[:,:,1] = self.colors[id][1];
mask[:,:,2] = self.colors[id][2];
diff = img - mask
indxs = np.where(diff == 0)
diff[np.where(diff < 0)] = 0
diff[np.where(diff > 0)] = 0
diff[indxs] = 255
mean_y = np.sum(indxs[0]) / np.shape(indxs[0])[0]
mean_x = np.sum(indxs[1]) / np.shape(indxs[1])[0]
return (mean_y, mean_x) #flipped co-ords due to numpy blob detect
def template_detect(self, img, id):
template = cv2.imread('templates/' + id + '.png')
w = np.shape(template)[1]
h = np.shape(template)[0]
res = cv2.matchTemplate(img,template,cv2.TM_CCOEFF_NORMED)
threshold = 0.8
loc = np.where( res >= threshold)
loc[0].setflags(write=True)
loc[1].setflags(write=True)
for i in range(np.shape(loc[0])[0]):
loc[0][i] += h/2; loc[1][i] += w/2
return loc, w, h
def get(self, img):
#detect man
man_coords = self.blob_detect(img, 'man')
skull_coords = self.blob_detect(img, 'skull')
ladder_coords, ladder_w, ladder_h = self.template_detect(img, 'ladder')
key_coords, key_w, key_h = self.template_detect(img, 'key')
door_coords, door_w, door_h = self.template_detect(img, 'door_new')
return {'man': man_coords, 'skull':skull_coords, 'ladder':ladder_coords, 'key':key_coords, 'door':door_coords, 'ladder_w': ladder_w,
'ladder_h':ladder_h , 'key_w':key_w, 'key_h':key_h, 'door_w':door_w, 'door_h':door_h}
def drawbbox(self, inputim, coords):
img = copy.deepcopy(inputim)
for i in ['ladder', 'key', 'door']:
for pt in zip(*coords[i][::-1]):
cv2.rectangle(img, pt, (pt[0] + coords[i+'_w'], pt[1] + coords[i+'_h']), (0,0,255), 2)
cv2.rectangle(img, (coords['man'][0] - 5, coords['man'][1] - 5), (coords['man'][0] + 5, coords['man'][1] + 5), (0,0,255), 2)
cv2.rectangle(img, (coords['skull'][0] - 5, coords['skull'][1] - 5), (coords['skull'][0] + 5, coords['skull'][1] + 5), (0,0,255), 2)
return img
def get_lives(self, img):
return np.sum(img)
def get_onehot(self, ID):
tmp = list(np.zeros(len(self.map)))
tmp[ID] = 1
return tmp
def process_objects(self, objects):
objects_list = []
objects_list.append([objects['man'][0], objects['man'][1]] + self.get_onehot(self.map['man']))
objects_list.append([objects['skull'][0], objects['skull'][1]] + self.get_onehot(self.map['skull']))
for obj, val in objects.items():
# print(obj, val)
if obj is not 'man' and obj is not 'skull':
if type(val) is not type(1):
if type(val[0]) == np.int64:
objects_list.append([val[0], val[1]] + self.get_onehot(self.map[obj]))
else:
for i in range(np.shape(val[0])[0]):
objects_list.append([val[0][i], val[1][i]] + self.get_onehot(self.map[obj]))
#process objects and pad with zeros to ensure fixed length state dim
fill_objects = 8 - len(objects_list)
for j in range(fill_objects):
objects_list.append([0, 0] + list(np.zeros(len(self.map))))
return objects_list
def show(img):
cv2.imshow('image',img)
cv2.waitKey(0)
# cv2.destroyAllWindows()
def unit_test():
rec = Recognizer()
try:
img_id = str(sys.argv[1])
except:
print 'Using default image 1.png'
img_id = '1'
img_rgb = cv2.imread('image2.png')
im_score = img_rgb[15:20, 55:95, :]
img_rgb = img_rgb[30:,:,:]
coords = rec.get(img_rgb)
objects = rec.process_objects(coords)
pdb.set_trace()
img = rec.drawbbox(img_rgb, coords)
show(img)
# unit_test()
rec = Recognizer()
img_rgb = cv2.imread('image2.png')
im_score = img_rgb[15:20, 55:95, :]
img_rgb = img_rgb[30:,:,:]
coords = rec.get(img_rgb)
objects_list_cache = rec.process_objects(coords)
while True:
# Wait for next request from client
message = socket.recv()
# print "Received request: ", message
img_rgb = cv2.imread('tmp_'+str(port)+'.png')
print "rbg pre: ", img_rgb.shape
im_score = img_rgb[15:20, 55:95, :]
print "score shape: ",im_score.shape
#cv2.imshow("imscore",im_score)
img_rgb = img_rgb[30:,:,:]
print "rgb shape: ",img_rgb.shape
#cv2.imshow("im_rgb",img_rgb)
#cv2.waitKey(0)
coords = rec.get(img_rgb)
# img = rec.drawbbox(img_rgb, coords)
# show(img)
objects_list = copy.deepcopy(objects_list_cache)
objects_list2 = rec.process_objects(coords)
# print(objects_list2[0])
#agent and skull is dynamic. everything else is static. TODO for key
objects_list[0] = objects_list2[0]
objects_list[1] = objects_list2[1]
if objects_list[1][0] == 0 and objects_list[1][1] == 0:
objects_list[1][3] = 0
socket.send(json.dumps(objects_list))
# socket.send("World from %s" % str(coords))
# print(rec.get_lives(im_score))