Skip to content

Commit

Permalink
Fix for the case where there are no detections (apache#9784)
Browse files Browse the repository at this point in the history
Use cv2.cvtColor instead of np to convert from BGR 2 RGB
Fix context in detector
  • Loading branch information
larroy authored and zhreshold committed Apr 23, 2018
1 parent c9b8b4b commit 60116b5
Showing 1 changed file with 48 additions and 31 deletions.
79 changes: 48 additions & 31 deletions example/ssd/detect/detector.py
Expand Up @@ -15,12 +15,12 @@
# specific language governing permissions and limitations
# under the License.

from __future__ import print_function
import mxnet as mx
import numpy as np
from timeit import default_timer as timer
from dataset.testdb import TestDB
from dataset.iterator import DetIter
import logging

class Detector(object):
"""
Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(self, symbol, model_prefix, epoch, data_shape, mean_pixels, \
load_symbol, args, auxs = mx.model.load_checkpoint(model_prefix, epoch)
if symbol is None:
symbol = load_symbol
self.mod = mx.mod.Module(symbol, label_names=None, context=ctx)
self.mod = mx.mod.Module(symbol, label_names=None, context=self.ctx)
if not isinstance(data_shape, tuple):
data_shape = (data_shape, data_shape)
self.data_shape = data_shape
Expand Down Expand Up @@ -81,13 +81,9 @@ def detect(self, det_iter, show_timer=False):
detections = self.mod.predict(det_iter).asnumpy()
time_elapsed = timer() - start
if show_timer:
print("Detection time for {} images: {:.4f} sec".format(
logging.info("Detection time for {} images: {:.4f} sec".format(
num_images, time_elapsed))
result = []
for i in range(detections.shape[0]):
det = detections[i, :, :]
res = det[np.where(det[:, 0] >= 0)[0]]
result.append(res)
result = Detector.filter_positive_detections(detections)
return result

def im_detect(self, im_list, root_dir=None, extension=None, show_timer=False):
Expand Down Expand Up @@ -136,31 +132,52 @@ class names
height = img.shape[0]
width = img.shape[1]
colors = dict()
for i in range(dets.shape[0]):
cls_id = int(dets[i, 0])
if cls_id >= 0:
score = dets[i, 1]
if score > thresh:
if cls_id not in colors:
colors[cls_id] = (random.random(), random.random(), random.random())
xmin = int(dets[i, 2] * width)
ymin = int(dets[i, 3] * height)
xmax = int(dets[i, 4] * width)
ymax = int(dets[i, 5] * height)
rect = plt.Rectangle((xmin, ymin), xmax - xmin,
ymax - ymin, fill=False,
edgecolor=colors[cls_id],
linewidth=3.5)
plt.gca().add_patch(rect)
class_name = str(cls_id)
if classes and len(classes) > cls_id:
class_name = classes[cls_id]
plt.gca().text(xmin, ymin - 2,
'{:s} {:.3f}'.format(class_name, score),
bbox=dict(facecolor=colors[cls_id], alpha=0.5),
for det in dets:
(klass, score, x0, y0, x1, y1) = det
if score < thresh:
continue
cls_id = int(klass)
if cls_id not in colors:
colors[cls_id] = (random.random(), random.random(), random.random())
xmin = int(x0 * width)
ymin = int(y0 * height)
xmax = int(x1 * width)
ymax = int(y1 * height)
rect = plt.Rectangle((xmin, ymin), xmax - xmin,
ymax - ymin, fill=False,
edgecolor=colors[cls_id],
linewidth=3.5)
plt.gca().add_patch(rect)
class_name = str(cls_id)
if classes and len(classes) > cls_id:
class_name = classes[cls_id]
plt.gca().text(xmin, ymin - 2,
'{:s} {:.3f}'.format(class_name, score),
bbox=dict(facecolor=colors[cls_id], alpha=0.5),
fontsize=12, color='white')
plt.show()

@staticmethod
def filter_positive_detections(detections):
"""
First column (class id) is -1 for negative detections
:param detections:
:return:
"""
class_idx = 0
assert(isinstance(detections, mx.nd.NDArray) or isinstance(detections, np.ndarray))
detections_per_image = []
# for each image
for i in range(detections.shape[0]):
result = []
det = detections[i, :, :]
for obj in det:
if obj[class_idx] >= 0:
result.append(obj)
detections_per_image.append(result)
logging.info("%d positive detections", len(result))
return detections_per_image

def detect_and_visualize(self, im_list, root_dir=None, extension=None,
classes=[], thresh=0.6, show_timer=False):
"""
Expand All @@ -187,5 +204,5 @@ def detect_and_visualize(self, im_list, root_dir=None, extension=None,
assert len(dets) == len(im_list)
for k, det in enumerate(dets):
img = cv2.imread(im_list[k])
img[:, :, (0, 1, 2)] = img[:, :, (2, 1, 0)]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
self.visualize_detection(img, det, classes, thresh)

0 comments on commit 60116b5

Please sign in to comment.