From 93bd531be8fcd1c2fee56b0d3626a7a769005847 Mon Sep 17 00:00:00 2001 From: Taehoon Lee Date: Mon, 7 May 2018 15:39:52 +0900 Subject: [PATCH] Fix `get_v3_boxes` and Clean up YOLO utils --- README.md | 2 +- .../references/darkflow_utils/get_boxes.pyx | 30 ++++++++++++++----- tensornets/references/darkflow_utils/nms.pxd | 11 ------- tensornets/references/darkflow_utils/nms.pyx | 2 +- tensornets/references/yolo_utils.py | 21 ++++++------- 5 files changed, 33 insertions(+), 33 deletions(-) delete mode 100644 tensornets/references/darkflow_utils/nms.pxd diff --git a/README.md b/README.md index 72267df..ab116b5 100644 --- a/README.md +++ b/README.md @@ -308,7 +308,7 @@ with tf.Session() as sess: | | mAP | Size | Speed | FPS | References | |------------------------------------------------------------------------|--------|--------|-------|-------|------------| -| [YOLOv3VOC](tensornets/references/yolos.py#L175) | 0.7247 | 62M | 24.09 | 41.51 | [[paper]](https://pjreddie.com/media/files/papers/YOLOv3.pdf) [[darknet]](https://pjreddie.com/darknet/yolo/) [[darkflow]](https://github.com/thtrieu/darkflow) | +| [YOLOv3VOC](tensornets/references/yolos.py#L175) | 0.7423 | 62M | 24.09 | 41.51 | [[paper]](https://pjreddie.com/media/files/papers/YOLOv3.pdf) [[darknet]](https://pjreddie.com/darknet/yolo/) [[darkflow]](https://github.com/thtrieu/darkflow) | | [YOLOv2VOC](tensornets/references/yolos.py#L195) | 0.7320 | 51M | 14.75 | 67.80 | [[paper]](https://arxiv.org/abs/1612.08242) [[darknet]](https://pjreddie.com/darknet/yolov2/) [[darkflow]](https://github.com/thtrieu/darkflow) | | [TinyYOLOv2VOC](tensornets/references/yolos.py#L205) | 0.5303 | 16M | 6.534 | 153.0 | [[paper]](https://arxiv.org/abs/1612.08242) [[darknet]](https://pjreddie.com/darknet/yolov2/) [[darkflow]](https://github.com/thtrieu/darkflow) | | [FasterRCNN\_ZF\_VOC](tensornets/references/rcnns.py#L151) | 0.4466 | 59M | 241.4 | 3.325 | [[paper]](https://arxiv.org/abs/1506.01497) [[caffe]](https://github.com/rbgirshick/py-faster-rcnn) [[roi-pooling]](https://github.com/deepsense-ai/roi-pooling) | diff --git a/tensornets/references/darkflow_utils/get_boxes.pyx b/tensornets/references/darkflow_utils/get_boxes.pyx index 9e1d5cf..d6ead23 100644 --- a/tensornets/references/darkflow_utils/get_boxes.pyx +++ b/tensornets/references/darkflow_utils/get_boxes.pyx @@ -5,8 +5,9 @@ cimport numpy as np cimport cython ctypedef np.float_t DTYPE_t from libc.math cimport exp +from libc.math cimport pow from .box import BoundBox -from .nms cimport NMS +from .nms import NMS #expit @cython.boundscheck(False) # turn off bounds-checking for entire function @@ -52,7 +53,7 @@ cdef void _softmax_c(float* x, int classes): @cython.cdivision(True) @cython.boundscheck(False) # turn off bounds-checking for entire function @cython.wraparound(False) # turn off negative index wrapping for entire function -def yolov3_box(meta,np.ndarray[float,ndim=3] net_out_in): +def _yolov3_box(meta,np.ndarray[float,ndim=3] net_out_in,scale_idx): cdef: np.intp_t H, W, _, C, B, row, col, box_loop, class_loop np.intp_t row1, col1, box_loop1,index,index2 @@ -61,10 +62,12 @@ def yolov3_box(meta,np.ndarray[float,ndim=3] net_out_in): double[:] anchors = np.asarray(meta['anchors']) list boxes = list() - H, W, _ = meta['out_size'] + H, W = net_out_in.shape[:2] C = meta['classes'] B = 3 # meta['num'] - anchor_idx = meta['anchor_idx'] + anchor_idx = 6 - 3 * scale_idx + Hin = H * pow(2, 5 - scale_idx) + Win = W * pow(2, 5 - scale_idx) cdef: float[:, :, :, ::1] net_out = net_out_in.reshape([H, W, B, net_out_in.shape[2]/B]) @@ -80,8 +83,8 @@ def yolov3_box(meta,np.ndarray[float,ndim=3] net_out_in): Bbox_pred[row, col, box_loop, 4] = expit_c(Bbox_pred[row, col, box_loop, 4]) Bbox_pred[row, col, box_loop, 0] = (col + expit_c(Bbox_pred[row, col, box_loop, 0])) / W Bbox_pred[row, col, box_loop, 1] = (row + expit_c(Bbox_pred[row, col, box_loop, 1])) / H - Bbox_pred[row, col, box_loop, 2] = exp(Bbox_pred[row, col, box_loop, 2]) * anchors[2 * (box_loop + anchor_idx) + 0] / (W * 32) - Bbox_pred[row, col, box_loop, 3] = exp(Bbox_pred[row, col, box_loop, 3]) * anchors[2 * (box_loop + anchor_idx) + 1] / (H * 32) + Bbox_pred[row, col, box_loop, 2] = exp(Bbox_pred[row, col, box_loop, 2]) * anchors[2 * (box_loop + anchor_idx) + 0] / Win + Bbox_pred[row, col, box_loop, 3] = exp(Bbox_pred[row, col, box_loop, 3]) * anchors[2 * (box_loop + anchor_idx) + 1] / Hin #SOFTMAX BLOCK, no more pointer juggling for class_loop in range(C): arr_max=max_c(arr_max,Classes[row,col,box_loop,class_loop]) @@ -97,7 +100,18 @@ def yolov3_box(meta,np.ndarray[float,ndim=3] net_out_in): #NMS - return NMS(np.ascontiguousarray(probs).reshape(H*W*B,C), np.ascontiguousarray(Bbox_pred).reshape(H*B*W,5)) + return np.ascontiguousarray(probs).reshape(H*W*B,C), np.ascontiguousarray(Bbox_pred).reshape(H*B*W,5) + + +#BOX CONSTRUCTOR +@cython.cdivision(True) +@cython.boundscheck(False) # turn off bounds-checking for entire function +@cython.wraparound(False) # turn off negative index wrapping for entire function +def yolov3_box(meta,np.ndarray[float,ndim=3] out0,np.ndarray[float,ndim=3] out1,np.ndarray[float,ndim=3] out2): + a0, b0 = _yolov3_box(meta, out0, 0) + a1, b1 = _yolov3_box(meta, out1, 1) + a2, b2 = _yolov3_box(meta, out2, 2) + return NMS(np.concatenate([a2, a1, a0], axis=0), np.concatenate([b2, b1, b0], axis=0)) #BOX CONSTRUCTOR @@ -113,7 +127,7 @@ def yolov2_box(meta,np.ndarray[float,ndim=3] net_out_in): double[:] anchors = np.asarray(meta['anchors']) list boxes = list() - H, W, _ = meta['out_size'] + H, W = net_out_in.shape[:2] C = meta['classes'] B = meta['num'] diff --git a/tensornets/references/darkflow_utils/nms.pxd b/tensornets/references/darkflow_utils/nms.pxd deleted file mode 100644 index 5ba188e..0000000 --- a/tensornets/references/darkflow_utils/nms.pxd +++ /dev/null @@ -1,11 +0,0 @@ -from __future__ import absolute_import - -import numpy as np -cimport numpy as np -cimport cython -ctypedef np.float_t DTYPE_t -from libc.math cimport exp -from .box import BoundBox - - -cdef NMS(float[:, ::1] , float[:, ::1] ) diff --git a/tensornets/references/darkflow_utils/nms.pyx b/tensornets/references/darkflow_utils/nms.pyx index 1c42884..b0f00b9 100644 --- a/tensornets/references/darkflow_utils/nms.pyx +++ b/tensornets/references/darkflow_utils/nms.pyx @@ -62,7 +62,7 @@ cdef float box_iou_c(float ax, float ay, float aw, float ah, float bx, float by, @cython.boundscheck(False) # turn off bounds-checking for entire function @cython.wraparound(False) # turn off negative index wrapping for entire function @cython.cdivision(True) -cdef NMS(float[:, ::1] final_probs , float[:, ::1] final_bbox): +def NMS(np.ndarray[float,ndim=2] final_probs, np.ndarray[float,ndim=2] final_bbox): cdef list boxes = list() cdef set indices = set() cdef: diff --git a/tensornets/references/yolo_utils.py b/tensornets/references/yolo_utils.py index a8a617b..2a6594e 100644 --- a/tensornets/references/yolo_utils.py +++ b/tensornets/references/yolo_utils.py @@ -75,15 +75,14 @@ def get_v3_boxes(opts, outs, source_size, threshold=0.1): h, w = source_size boxes = [[] for _ in xrange(opts['classes'])] opts['thresh'] = threshold - for i in range(3): - opts['out_size'] = list(outs[i][0].shape) - opts['anchor_idx'] = 6 - 3 * i - o = np.array(outs[i][0], dtype=np.float32) - results = yolov3_box(opts, o) - for b in results: - idx, box = parse_box(b, threshold, w, h) - if idx is not None: - boxes[idx].append(box) + results = yolov3_box(opts, + np.array(outs[0][0], dtype=np.float32), + np.array(outs[1][0], dtype=np.float32), + np.array(outs[2][0], dtype=np.float32)) + for b in results: + idx, box = parse_box(b, threshold, w, h) + if idx is not None: + boxes[idx].append(box) for i in xrange(opts['classes']): boxes[i] = np.asarray(boxes[i], dtype=np.float32) return boxes @@ -93,9 +92,7 @@ def get_v2_boxes(opts, outs, source_size, threshold=0.1): h, w = source_size boxes = [[] for _ in xrange(opts['classes'])] opts['thresh'] = threshold - opts['out_size'] = list(outs[0].shape) - o = np.array(outs[0], dtype=np.float32) - results = yolov2_box(opts, o) + results = yolov2_box(opts, np.array(outs[0], dtype=np.float32)) for b in results: idx, box = parse_box(b, threshold, w, h) if idx is not None: