diff --git a/README.md b/README.md index 22ccf73..daa7b3f 100644 --- a/README.md +++ b/README.md @@ -4,16 +4,21 @@ SSD is an unified framework for object detection with a single network. You can use the code to train/evaluate/test for object detection task. -*This repo is still under construction.* - ### Disclaimer This is a re-implementation of original SSD which is based on caffe. The official repository is available [here](https://github.com/weiliu89/caffe/tree/ssd). The arXiv paper is available [here](http://arxiv.org/abs/1512.02325). This example is intended for reproducing the nice detector while fully utilize the -remarkable traits of MXNet. However: -* The model is not compatible with caffe version due to the implementation details. +remarkable traits of MXNet. +* The model is fully compatible with caffe version due to the implementation details. +* Model converter from caffe is available, I'll release it once I can convert any symbol other than VGG16. + +### Demo results +![demo1](https://cloud.githubusercontent.com/assets/3307514/19171057/8e1a0cc4-8be0-11e6-9d8f-088c25353b40.png) +![demo2](https://cloud.githubusercontent.com/assets/3307514/19171063/91ec2792-8be0-11e6-983c-773bd6868fa8.png) +![demo3](https://cloud.githubusercontent.com/assets/3307514/19171086/a9346842-8be0-11e6-8011-c17716b22ad3.png) + ### Getting started * You will need python modules: `easydict`, `cv2`, `matplotlib` and `numpy`. @@ -34,20 +39,19 @@ git clone --recursive https://github.com/zhreshold/mxnet-ssd.git # git submodule update --recursive --init cd mxnet-ssd/mxnet ``` -* Build MXNet with extra layers: Follow the official instructions -[here](http://mxnet.readthedocs.io/en/latest/how_to/build.html), and add extra -layers in `config.mk` by pointing `EXTRA_OPERATORS = ../operator/`. +* Build MXNet: Follow the official instructions +[here](http://mxnet.readthedocs.io/en/latest/how_to/build.html). Remember to enable CUDA if you want to be able to train, since CPU training is -insanely slow. Using CUDNN is not fully tested but should be fine. +insanely slow. Using CUDNN is optional, it's not fully tested but should be fine. ### Try the demo -* Download the pretrained model: `to_be_added`, and extract to `model/` directory. +* Download the pretrained model: [`ssd_300_voc_0712.zip`](https://dl.dropboxusercontent.com/u/39265872/ssd_300_voc0712.zip), and extract to `model/` directory. (This model is converted from VGG_VOC0712_SSD_300x300_iter_60000.caffemodel provided by paper author). * Run ``` # cd /path/to/mxnet-ssd python demo.py # play with examples: -python demo.py --images ./data/demo/dog.jpg --thresh 0.3 +python demo.py --epoch 0 --images ./data/demo/dog.jpg --thresh 0.3 ``` * Check `python demo.py --help` for more options. @@ -55,7 +59,7 @@ python demo.py --images ./data/demo/dog.jpg --thresh 0.3 This example only covers training on Pascal VOC dataset. Other datasets should be easily supported by adding subclass derived from class `Imdb` in `dataset/imdb.py`. See example of `dataset/pascal_voc.py` for details. -* Download the converted pretrained `vgg16_reduced` model: , put `.param` and `.json` files +* Download the converted pretrained `vgg16_reduced` model [here](https://dl.dropboxusercontent.com/u/39265872/vgg16_reduced.zip), unzip `.param` and `.json` files into `model/` directory by default. * Download the PASCAL VOC dataset, skip this step if you already have one. ``` @@ -75,10 +79,11 @@ in the same `VOCdevkit` folder. `ln -s /path/to/VOCdevkit /path/to/this_example/data/VOCdevkit`. Use hard link instead of copy could save us a bit disk space. * Start training: `python train.py` -* By default, this example will use `batch-size=32` and `learning_rate=0.004`. +* By default, this example will use `batch-size=32` and `learning_rate=0.002`. You might need to change the parameters a bit if you have different configurations. Check `python train.py --help` for more training options. For example, if you have 4 GPUs, use: ``` +# note that a perfect training parameter set is yet to be found for multi-gpu python train.py --gpus 0,1,2,3 --batch-size 128 --lr 0.005 ``` diff --git a/data/demo/000005.jpg b/data/demo/000005.jpg deleted file mode 100644 index b42aaa7..0000000 Binary files a/data/demo/000005.jpg and /dev/null differ diff --git a/data/demo/000012.jpg b/data/demo/000012.jpg deleted file mode 100644 index b829107..0000000 Binary files a/data/demo/000012.jpg and /dev/null differ diff --git a/data/demo/2008_000145.jpg b/data/demo/2008_000145.jpg deleted file mode 100755 index 28f672e..0000000 Binary files a/data/demo/2008_000145.jpg and /dev/null differ diff --git a/data/demo/bangkok2.jpg b/data/demo/bangkok2.jpg deleted file mode 100644 index 10c52db..0000000 Binary files a/data/demo/bangkok2.jpg and /dev/null differ diff --git a/data/demo/dogcat.jpg b/data/demo/dogcat.jpg deleted file mode 100644 index 297e086..0000000 Binary files a/data/demo/dogcat.jpg and /dev/null differ diff --git a/data/demo/monitor.jpg b/data/demo/monitor.jpg deleted file mode 100644 index 48bd32c..0000000 Binary files a/data/demo/monitor.jpg and /dev/null differ diff --git a/data/demo/stoplight.jpg b/data/demo/stoplight.jpg deleted file mode 100644 index 0427a1b..0000000 Binary files a/data/demo/stoplight.jpg and /dev/null differ diff --git a/data/demo/umbrella.jpg b/data/demo/umbrella.jpg deleted file mode 100644 index ce62ca8..0000000 Binary files a/data/demo/umbrella.jpg and /dev/null differ diff --git a/demo.py b/demo.py index e6f782c..e40e192 100644 --- a/demo.py +++ b/demo.py @@ -48,7 +48,7 @@ def parse_args(): parser.add_argument('--ext', dest='extension', help='image extension, optional', type=str, nargs='?') parser.add_argument('--epoch', dest='epoch', help='epoch of trained model', - default=200, type=int) + default=0, type=int) parser.add_argument('--prefix', dest='prefix', help='trained model prefix', default=os.path.join(os.getcwd(), 'model', 'ssd'), type=str) parser.add_argument('--cpu', dest='cpu', help='(override GPU) use CPU to detect', @@ -63,7 +63,7 @@ def parse_args(): help='green mean value') parser.add_argument('--mean-b', dest='mean_b', type=float, default=104, help='blue mean value') - parser.add_argument('--thresh', dest='thresh', type=float, default=0.6, + parser.add_argument('--thresh', dest='thresh', type=float, default=0.5, help='object visualize score threshold, default 0.6') parser.add_argument('--nms', dest='nms_thresh', type=float, default=0.5, help='non-maximum suppression threshold, default 0.5') diff --git a/mxnet b/mxnet index 3ae36d3..6ca3354 160000 --- a/mxnet +++ b/mxnet @@ -1 +1 @@ -Subproject commit 3ae36d3eb8185df5e591c04d1fafa23968a24096 +Subproject commit 6ca33546e5c9d10a6cb6ff4c279bb28e285f6f31 diff --git a/operator/multibox_detection-inl.h b/operator/multibox_detection-inl.h deleted file mode 100644 index 2fb7e16..0000000 --- a/operator/multibox_detection-inl.h +++ /dev/null @@ -1,254 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file multibox_detection-inl.h - * \brief post-process multibox detection predictions - * \author Joshua Zhang -*/ -#ifndef MXNET_OPERATOR_MULTIBOX_DETECTION_INL_H_ -#define MXNET_OPERATOR_MULTIBOX_DETECTION_INL_H_ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "./operator_common.h" - -namespace mxnet { -namespace op { -namespace mboxdet_enum { -enum MultiBoxDetectionOpInputs {kClsProb, kLocPred, kAnchor}; -enum MultiBoxDetectionOpOutputs {kOut}; -enum MultiBoxDetectionOpResource {kTempSpace}; -} // namespace mboxdet_enum - -struct VarInfo { - VarInfo() {} - explicit VarInfo(std::vector in) : info(in) {} - - std::vector info; -}; // struct VarInfo - -inline std::istream &operator>>(std::istream &is, VarInfo &size) { - while (true) { - char ch = is.get(); - if (ch == '(') break; - if (!isspace(ch)) { - is.setstate(std::ios::failbit); - return is; - } - } - float f; - std::vector tmp; - // deal with empty case - // safe to remove after stop using target_size - size_t pos = is.tellg(); - char ch = is.get(); - if (ch == ')') { - size.info = tmp; - return is; - } - is.seekg(pos); - // finish deal - while (is >> f) { - tmp.push_back(f); - char ch; - do { - ch = is.get(); - } while (isspace(ch)); - if (ch == ',') { - while (true) { - ch = is.peek(); - if (isspace(ch)) { - is.get(); continue; - } - if (ch == ')') { - is.get(); break; - } - break; - } - if (ch == ')') break; - } else if (ch == ')') { - break; - } else { - is.setstate(std::ios::failbit); - return is; - } - } - size.info = tmp; - return is; -} - -inline std::ostream &operator<<(std::ostream &os, const VarInfo &size) { - os << '('; - for (index_t i = 0; i < size.info.size(); ++i) { - if (i != 0) os << ','; - os << size.info[i]; - } - // python style tuple - if (size.info.size() == 1) os << ','; - os << ')'; - return os; -} - -struct MultiBoxDetectionParam : public dmlc::Parameter { - bool clip; - float threshold; - int background_id; - float nms_threshold; - bool force_suppress; - VarInfo variances; - DMLC_DECLARE_PARAMETER(MultiBoxDetectionParam) { - DMLC_DECLARE_FIELD(clip).set_default(true) - .describe("Clip out-of-boundary boxes."); - DMLC_DECLARE_FIELD(threshold).set_default(0.01f) - .describe("Threshold to be a positive prediction."); - DMLC_DECLARE_FIELD(background_id).set_default(0) - .describe("Background id."); - DMLC_DECLARE_FIELD(nms_threshold).set_default(0.5f) - .describe("Non-maximum suppression threshold."); - DMLC_DECLARE_FIELD(force_suppress).set_default(false) - .describe("Suppress all detections regardless of class_id."); - DMLC_DECLARE_FIELD(variances).set_default(VarInfo({0.1, 0.1, 0.2, 0.2})) - .describe("Variances to be decoded from box regression output."); - } -}; // struct MultiBoxDetectionParam - -template -class MultiBoxDetectionOp : public Operator { - public: - explicit MultiBoxDetectionOp(MultiBoxDetectionParam param) { - this->param_ = param; - } - - virtual void Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - CHECK_EQ(in_data.size(), 3) << "Input: [cls_prob, loc_pred, anchor]"; - TShape cshape = in_data[mboxdet_enum::kClsProb].shape_; - TShape lshape = in_data[mboxdet_enum::kLocPred].shape_; - TShape ashape = in_data[mboxdet_enum::kAnchor].shape_; - CHECK_EQ(cshape.ndim(), 3); - CHECK_EQ(lshape.ndim(), 2); - CHECK_EQ(ashape.ndim(), 3); - CHECK_GE(cshape[1], 2) << "Number of classes must > 1"; - CHECK_EQ(cshape[2], ashape[1]) << "Number of anchors mismatch"; - CHECK_EQ(cshape[2] * 4, lshape[1]) << "# anchors mismatch with # loc"; - CHECK_GT(ashape[1], 0) << "Number of anchors must > 0"; - CHECK_EQ(ashape[2], 4); - CHECK_EQ(out_data.size(), 1); - CHECK_EQ(out_data[mboxdet_enum::kOut].size(1), ashape[1]); - - Stream *s = ctx.get_stream(); - Tensor cls_prob = in_data[mboxdet_enum::kClsProb] - .get(s); - Tensor loc_pred = in_data[mboxdet_enum::kLocPred] - .get(s); - Tensor anchors = in_data[mboxdet_enum::kAnchor] - .get_with_shape(Shape2(ashape[1], 4), s); - Tensor out = out_data[mboxdet_enum::kOut] - .get(s); - Tensor temp_space = ctx.requested[mboxdet_enum::kTempSpace] - .get_space_typed(out.shape_, s); - - MultiBoxDetectionForward(out, cls_prob, loc_pred, anchors, - param_.threshold, param_.clip, param_.variances.info); - NonMaximumSuppression(out, temp_space, param_.nms_threshold, param_.force_suppress); - } - - virtual void Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_states) { - using namespace mshadow; - using namespace mshadow::expr; -} - - private: - MultiBoxDetectionParam param_; -}; // class MultiBoxDetectionOp - -template -Operator *CreateOp(MultiBoxDetectionParam, int dtype); - -#if DMLC_USE_CXX11 -class MultiBoxDetectionProp : public OperatorProperty { - public: - void Init(const std::vector >& kwargs) override { - param_.Init(kwargs); - } - - std::map GetParams() const override { - return param_.__DICT__(); - } - - std::vector ListArguments() const override { - return {"cls_prob", "loc_pred", "anchor"}; - } - - bool InferShape(std::vector *in_shape, - std::vector *out_shape, - std::vector *aux_shape) const override { - using namespace mshadow; - CHECK_EQ(in_shape->size(), 3) << "Inputs: [cls_prob, loc_pred, anchor]"; - TShape cshape = in_shape->at(mboxdet_enum::kClsProb); - TShape lshape = in_shape->at(mboxdet_enum::kLocPred); - TShape ashape = in_shape->at(mboxdet_enum::kAnchor); - CHECK_EQ(cshape.ndim(), 3) << "Provided: " << cshape; - CHECK_EQ(lshape.ndim(), 2) << "Provided: " << lshape; - CHECK_EQ(ashape.ndim(), 3) << "Provided: " << ashape; - CHECK_EQ(cshape[2], ashape[1]) << "Number of anchors mismatch"; - CHECK_EQ(cshape[2] * 4, lshape[1]) << "# anchors mismatch with # loc"; - CHECK_GT(ashape[1], 0) << "Number of anchors must > 0"; - CHECK_EQ(ashape[2], 4); - TShape oshape = TShape(3); - oshape[0] = cshape[0]; - oshape[1] = ashape[1]; - oshape[2] = 6; // [id, prob, xmin, ymin, xmax, ymax] - out_shape->clear(); - out_shape->push_back(oshape); - return true; - } - - OperatorProperty* Copy() const override { - auto ptr = new MultiBoxDetectionProp(); - ptr->param_ = param_; - return ptr; - } - - std::string TypeString() const override { - return "MultiBoxDetection"; - } - - std::vector ForwardResource( - const std::vector &in_shape) const override { - return {ResourceRequest::kTempSpace}; - } - - Operator* CreateOperator(Context ctx) const override { - LOG(FATAL) << "Not implemented"; - return NULL; - } - - Operator* CreateOperatorEx(Context ctx, std::vector *in_shape, - std::vector *in_type) const override; - - private: - MultiBoxDetectionParam param_; -}; // class MultiBoxDetectionProp -#endif // DMLC_USE_CXX11 - -} // namespace op -} // namespace mxnet - -#endif // MXNET_OPERATOR_MULTIBOX_DETECTION_INL_H_ diff --git a/operator/multibox_detection.cc b/operator/multibox_detection.cc deleted file mode 100644 index 77d57c7..0000000 --- a/operator/multibox_detection.cc +++ /dev/null @@ -1,178 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file multibox_detection.cc - * \brief MultiBoxDetection op - * \author Joshua Zhang -*/ -#include "./multibox_detection-inl.h" -#include - -namespace mshadow { -template -struct SortElemDescend { - DType value; - int index; - - SortElemDescend(DType v, int i) { - value = v; - index = i; - } - - bool operator<(const SortElemDescend &other) const { - return value > other.value; - } -}; - -template -inline void TransformLocations(DType *out, const DType *anchors, - const DType *loc_pred, bool clip, - float vx, float vy, float vw, float vh) { - // transform predictions to detection results - DType al = anchors[0]; - DType at = anchors[1]; - DType ar = anchors[2]; - DType ab = anchors[3]; - DType aw = ar - al; - DType ah = ab - at; - DType ax = (al + ar) / 2.f; - DType ay = (at + ab) / 2.f; - DType px = loc_pred[0]; - DType py = loc_pred[1]; - DType pw = loc_pred[2]; - DType ph = loc_pred[3]; - DType ox = px * vx * aw + ax; - DType oy = py * vy * ah + ay; - DType ow = exp(pw * vw) * aw / 2; - DType oh = exp(ph * vh) * ah / 2; - out[0] = clip ? std::max(DType(0), std::min(DType(1), ox - ow)) : (ox - ow); - out[1] = clip ? std::max(DType(0), std::min(DType(1), oy - oh)) : (oy - oh); - out[2] = clip ? std::max(DType(0), std::min(DType(1), ox + ow)) : (ox + ow); - out[3] = clip ? std::max(DType(0), std::min(DType(1), oy + oh)) : (oy + oh); -} - -template -inline void MultiBoxDetectionForward(const Tensor &out, - const Tensor &cls_prob, - const Tensor &loc_pred, - const Tensor &anchors, - float threshold, bool clip, - const std::vector &variances) { - CHECK_EQ(variances.size(), 4) << "Variance size must be 4"; - index_t num_classes = cls_prob.size(1); - index_t num_anchors = cls_prob.size(2); - const DType *p_anchor = anchors.dptr_; - for (index_t nbatch = 0; nbatch < cls_prob.size(0); ++nbatch) { - const DType *p_cls_prob = cls_prob.dptr_ + nbatch * num_classes * num_anchors; - const DType *p_loc_pred = loc_pred.dptr_ + nbatch * num_anchors * 4; - DType *p_out = out.dptr_ + nbatch * num_anchors * 6; - for (index_t i = 0; i < num_anchors; ++i) { - // find the predicted class id and probability - DType score = p_cls_prob[i]; - int id = 0; - for (int j = 1; j < num_classes; ++j) { - DType temp = p_cls_prob[j * num_anchors + i]; - if (temp > score) { - score = temp; - id = j; - } - } - if (id > 0 && score < threshold) { - id = 0; - } - // [id, prob, xmin, ymin, xmax, ymax] - p_out[i * 6] = id - 1; // remove background, restore original id - p_out[i * 6 + 1] = (id == 0 ? DType(-1) : score); - index_t offset = i * 4; - TransformLocations(p_out + i * 6 + 2, p_anchor + offset, - p_loc_pred + offset, clip, variances[0], variances[1], - variances[2], variances[3]); - } // end iter num_anchors - } // end iter batch -} - -template -inline DType CalculateOverlap(const DType *a, const DType *b) { - DType w = std::max(DType(0), std::min(a[2], b[2]) - std::max(a[0], b[0])); - DType h = std::max(DType(0), std::min(a[3], b[3]) - std::max(a[1], b[1])); - DType i = w * h; - DType u = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - i; - return u <= 0.f ? static_cast(0) : static_cast(i / u); -} - -template -inline void NonMaximumSuppression(const Tensor &out, - const Tensor &temp_space, - float nms_threshold, bool force_suppress) { - Copy(temp_space, out, out.stream_); - index_t num_anchors = out.size(1); - for (index_t nbatch = 0; nbatch < out.size(0); ++nbatch) { - DType *pout = out.dptr_ + nbatch * num_anchors * 6; - // sort confidence in descend order - std::vector> sorter; - sorter.reserve(num_anchors); - for (index_t i = 0; i < num_anchors; ++i) { - DType id = pout[i * 6]; - if (id >= 0) { - sorter.push_back(SortElemDescend(pout[i * 6 + 1], i)); - } else { - sorter.push_back(SortElemDescend(DType(0), i)); - } - } - std::stable_sort(sorter.begin(), sorter.end()); - // re-order output - DType *ptemp = temp_space.dptr_ + nbatch * num_anchors * 6; - for (index_t i = 0; i < sorter.size(); ++i) { - for (index_t j = 0; j < 6; ++j) { - pout[i * 6 + j] = ptemp[sorter[i].index * 6 + j]; - } - } - // apply nms - for (index_t i = 0; i < num_anchors; ++i) { - index_t offset_i = i * 6; - if (pout[offset_i] < 0) continue; // skip eliminated - for (index_t j = i + 1; j < num_anchors; ++j) { - index_t offset_j = j * 6; - if (pout[offset_j] < 0) continue; // skip eliminated - if (force_suppress || (pout[offset_i] == pout[offset_j])) { - // when foce_suppress == true or class_id equals - DType iou = CalculateOverlap(pout + offset_i + 2, pout + offset_j + 2); - if (iou >= nms_threshold) { - pout[offset_j] = -1; - } - } - } - } - } -} -} // namespace mshadow - -namespace mxnet { -namespace op { -template<> -Operator *CreateOp(MultiBoxDetectionParam param, int dtype) { - Operator *op = NULL; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new MultiBoxDetectionOp(param); - }); - return op; -} - -Operator* MultiBoxDetectionProp::CreateOperatorEx(Context ctx, - std::vector *in_shape, - std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); - CHECK(InferType(in_type, &out_type, &aux_type)); - DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); -} - -DMLC_REGISTER_PARAMETER(MultiBoxDetectionParam); -MXNET_REGISTER_OP_PROPERTY(MultiBoxDetection, MultiBoxDetectionProp) -.describe("Convert multibox detection predictions.") -.add_argument("cls_prob", "Symbol", "Class probabilities.") -.add_argument("loc_pred", "Symbol", "Location regression predictions.") -.add_argument("anchors", "Symbol", "Multibox prior anchor boxes") -.add_arguments(MultiBoxDetectionParam::__FIELDS__()); -} // namespace op -} // namespace mxnet diff --git a/operator/multibox_detection.cu b/operator/multibox_detection.cu deleted file mode 100644 index cfb55d7..0000000 --- a/operator/multibox_detection.cu +++ /dev/null @@ -1,217 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file multibox_detection.cu - * \brief MultiBoxDetection op - * \author Joshua Zhang -*/ -#include "./multibox_detection-inl.h" - -#define WARPS_PER_BLOCK 16 -#define THREADS_PER_WARP 32 - -#define MULTIBOX_DETECTION_CUDA_CHECK(condition) \ - /* Code block avoids redefinition of cudaError_t error */ \ - do { \ - cudaError_t error = condition; \ - CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \ - } while (0) - -namespace mshadow { -namespace cuda { -template -__device__ void Clip(DType *value, DType lower, DType upper) { - if ((*value) < lower) *value = lower; - if ((*value) > upper) *value = upper; -} - -template -__global__ void MergePredictions(DType *out, const DType *cls_prob, - const DType *loc_pred, const DType *anchors, - int num_classes, int num_anchors, - int num_batches, float threshold, bool clip, - float vx, float vy, float vw, float vh) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index >= num_batches * num_anchors) return; - for (int i = index; i < num_batches * num_anchors; i += blockDim.x * gridDim.x) { - int n_batch = i / num_anchors; - int n_anchor = i % num_anchors; - const DType *p_cls_prob = cls_prob + n_batch * num_classes * num_anchors; - const DType *p_loc_pred = loc_pred + n_batch * num_anchors * 4; - DType *p_out = out + n_batch * num_anchors * 6; - DType score = p_cls_prob[n_anchor]; - int id = 0; - for (int j = 1; j < num_classes; ++j) { - DType temp = p_cls_prob[j * num_anchors + n_anchor]; - if (temp > score) { - score = temp; - id = j; - } - } - if (id > 0 && score < threshold) { - id = 0; - } - p_out[n_anchor * 6] = id - 1; // restore original class id - p_out[n_anchor * 6 + 1] = (id == 0 ? DType(-1) : score); - int offset = n_anchor * 4; - DType al = anchors[offset]; - DType at = anchors[offset + 1]; - DType ar = anchors[offset + 2]; - DType ab = anchors[offset + 3]; - DType aw = ar - al; - DType ah = ab - at; - DType ax = (al + ar) / 2.f; - DType ay = (at + ab) / 2.f; - DType ox = p_loc_pred[offset] * vx * aw + ax; - DType oy = p_loc_pred[offset + 1] * vy * ah + ay; - DType ow = exp(p_loc_pred[offset + 2] * vw) * aw / 2; - DType oh = exp(p_loc_pred[offset + 3] * vh) * ah / 2; - DType xmin = ox - ow; - DType ymin = oy - oh; - DType xmax = ox + ow; - DType ymax = oy + oh; - if (clip) { - Clip(&xmin, DType(0), DType(1)); - Clip(&ymin, DType(0), DType(1)); - Clip(&xmax, DType(0), DType(1)); - Clip(&ymax, DType(0), DType(1)); - } - p_out[n_anchor * 6 + 2] = xmin; - p_out[n_anchor * 6 + 3] = ymin; - p_out[n_anchor * 6 + 4] = xmax; - p_out[n_anchor * 6 + 5] = ymax; - } -} - -template -__global__ void MergeSortDescend(DType *src, DType *dst, int size, - int width, int slices, int step, int offset) { - int index = blockDim.x * blockIdx.x + threadIdx.x; - int start = width * index * slices; - for (int slice = 0; slice < slices; ++slice) { - if (start >= size) break; - int middle = start + (width >> 1); - if (middle > size) middle = size; - int end = start + width; - if (end > size) end = size; - int i = start; - int j = middle; - for (int k = start; k < end; ++k) { - DType score_i = i < size ? src[i * step + offset] : DType(-1); - DType score_j = j < size ? src[j * step + offset] : DType(-1); - if (i < middle && (j >= end || score_i > score_j)) { - for (int n = 0; n < step; ++n) { - dst[k * step + n] = src[i * step + n]; - } - ++i; - } else { - for (int n = 0; n < step; ++n) { - dst[k * step + n] = src[j * step + n]; - } - ++j; - } - } - start += width; - } -} - -template -__device__ void CalculateOverlap(const DType *a, const DType *b, DType *iou) { - DType w = max(DType(0), min(a[2], b[2]) - max(a[0], b[0])); - DType h = max(DType(0), min(a[3], b[3]) - max(a[1], b[1])); - DType i = w * h; - DType u = (a[2] - a[0]) * (a[3] - a[1]) + (b[2] - b[0]) * (b[3] - b[1]) - i; - (*iou) = u <= 0.f ? static_cast(0) : static_cast(i / u); -} - -template -__global__ void ApplyNMS(DType *out, int pos, int num_anchors, - int step, int id_index, int loc_index, - bool force_suppress, float nms_threshold) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - DType compare_id = out[pos * step + id_index]; - if (compare_id < 0) return; // not a valid positive detection, skip - DType *compare_loc_ptr = out + pos * step + loc_index; - for (int i = index; i < num_anchors; i += blockDim.x * gridDim.x) { - if (i <= pos) continue; - DType class_id = out[i * step + id_index]; - if (class_id < 0) continue; - if (force_suppress || (class_id == compare_id)) { - DType iou; - CalculateOverlap(compare_loc_ptr, out + i * step + loc_index, &iou); - if (iou >= nms_threshold) { - out[i * step + id_index] = -1; - } - } - } -} -} // namespace cuda - -template -inline void MultiBoxDetectionForward(const Tensor &out, - const Tensor &cls_prob, - const Tensor &loc_pred, - const Tensor &anchors, - float threshold, bool clip, - const std::vector &variances) { - CHECK_EQ(variances.size(), 4) << "Variance size must be 4"; - int num_classes = cls_prob.size(1); - int num_anchors = cls_prob.size(2); - int num_batches = cls_prob.size(0); - const int num_threads = THREADS_PER_WARP * WARPS_PER_BLOCK; - int num_samples = num_batches * num_anchors; - int num_blocks = (num_samples - 1) / num_threads + 1; - cuda::MergePredictions<<>>(out.dptr_, cls_prob.dptr_, - loc_pred.dptr_, anchors.dptr_, num_classes, num_anchors, num_batches, - threshold, clip, variances[0], variances[1], variances[2], variances[3]); - MULTIBOX_DETECTION_CUDA_CHECK(cudaPeekAtLastError()); -} - -template -inline void NonMaximumSuppression(const Tensor &out, - const Tensor &temp_space, - float nms_threshold, bool force_suppress) { - int num_anchors = out.size(1); - int total_threads = num_anchors / 2 + 1; - const int num_threads = WARPS_PER_BLOCK * THREADS_PER_WARP; - int num_blocks = (total_threads - 1) / num_threads + 1; - // sort detection results - for (int nbatch = 0; nbatch < out.size(0); ++nbatch) { - DType *src_ptr = out.dptr_ + nbatch * num_anchors * 6; - DType *dst_ptr = temp_space.dptr_ + nbatch * num_anchors * 6; - DType *src = src_ptr; - DType *dst = dst_ptr; - for (int width = 2; width < (num_anchors << 1); width <<= 1) { - int slices = (num_anchors - 1) / (total_threads * width) + 1; - cuda::MergeSortDescend<<>>(src, dst, num_anchors, - width, slices, 6, 1); - MULTIBOX_DETECTION_CUDA_CHECK(cudaPeekAtLastError()); - src = src == src_ptr? dst_ptr : src_ptr; - dst = dst == src_ptr? dst_ptr : src_ptr; - } - } - // apply nms - num_blocks = (num_anchors - 1) / num_threads + 1; - for (int nbatch = 0; nbatch < out.size(0); ++nbatch) { - DType *ptr = out.dptr_ + nbatch * num_anchors * 6; - for (int pos = 0; pos < num_anchors; ++pos) { - // suppress against position: pos - cuda::ApplyNMS<<>>(ptr, pos, num_anchors, - 6, 0, 2, force_suppress, nms_threshold); - MULTIBOX_DETECTION_CUDA_CHECK(cudaPeekAtLastError()); - } - } -} -} // namespace mshadow - -namespace mxnet { -namespace op { -template<> -Operator *CreateOp(MultiBoxDetectionParam param, int dtype) { - Operator *op = NULL; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new MultiBoxDetectionOp(param); - }); - return op; -} -} // namespace op -} // namespace mxnet diff --git a/operator/multibox_prior-inl.h b/operator/multibox_prior-inl.h deleted file mode 100644 index 261cfd1..0000000 --- a/operator/multibox_prior-inl.h +++ /dev/null @@ -1,250 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file multibox_prior-inl.h - * \brief generate multibox prior boxes - * \author Joshua Zhang -*/ -#ifndef MXNET_OPERATOR_MULTIBOX_PRIOR_INL_H_ -#define MXNET_OPERATOR_MULTIBOX_PRIOR_INL_H_ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "./operator_common.h" - - -namespace mxnet { -namespace op { - -namespace mshadow_op { -struct clip_zero_one { - template - MSHADOW_XINLINE static DType Map(DType a) { - if (a < 0.f) return DType(0.f); - if (a > 1.f) return DType(1.f); - return DType(a); - } -}; // struct clip_zero_one -} // namespace mshadow_op - -namespace mboxprior_enum { -enum MultiBoxPriorOpInputs {kData}; -enum MultiBoxPriorOpOutputs {kOut}; -} // namespace mboxprior_enum - -struct SizeInfo { - SizeInfo() {} - explicit SizeInfo(std::vector in) : info(in) {} - - std::vector info; -}; // struct SizeInfo - -inline std::istream &operator>>(std::istream &is, SizeInfo &size) { - while (true) { - char ch = is.get(); - if (ch == '(') break; - if (!isspace(ch)) { - is.setstate(std::ios::failbit); - return is; - } - } - float f; - std::vector tmp; - // deal with empty case - // safe to remove after stop using target_size - size_t pos = is.tellg(); - char ch = is.get(); - if (ch == ')') { - size.info = tmp; - return is; - } - is.seekg(pos); - // finish deal - while (is >> f) { - tmp.push_back(f); - char ch; - do { - ch = is.get(); - } while (isspace(ch)); - if (ch == ',') { - while (true) { - ch = is.peek(); - if (isspace(ch)) { - is.get(); continue; - } - if (ch == ')') { - is.get(); break; - } - break; - } - if (ch == ')') break; - } else if (ch == ')') { - break; - } else { - is.setstate(std::ios::failbit); - return is; - } - } - size.info = tmp; - return is; -} - -inline std::ostream &operator<<(std::ostream &os, const SizeInfo &size) { - os << '('; - for (index_t i = 0; i < size.info.size(); ++i) { - if (i != 0) os << ','; - os << size.info[i]; - } - // python style tuple - if (size.info.size() == 1) os << ','; - os << ')'; - return os; -} - -struct MultiBoxPriorParam : public dmlc::Parameter { - SizeInfo sizes; - SizeInfo ratios; - bool clip; - DMLC_DECLARE_PARAMETER(MultiBoxPriorParam) { - DMLC_DECLARE_FIELD(sizes).set_default(SizeInfo({1.0f})) - .describe("List of sizes of generated MultiBoxPriores."); - DMLC_DECLARE_FIELD(ratios).set_default(SizeInfo({1.0f})) - .describe("List of aspect ratios of generated MultiBoxPriores."); - DMLC_DECLARE_FIELD(clip).set_default(false) - .describe("Whether to clip out-of-boundary boxes."); - } -}; // struct MultiBoxPriorParam - -template -class MultiBoxPriorOp : public Operator { - public: - explicit MultiBoxPriorOp(MultiBoxPriorParam param) - : clip_(param.clip), sizes_(param.sizes.info), ratios_(param.ratios.info) { - CHECK_GT(sizes_.size(), 0); - CHECK_GT(ratios_.size(), 0); - } - - virtual void Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - CHECK_EQ(static_cast(in_data.size()), 1); - CHECK_GE(in_data[mboxprior_enum::kData].ndim(), 4); // require spatial information - int in_height = in_data[mboxprior_enum::kData].size(2); - CHECK_GT(in_height, 0); - int in_width = in_data[mboxprior_enum::kData].size(3); - CHECK_GT(in_width, 0); - CHECK_EQ(out_data.size(), 1); - Stream *s = ctx.get_stream(); - Tensor out; - // TODO(Joshua Zhang): this implementation is to be compliant to original ssd in caffe - // The prior boxes could be implemented in more versatile ways - // since input sizes are same in each batch, we could share MultiBoxPrior - int num_sizes = static_cast(sizes_.size()); - int num_ratios = static_cast(ratios_.size()); - int num_anchors = num_sizes - 1 + num_ratios; // anchors per location - Shape<2> oshape = Shape2(num_anchors * in_width * in_height, 4); - out = out_data[mboxprior_enum::kOut].get_with_shape(oshape, s); - MultiBoxPriorForward(out, sizes_, ratios_, in_width, in_height); - - if (clip_) { - Assign(out, req[mboxprior_enum::kOut], F(out)); - } - } - - virtual void Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_states) { - using namespace mshadow; - using namespace mshadow::expr; - Stream *s = ctx.get_stream(); - Tensor grad = in_grad[mboxprior_enum::kData].FlatTo2D(s); - grad = 0.f; - } - - private: - bool clip_; - std::vector sizes_; - std::vector ratios_; -}; // class MultiBoxPriorOp - -template -Operator *CreateOp(MultiBoxPriorParam, int dtype); - -#if DMLC_USE_CXX11 -class MultiBoxPriorProp: public OperatorProperty { - public: - void Init(const std::vector >& kwargs) override { - param_.Init(kwargs); - } - - std::map GetParams() const override { - return param_.__DICT__(); - } - - std::vector ListArguments() const override { - return {"data"}; - } - - bool InferShape(std::vector *in_shape, - std::vector *out_shape, - std::vector *aux_shape) const override { - using namespace mshadow; - CHECK_EQ(in_shape->size(), 1) << "Inputs: [data]" << in_shape->size(); - TShape dshape = in_shape->at(mboxprior_enum::kData); - CHECK_GE(dshape.ndim(), 4) << "Input data should be 4D: batch-channel-y-x"; - int in_height = dshape[2]; - CHECK_GT(in_height, 0) << "Input height should > 0"; - int in_width = dshape[3]; - CHECK_GT(in_width, 0) << "Input width should > 0"; - // since input sizes are same in each batch, we could share MultiBoxPrior - TShape oshape = TShape(3); - int num_sizes = param_.sizes.info.size(); - int num_ratios = param_.ratios.info.size(); - oshape[0] = 1; - oshape[1] = in_height * in_width * (num_sizes + num_ratios - 1); - oshape[2] = 4; - out_shape->clear(); - out_shape->push_back(oshape); - return true; - } - - OperatorProperty* Copy() const override { - auto ptr = new MultiBoxPriorProp(); - ptr->param_ = param_; - return ptr; - } - - std::string TypeString() const override { - return "MultiBoxPrior"; - } - - Operator* CreateOperator(Context ctx) const override { - LOG(FATAL) << "Not implemented"; - return NULL; - } - - Operator* CreateOperatorEx(Context ctx, std::vector *in_shape, - std::vector *in_type) const override; - - private: - MultiBoxPriorParam param_; -}; // class MultiBoxPriorProp -#endif // DMLC_USE_CXX11 - -} // namespace op -} // namespace mxnet - -#endif // MXNET_OPERATOR_MULTIBOX_PRIOR_INL_H_ diff --git a/operator/multibox_prior.cc b/operator/multibox_prior.cc deleted file mode 100644 index fa47faf..0000000 --- a/operator/multibox_prior.cc +++ /dev/null @@ -1,82 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file multibox_prior.cc - * \brief generate multibox prior boxes cpu implementation - * \author Joshua Zhang -*/ - -#include "./multibox_prior-inl.h" - -namespace mshadow { -template -inline void MultiBoxPriorForward(const Tensor &out, - const std::vector &sizes, - const std::vector &ratios, - int in_width, int in_height) { - float step_x = 1.f / in_width; - float step_y = 1.f / in_height; - int num_sizes = static_cast(sizes.size()); - int num_ratios = static_cast(ratios.size()); - int count = 0; - - for (int r = 0; r < in_height; ++r) { - float center_y = (r + 0.5) * step_y; - for (int c = 0; c < in_width; ++c) { - float center_x = (c + 0.5) * step_x; - // ratio = 1, various sizes - for (int i = 0; i < num_sizes; ++i) { - float size = sizes[i]; - float w = size / 2; - float h = size / 2; - out[count][0] = center_x - w; // xmin - out[count][1] = center_y - h; // ymin - out[count][2] = center_x + w; // xmax - out[count][3] = center_y + h; // ymax - ++count; - } - // various ratios, size = min_size = size[0] - float size = sizes[0]; - for (int j = 1; j < num_ratios; ++j) { - float ratio = sqrtf(ratios[j]); - float w = size * ratio / 2; - float h = size / ratio / 2; - out[count][0] = center_x - w; // xmin - out[count][1] = center_y - h; // ymin - out[count][2] = center_x + w; // xmax - out[count][3] = center_y + h; // ymax - ++count; - } - } - } -} -} // namespace mshadow - -namespace mxnet { -namespace op { -template<> -Operator* CreateOp(MultiBoxPriorParam param, int dtype) { - Operator *op = NULL; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new MultiBoxPriorOp(param); - }); - return op; -} - -Operator* MultiBoxPriorProp::CreateOperatorEx(Context ctx, std::vector *in_shape, - std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); - CHECK(InferType(in_type, &out_type, &aux_type)); - DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); -} - -DMLC_REGISTER_PARAMETER(MultiBoxPriorParam); - -MXNET_REGISTER_OP_PROPERTY(MultiBoxPrior, MultiBoxPriorProp) -.add_argument("data", "Symbol", "Input data.") -.add_arguments(MultiBoxPriorParam::__FIELDS__()) -.describe("Generate prior(anchor) boxes from data, sizes and ratios."); - -} // namespace op -} // namespace mxnet diff --git a/operator/multibox_prior.cu b/operator/multibox_prior.cu deleted file mode 100644 index 613ac5e..0000000 --- a/operator/multibox_prior.cu +++ /dev/null @@ -1,90 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file multibox_prior.cu - * \brief generate multibox prior boxes cuda kernels - * \author Joshua Zhang -*/ - -#include "./multibox_prior-inl.h" - -#define WARPS_PER_BLOCK 1 -#define THREADS_PER_WARP 32 - -#define MULTIBOXPRIOR_CUDA_CHECK(condition) \ - /* Code block avoids redefinition of cudaError_t error */ \ - do { \ - cudaError_t error = condition; \ - CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \ - } while (0) - -namespace mshadow { -namespace cuda { -template -__global__ void AssignPriors(DType *out, float size, float sqrt_ratio, int in_width, - int in_height, float step_x, float step_y, int stride, int offset) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index >= in_width * in_height) return; - int r = index / in_width; - int c = index % in_width; - float center_x = (c + 0.5) * step_x; - float center_y = (r + 0.5) * step_y; - float w = size * sqrt_ratio / 2; // half width - float h = size / sqrt_ratio / 2; // half height - DType *ptr = out + index * stride + 4 * offset; - *(ptr++) = center_x - w; // xmin - *(ptr++) = center_y - h; // ymin - *(ptr++) = center_x + w; // xmax - *(ptr++) = center_y + h; // ymax -} -} // namespace cuda - -template -inline void MultiBoxPriorForward(const Tensor &out, - const std::vector &sizes, - const std::vector &ratios, - int in_width, int in_height) { - CHECK_EQ(out.CheckContiguous(), true); - cudaStream_t stream = Stream::GetStream(out.stream_); - DType *out_ptr = out.dptr_; - float step_x = 1.f / in_width; - float step_y = 1.f / in_height; - int num_sizes = static_cast(sizes.size()); - int num_ratios = static_cast(ratios.size()); - - int num_thread = THREADS_PER_WARP * WARPS_PER_BLOCK; - dim3 thread_dim(num_thread); - dim3 block_dim((in_width * in_height - 1) / num_thread + 1); - - int stride = 4 * (num_sizes + num_ratios - 1); - int offset = 0; - // ratio = 1, various sizes - for (int i = 0; i < num_sizes; ++i) { - cuda::AssignPriors<<>>(out_ptr, - sizes[i], 1.f, in_width, in_height, step_x, step_y, stride, offset); - ++offset; - } - MULTIBOXPRIOR_CUDA_CHECK(cudaPeekAtLastError()); - - // size = sizes[0], various ratios - for (int j = 1; j < num_ratios; ++j) { - cuda::AssignPriors<<>>(out_ptr, - sizes[0], sqrtf(ratios[j]), in_width, in_height, step_x, step_y, stride, offset); - ++offset; - } - MULTIBOXPRIOR_CUDA_CHECK(cudaPeekAtLastError()); -} -} // namespace mshadow - -namespace mxnet { -namespace op { -template<> -Operator* CreateOp(MultiBoxPriorParam param, int dtype) { - Operator *op = NULL; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new MultiBoxPriorOp(param); - }); - return op; -} - -} // namespace op -} // namespace mxnet diff --git a/operator/multibox_target-inl.h b/operator/multibox_target-inl.h deleted file mode 100644 index 6de78e7..0000000 --- a/operator/multibox_target-inl.h +++ /dev/null @@ -1,341 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file multibox_target-inl.h - * \brief - * \author Joshua Zhang -*/ -#ifndef MXNET_OPERATOR_MULTIBOX_TARGET_INL_H_ -#define MXNET_OPERATOR_MULTIBOX_TARGET_INL_H_ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "./operator_common.h" -#include "./mshadow_op.h" - -namespace mxnet { -namespace op { - -namespace mshadow_op { -struct safe_divide { - template - MSHADOW_XINLINE static DType Map(DType a, DType b) { - if (b == DType(0.0f)) return DType(0.0f); - return DType(a / b); - } -}; // struct safe_divide -} // namespace mshadow_op - -namespace mboxtarget_enum { -enum MultiBoxTargetOpInputs {kAnchor, kLabel, kClsPred}; -enum MultiBoxTargetOpOutputs {kLoc, kLocMask, kCls}; -enum MultiBoxTargetOpResource {kTempSpace}; -} // namespace mboxtarget_enum - -struct VarsInfo { - VarsInfo() {} - explicit VarsInfo(std::vector in) : info(in) {} - - std::vector info; -}; // struct VarsInfo - -inline std::istream &operator>>(std::istream &is, VarsInfo &size) { - while (true) { - char ch = is.get(); - if (ch == '(') break; - if (!isspace(ch)) { - is.setstate(std::ios::failbit); - return is; - } - } - float f; - std::vector tmp; - // deal with empty case - // safe to remove after stop using target_size - size_t pos = is.tellg(); - char ch = is.get(); - if (ch == ')') { - size.info = tmp; - return is; - } - is.seekg(pos); - // finish deal - while (is >> f) { - tmp.push_back(f); - char ch; - do { - ch = is.get(); - } while (isspace(ch)); - if (ch == ',') { - while (true) { - ch = is.peek(); - if (isspace(ch)) { - is.get(); continue; - } - if (ch == ')') { - is.get(); break; - } - break; - } - if (ch == ')') break; - } else if (ch == ')') { - break; - } else { - is.setstate(std::ios::failbit); - return is; - } - } - size.info = tmp; - return is; -} - -inline std::ostream &operator<<(std::ostream &os, const VarsInfo &size) { - os << '('; - for (index_t i = 0; i < size.info.size(); ++i) { - if (i != 0) os << ','; - os << size.info[i]; - } - // python style tuple - if (size.info.size() == 1) os << ','; - os << ')'; - return os; -} - -struct MultiBoxTargetParam : public dmlc::Parameter { - float overlap_threshold; - float ignore_label; - float negative_mining_ratio; - float negative_mining_thresh; - int minimum_negative_samples; - VarsInfo variances; - DMLC_DECLARE_PARAMETER(MultiBoxTargetParam) { - DMLC_DECLARE_FIELD(overlap_threshold).set_default(0.5f) - .describe("Anchor-GT overlap threshold to be regarded as a possitive match."); - DMLC_DECLARE_FIELD(ignore_label).set_default(-1.0f) - .describe("Label for ignored anchors."); - DMLC_DECLARE_FIELD(negative_mining_ratio).set_default(-1.0f) - .describe("Max negative to positive samples ratio, use -1 to disable mining"); - DMLC_DECLARE_FIELD(negative_mining_thresh).set_default(0.5f) - .describe("Threshold used for negative mining."); - DMLC_DECLARE_FIELD(minimum_negative_samples).set_default(0) - .describe("Minimum number of negative samples."); - DMLC_DECLARE_FIELD(variances).set_default(VarsInfo({0.1, 0.1, 0.2, 0.2})) - .describe("Variances to be encoded in box regression target."); - } -}; // struct MultiBoxTargetParam - -template -class MultiBoxTargetOp : public Operator { - public: - explicit MultiBoxTargetOp(MultiBoxTargetParam param) { - this->param_ = param; - } - - virtual void Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow_op; - using namespace mshadow::expr; - CHECK_EQ(in_data.size(), 3); - CHECK_EQ(in_data[mboxtarget_enum::kAnchor].ndim(), 3) << "Anchors: 1*N*4."; - CHECK_EQ(in_data[mboxtarget_enum::kAnchor].size(0), 1); - CHECK_GT(in_data[mboxtarget_enum::kAnchor].size(1), 0) - << "Number of anchors must > 0."; - CHECK_EQ(in_data[mboxtarget_enum::kAnchor].size(2), 4); - CHECK_EQ(in_data[mboxtarget_enum::kLabel].ndim(), 3) << "Labels: batch*M*5"; - CHECK_GT(in_data[mboxtarget_enum::kLabel].size(1), 0) - << "Number of ground-truth must > 0."; - CHECK_EQ(in_data[mboxtarget_enum::kLabel].size(2), 5); - CHECK_EQ(in_data[mboxtarget_enum::kAnchor].size(1), - in_data[mboxtarget_enum::kClsPred].size(2)) << "# anchors mismatch"; - CHECK_GT(in_data[mboxtarget_enum::kClsPred].size(1), 1); - CHECK_EQ(out_data.size(), 3); - Stream *s = ctx.get_stream(); - Tensor anchors = in_data[mboxtarget_enum::kAnchor] - .get_with_shape( - Shape2(in_data[mboxtarget_enum::kAnchor].size(1), 4), s); - Tensor labels = in_data[mboxtarget_enum::kLabel] - .get(s); - Tensor cls_preds = in_data[mboxtarget_enum::kClsPred] - .get(s); - Tensor loc_target = out_data[mboxtarget_enum::kLoc] - .get(s); - Tensor loc_mask = out_data[mboxtarget_enum::kLocMask] - .get(s); - Tensor cls_target = out_data[mboxtarget_enum::kCls] - .get(s); - - index_t num_batches = labels.size(0); - index_t num_anchors = anchors.size(0); - index_t num_labels = labels.size(1); - // TODO(Joshua Zhang): use maximum valid ground-truth in batch rather than # in dataset - Shape<4> temp_shape = Shape4(11, num_batches, num_anchors, num_labels); - Tensor temp_space = ctx.requested[mboxtarget_enum::kTempSpace] - .get_space_typed(temp_shape, s); - loc_target = 0.f; - loc_mask = 0.0f; - cls_target = param_.ignore_label; - temp_space = -1.0f; - CHECK_EQ(anchors.CheckContiguous(), true); - CHECK_EQ(labels.CheckContiguous(), true); - CHECK_EQ(cls_preds.CheckContiguous(), true); - CHECK_EQ(loc_target.CheckContiguous(), true); - CHECK_EQ(loc_mask.CheckContiguous(), true); - CHECK_EQ(cls_target.CheckContiguous(), true); - CHECK_EQ(temp_space.CheckContiguous(), true); - - // compute overlaps - // TODO(Joshua Zhang): squeeze temporary memory space - // temp_space, 0:out, 1:l1, 2:t1, 3:r1, 4:b1, 5:l2, 6:t2, 7:r2, 8:b2 - // 9: intersection, 10:union - temp_space[1] = broadcast_keepdim(broadcast_with_axis(slice<1>(anchors, 0, 1), -1, - num_batches), 2, num_labels); - temp_space[2] = broadcast_keepdim(broadcast_with_axis(slice<1>(anchors, 1, 2), -1, - num_batches), 2, num_labels); - temp_space[3] = broadcast_keepdim(broadcast_with_axis(slice<1>(anchors, 2, 3), -1, - num_batches), 2, num_labels); - temp_space[4] = broadcast_keepdim(broadcast_with_axis(slice<1>(anchors, 3, 4), -1, - num_batches), 2, num_labels); - Shape<3> temp_reshape = Shape3(num_batches, 1, num_labels); - temp_space[5] = broadcast_keepdim(reshape(slice<2>(labels, 1, 2), temp_reshape), 1, - num_anchors); - temp_space[6] = broadcast_keepdim(reshape(slice<2>(labels, 2, 3), temp_reshape), 1, - num_anchors); - temp_space[7] = broadcast_keepdim(reshape(slice<2>(labels, 3, 4), temp_reshape), 1, - num_anchors); - temp_space[8] = broadcast_keepdim(reshape(slice<2>(labels, 4, 5), temp_reshape), 1, - num_anchors); - temp_space[9] = F(ScalarExp(0.0f), - F(temp_space[3], temp_space[7]) - F(temp_space[1], temp_space[5])) - * F(ScalarExp(0.0f), - F(temp_space[4], temp_space[8]) - F(temp_space[2], temp_space[6])); - temp_space[10] = (temp_space[3] - temp_space[1]) * (temp_space[4] - temp_space[2]) - + (temp_space[7] - temp_space[5]) * (temp_space[8] - temp_space[6]) - - temp_space[9]; - temp_space[0] = F(temp_space[9], temp_space[10]); - - MultiBoxTargetForward(loc_target, loc_mask, cls_target, - anchors, labels, cls_preds, temp_space, - param_.overlap_threshold, - param_.ignore_label, - param_.negative_mining_ratio, - param_.negative_mining_thresh, - param_.minimum_negative_samples, - param_.variances.info); - } - - virtual void Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - Stream *s = ctx.get_stream(); - Tensor grad = in_grad[mboxtarget_enum::kClsPred].FlatTo2D(s); - grad = 0.f; -} - - private: - MultiBoxTargetParam param_; -}; // class MultiBoxTargetOp - -template -Operator* CreateOp(MultiBoxTargetParam param, int dtype); - -#if DMLC_USE_CXX11 -class MultiBoxTargetProp : public OperatorProperty { - public: - std::vector ListArguments() const override { - return {"anchor", "label", "cls_pred"}; - } - - std::vector ListOutputs() const override { - return {"loc_target", "loc_mask", "cls_target"}; - } - - void Init(const std::vector >& kwargs) override { - param_.Init(kwargs); - } - - std::map GetParams() const override { - return param_.__DICT__(); - } - - bool InferShape(std::vector *in_shape, - std::vector *out_shape, - std::vector *aux_shape) const override { - using namespace mshadow; - CHECK_EQ(in_shape->size(), 3) << "Input: [anchor, label, clsPred]"; - TShape ashape = in_shape->at(mboxtarget_enum::kAnchor); - CHECK_EQ(ashape.ndim(), 3) << "Anchor should be batch shared N*4 tensor"; - CHECK_EQ(ashape[0], 1) << "Anchors are shared across batches, first dim=1"; - CHECK_GT(ashape[1], 0) << "Number boxes should > 0"; - CHECK_EQ(ashape[2], 4) << "Box dimension should be 4: [xmin-ymin-xmax-ymax]"; - TShape lshape = in_shape->at(mboxtarget_enum::kLabel); - CHECK_EQ(lshape.ndim(), 3) << "Label should be [batch-num_labels-5] tensor"; - CHECK_GT(lshape[1], 0) << "Padded label should > 0"; - CHECK_EQ(lshape[2], 5) << "Label should be [batch-num_labels-5] tensor"; - TShape pshape = in_shape->at(mboxtarget_enum::kClsPred); - CHECK_EQ(pshape.ndim(), 3) << "Prediction: [nbatch-num_classes-num_anchors]"; - CHECK_EQ(pshape[2], ashape[1]) << "Number of anchors mismatch"; - TShape loc_shape = Shape2(lshape[0], ashape.Size()); // batch - (num_box * 4) - TShape lm_shape = loc_shape; - TShape label_shape = Shape2(lshape[0], ashape[1]); // batch - num_box - out_shape->clear(); - out_shape->push_back(loc_shape); - out_shape->push_back(lm_shape); - out_shape->push_back(label_shape); - return true; - } - - OperatorProperty* Copy() const override { - MultiBoxTargetProp* MultiBoxTarget_sym = new MultiBoxTargetProp(); - MultiBoxTarget_sym->param_ = this->param_; - return MultiBoxTarget_sym; - } - - std::string TypeString() const override { - return "MultiBoxTarget"; - } - - // decalre dependency and inplace optimization options - std::vector DeclareBackwardDependency( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data) const override { - return {}; - } - - std::vector ForwardResource( - const std::vector &in_shape) const override { - return {ResourceRequest::kTempSpace}; - } - - Operator* CreateOperator(Context ctx) const override { - LOG(FATAL) << "Not implemented"; - return NULL; - } - - Operator* CreateOperatorEx(Context ctx, std::vector *in_shape, - std::vector *in_type) const override; - - private: - MultiBoxTargetParam param_; -}; // class MultiBoxTargetProp -#endif // DMLC_USE_CXX11 - -} // namespace op -} // namespace mxnet - -#endif // MXNET_OPERATOR_MULTIBOX_TARGET_INL_H_ diff --git a/operator/multibox_target.cc b/operator/multibox_target.cc deleted file mode 100644 index 3686797..0000000 --- a/operator/multibox_target.cc +++ /dev/null @@ -1,288 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file multibox_target.cc - * \brief MultiBoxTarget op - * \author Joshua Zhang -*/ -#include "./multibox_target-inl.h" -#include -#include "./mshadow_op.h" - -namespace mshadow { -template -inline void AssignLocTargets(const DType *anchor, const DType *l, DType *dst, - float vx, float vy, float vw, float vh) { - float al = *(anchor); - float at = *(anchor+1); - float ar = *(anchor+2); - float ab = *(anchor+3); - float aw = ar - al; - float ah = ab - at; - float ax = (al + ar) * 0.5; - float ay = (at + ab) * 0.5; - float gl = *(l); - float gt = *(l+1); - float gr = *(l+2); - float gb = *(l+3); - float gw = gr - gl; - float gh = gb - gt; - float gx = (gl + gr) * 0.5; - float gy = (gt + gb) * 0.5; - *(dst) = DType((gx - ax) / aw / vx); - *(dst+1) = DType((gy - ay) / ah / vy); - *(dst+2) = DType(std::log(gw / aw) / vw); - *(dst+3) = DType(std::log(gh / ah) / vh); -} - -struct SortElemDescend { - float value; - int index; - - SortElemDescend(float v, int i) { - value = v; - index = i; - } - - bool operator<(const SortElemDescend &other) const { - return value > other.value; - } -}; - -template -inline void MultiBoxTargetForward(const Tensor &loc_target, - const Tensor &loc_mask, - const Tensor &cls_target, - const Tensor &anchors, - const Tensor &labels, - const Tensor &cls_preds, - const Tensor &temp_space, - float overlap_threshold, float background_label, - float negative_mining_ratio, - float negative_mining_thresh, - int minimum_negative_samples, - const std::vector &variances) { - const DType *p_anchor = anchors.dptr_; - index_t num_labels = labels.size(1); - index_t num_anchors = anchors.size(0); - for (index_t nbatch = 0; nbatch < labels.size(0); ++nbatch) { - const DType *p_label = labels.dptr_ + nbatch * num_labels * 5; - const DType *p_overlaps = temp_space.dptr_ + nbatch * num_anchors * num_labels; - index_t num_valid_gt = 0; - for (index_t i = 0; i < num_labels; ++i) { - if (static_cast(*(p_label + i * 5)) == -1.0f) { - CHECK_EQ(static_cast(*(p_label + i * 5 + 1)), -1.0f); - CHECK_EQ(static_cast(*(p_label + i * 5 + 2)), -1.0f); - CHECK_EQ(static_cast(*(p_label + i * 5 + 3)), -1.0f); - CHECK_EQ(static_cast(*(p_label + i * 5 + 4)), -1.0f); - break; - } - ++num_valid_gt; - } // end iterate labels - - if (num_valid_gt > 0) { - std::vector gt_flags(num_valid_gt, false); - std::vector> max_matches(num_anchors, - std::pair(-1.0f, -1)); - std::vector anchor_flags(num_anchors, -1); // -1 means don't care - int num_positive = 0; - while (std::find(gt_flags.begin(), gt_flags.end(), false) != gt_flags.end()) { - // ground-truths not fully matched - int best_anchor = -1; - int best_gt = -1; - float max_overlap = 1e-6; // start with a very small positive overlap - for (index_t j = 0; j < num_anchors; ++j) { - if (anchor_flags[j] == 1) { - continue; // already matched this anchor - } - const DType *pp_overlaps = p_overlaps + j * num_labels; - for (index_t k = 0; k < num_valid_gt; ++k) { - if (gt_flags[k]) { - continue; // already matched this gt - } - float iou = static_cast(*(pp_overlaps + k)); - if (iou > max_overlap) { - best_anchor = j; - best_gt = k; - max_overlap = iou; - } - } - } - - if (best_anchor == -1) { - CHECK_EQ(best_gt, -1); - break; // no more good match - } else { - CHECK_EQ(max_matches[best_anchor].first, -1.0f); - CHECK_EQ(max_matches[best_anchor].second, -1); - max_matches[best_anchor].first = max_overlap; - max_matches[best_anchor].second = best_gt; - num_positive += 1; - // mark as visited - gt_flags[best_gt] = true; - anchor_flags[best_anchor] = 1; - } - } // end while - - if (overlap_threshold > 0) { - // find positive matches based on overlaps - for (index_t j = 0; j < num_anchors; ++j) { - if (anchor_flags[j] == 1) { - continue; // already matched this anchor - } - const DType *pp_overlaps = p_overlaps + j * num_labels; - int best_gt = -1; - int max_iou = -1.0f; - for (index_t k = 0; k < num_valid_gt; ++k) { - float iou = static_cast(*(pp_overlaps + k)); - if (iou > max_iou) { - best_gt = k; - max_iou = iou; - } - } - if (best_gt != -1) { - CHECK_EQ(max_matches[j].first, -1.0f); - CHECK_EQ(max_matches[j].second, -1); - max_matches[j].first = max_iou; - max_matches[j].second = best_gt; - if (max_iou > overlap_threshold) { - num_positive += 1; - // mark as visited - gt_flags[best_gt] = true; - anchor_flags[j] = 1; - } - } - } // end iterate anchors - } - - if (negative_mining_ratio > 0) { - index_t num_classes = cls_preds.size(1); - DType *p_cls_preds = cls_preds.dptr_ + nbatch * num_classes * num_anchors; - CHECK_GT(negative_mining_thresh, 0); - int num_negative = num_positive * negative_mining_ratio; - if (num_negative > (num_anchors - num_positive)) { - num_negative = num_anchors - num_positive; - } - if (num_negative > 0) { - // use negative mining, pick up "best" negative samples - std::vector temp; - temp.reserve(num_anchors - num_positive); - for (index_t j = 0; j < num_anchors; ++j) { - if (anchor_flags[j] == 1) { - continue; // already matched this anchor - } - if (max_matches[j].first < 0) { - // not yet calculated - const DType *pp_overlaps = p_overlaps + j * num_labels; - int best_gt = -1; - int max_iou = -1.0f; - for (index_t k = 0; k < num_valid_gt; ++k) { - float iou = static_cast(*(pp_overlaps + k)); - if (iou > max_iou) { - best_gt = k; - max_iou = iou; - } - } - if (best_gt != -1) { - CHECK_EQ(max_matches[j].first, -1.0f); - CHECK_EQ(max_matches[j].second, -1); - max_matches[j].first = max_iou; - max_matches[j].second = best_gt; - } - if (max_matches[j].first < negative_mining_thresh && - max_matches[j].first >= 0) { - // calcuate class predictions - DType max_val = p_cls_preds[j]; - DType max_val_pos = p_cls_preds[j + num_anchors]; - for (int k = 2; k < num_classes; ++k) { - DType tmp = p_cls_preds[j + num_anchors * k]; - if (tmp > max_val_pos) max_val_pos = tmp; - } - if (max_val_pos > max_val) max_val = max_val_pos; - DType sum = 0.f; - for (int k = 0; k < num_classes; ++k) { - DType tmp = p_cls_preds[j + num_anchors * k]; - sum += std::exp(tmp - max_val); - } - max_val_pos = std::exp(max_val_pos - max_val) / sum; - temp.push_back(SortElemDescend(max_val_pos, j)); - } - } - } // end iterate anchors - - CHECK_GE(temp.size(), num_negative); - std::stable_sort(temp.begin(), temp.end()); - for (int i = 0; i < num_negative; ++i) { - anchor_flags[temp[i].index] = 0; // mark as negative sample - } - } - } else { - // use all negative samples - for (index_t i = 0; i < num_anchors; ++i) { - if (anchor_flags[i] != 1) { - anchor_flags[i] = 0; - } - } - } - - // assign training targets - DType *p_loc_target = loc_target.dptr_ + nbatch * num_anchors * 4; - DType *p_loc_mask = loc_mask.dptr_ + nbatch * num_anchors * 4; - DType *p_cls_target = cls_target.dptr_ + nbatch * num_anchors; - for (index_t i = 0; i < num_anchors; ++i) { - if (anchor_flags[i] == 1) { - // positive sample - CHECK_GE(max_matches[i].second, 0); - // 0 reserved for background - *(p_cls_target + i) = *(p_label + 5 * max_matches[i].second) + 1; - index_t offset = i * 4; - *(p_loc_mask + offset) = 1; - *(p_loc_mask + offset + 1) = 1; - *(p_loc_mask + offset + 2) = 1; - *(p_loc_mask + offset + 3) = 1; - AssignLocTargets(p_anchor + i * 4, - p_label + 5 * max_matches[i].second + 1, p_loc_target + offset, - variances[0], variances[1], variances[2], variances[3]); - } else if (anchor_flags[i] == 0) { - // negative sample - *(p_cls_target + i) = 0; - index_t offset = i * 4; - *(p_loc_mask + offset) = 0; - *(p_loc_mask + offset + 1) = 0; - *(p_loc_mask + offset + 2) = 0; - *(p_loc_mask + offset + 3) = 0; - } - } // end iterate anchors - } - } // end iterate batches -} -} // namespace mshadow - -namespace mxnet { -namespace op { -template<> -Operator *CreateOp(MultiBoxTargetParam param, int dtype) { - Operator *op = NULL; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new MultiBoxTargetOp(param); - }); - return op; -} - -Operator* MultiBoxTargetProp::CreateOperatorEx(Context ctx, std::vector *in_shape, - std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); - CHECK(InferType(in_type, &out_type, &aux_type)); - DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); -} - -DMLC_REGISTER_PARAMETER(MultiBoxTargetParam); -MXNET_REGISTER_OP_PROPERTY(MultiBoxTarget, MultiBoxTargetProp) -.describe("Compute Multibox training targets") -.add_argument("anchor", "Symbol", "Generated anchor boxes.") -.add_argument("label", "Symbol", "Object detection labels.") -.add_argument("cls_pred", "Symbol", "Class predictions.") -.add_arguments(MultiBoxTargetParam::__FIELDS__()); -} // namespace op -} // namespace mxnet diff --git a/operator/multibox_target.cu b/operator/multibox_target.cu deleted file mode 100644 index c1c517f..0000000 --- a/operator/multibox_target.cu +++ /dev/null @@ -1,405 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file multibox_target.cu - * \brief MultiBoxTarget op - * \author Joshua Zhang -*/ -#include "./multibox_target-inl.h" - -#define WARPS_PER_BLOCK 16 -#define THREADS_PER_WARP 32 - -#define MULTIBOX_TARGET_CUDA_CHECK(condition) \ - /* Code block avoids redefinition of cudaError_t error */ \ - do { \ - cudaError_t error = condition; \ - CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \ - } while (0) - -namespace mshadow { -namespace cuda { -template -__global__ void InitGroundTruthFlags(DType *gt_flags, const DType *labels, - int num_batches, int num_labels) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index >= num_batches * num_labels) return; - int b = index / num_labels; - int l = index % num_labels; - if (*(labels + b * num_labels * 5 + l * 5) == -1.f) { - *(gt_flags + b * num_labels + l) = 0; - } else { - *(gt_flags + b * num_labels + l) = 1; - } -} - -template -__global__ void FindBestMatches(DType *best_matches, DType *gt_flags, - DType *anchor_flags, const DType *overlaps, - int num_anchors, int num_labels) { - int nbatch = blockIdx.x; - gt_flags += nbatch * num_labels; - overlaps += nbatch * num_anchors * num_labels; - best_matches += nbatch * num_anchors; - anchor_flags += nbatch * num_anchors; - const int num_threads = WARPS_PER_BLOCK * THREADS_PER_WARP; - __shared__ int max_indices_y[WARPS_PER_BLOCK * THREADS_PER_WARP]; - __shared__ int max_indices_x[WARPS_PER_BLOCK * THREADS_PER_WARP]; - __shared__ float max_values[WARPS_PER_BLOCK * THREADS_PER_WARP]; - - while (1) { - // check if all done. - bool finished = true; - for (int i = 0; i < num_labels; ++i) { - if (gt_flags[i] > .5) { - finished = false; - break; - } - } - if (finished) break; // all done. - - // finding max indices in different threads - int max_x = -1; - int max_y = -1; - DType max_value = 1e-6; // start with very small overlap - for (int i = threadIdx.x; i < num_anchors; i += num_threads) { - if (anchor_flags[i] > .5) continue; - for (int j = 0; j < num_labels; ++j) { - if (gt_flags[j] > .5) { - DType temp = overlaps[i * num_labels + j]; - if (temp > max_value) { - max_x = j; - max_y = i; - max_value = temp; - } - } - } - } - max_indices_x[threadIdx.x] = max_x; - max_indices_y[threadIdx.x] = max_y; - max_values[threadIdx.x] = max_value; - __syncthreads(); - - if (threadIdx.x == 0) { - // merge results and assign best match - int max_x = -1; - int max_y = -1; - DType max_value = -1; - for (int k = 0; k < num_threads; ++k) { - if (max_indices_y[k] < 0 || max_indices_x[k] < 0) continue; - float temp = max_values[k]; - if (temp > max_value) { - max_x = max_indices_x[k]; - max_y = max_indices_y[k]; - max_value = temp; - } - } - if (max_x >= 0 && max_y >= 0) { - best_matches[max_y] = max_x; - // mark flags as visited - gt_flags[max_x] = 0.f; - anchor_flags[max_y] = 1.f; - } else { - // no more good matches - for (int i = 0; i < num_labels; ++i) { - gt_flags[i] = 0.f; - } - } - } - __syncthreads(); - } -} - -template -__global__ void FindGoodMatches(DType *best_matches, DType *anchor_flags, - const DType *overlaps, int num_anchors, - int num_labels, float overlap_threshold) { - int nbatch = blockIdx.x; - overlaps += nbatch * num_anchors * num_labels; - best_matches += nbatch * num_anchors; - anchor_flags += nbatch * num_anchors; - const int num_threads = WARPS_PER_BLOCK * THREADS_PER_WARP; - - for (int i = threadIdx.x; i < num_anchors; i += num_threads) { - if (anchor_flags[i] < 0) { - int idx = -1; - float max_value = -1.f; - for (int j = 0; j < num_labels; ++j) { - DType temp = overlaps[i * num_labels + j]; - if (temp > max_value) { - max_value = temp; - idx = j; - } - } - if (max_value > overlap_threshold && (idx >= 0)) { - best_matches[i] = idx; - anchor_flags[i] = 0.9f; - } - } - } -} - -template -__global__ void UseAllNegatives(DType *anchor_flags, int num) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= num) return; - if (anchor_flags[idx] < 0.5) { - anchor_flags[idx] = 0; // regard all non-positive as negatives - } -} - -template -__global__ void NegativeMining(const DType *overlaps, const DType *cls_preds, - DType *anchor_flags, DType *buffer, - float negative_mining_ratio, - float negative_mining_thresh, - int minimum_negative_samples, - int num_anchors, - int num_labels, int num_classes) { - int nbatch = blockIdx.x; - overlaps += nbatch * num_anchors * num_labels; - cls_preds += nbatch * num_classes * num_anchors; - anchor_flags += nbatch * num_anchors; - buffer += nbatch * num_anchors * 3; - const int num_threads = WARPS_PER_BLOCK * THREADS_PER_WARP; - int num_positive; - __shared__ int num_negative; - - if (threadIdx.x == 0) { - num_positive = 0; - for (int i = 0; i < num_anchors; ++i) { - if (anchor_flags[i] > .5) { - ++num_positive; - } - } - num_negative = num_positive * negative_mining_ratio; - if (num_negative < minimum_negative_samples) { - num_negative = minimum_negative_samples; - } - if (num_negative > (num_anchors - num_positive)) { - num_negative = num_anchors - num_positive; - } - } - __syncthreads(); - - if (num_negative < 1) return; - - for (int i = threadIdx.x; i < num_anchors; i += num_threads) { - buffer[i] = -1.f; - if (anchor_flags[i] < 0) { - // compute max class prediction score - DType max_val = cls_preds[i]; - DType max_val_pos = cls_preds[i + num_anchors]; - for (int j = 2; j < num_classes; ++j) { - DType temp = cls_preds[i + num_anchors * j]; - if (temp > max_val_pos) max_val_pos = temp; - } - if (max_val_pos > max_val) max_val = max_val_pos; - DType sum = 0.f; - for (int j = 0; j < num_classes; ++j) { - DType temp = cls_preds[i + num_anchors * j]; - sum += exp(temp - max_val); - } - max_val_pos = exp(max_val_pos - max_val) / sum; - DType max_iou = -1.f; - for (int j = 0; j < num_labels; ++j) { - DType temp = overlaps[i * num_labels + j]; - if (temp > max_iou) max_iou = temp; - } - if (max_iou < negative_mining_thresh) { - // only do it for anchors with iou < thresh - buffer[i] = max_val_pos; - } - } - } - __syncthreads(); - - // descend merge sorting for negative mining - DType *index_src = buffer + num_anchors; - DType *index_dst = buffer + num_anchors * 2; - DType *src = index_src; - DType *dst = index_dst; - for (int i = threadIdx.x; i < num_anchors; i += num_threads) { - index_src[i] = i; - } - __syncthreads(); - - for (int width = 2; width < (num_anchors << 1); width <<= 1) { - int slices = (num_anchors - 1) / (num_threads * width) + 1; - int start = width * threadIdx.x * slices; - for (int slice = 0; slice < slices; ++slice) { - if (start >= num_anchors) break; - int middle = start + (width >> 1); - if (num_anchors < middle) middle = num_anchors; - int end = start + width; - if (num_anchors < end) end = num_anchors; - int i = start; - int j = middle; - for (int k = start; k < end; ++k) { - int idx_i = static_cast(src[i]); - int idx_j = static_cast(src[j]); - if (i < middle && (j >= end || buffer[idx_i] > buffer[idx_j])) { - dst[k] = src[i]; - ++i; - } else { - dst[k] = src[j]; - ++j; - } - } - start += width; - } - __syncthreads(); - // swap src/dst - src = src == index_src? index_dst : index_src; - dst = dst == index_src? index_dst : index_src; - } - __syncthreads(); - - for (int i = threadIdx.x; i < num_negative; i += num_threads) { - int idx = static_cast(src[i]); - if (anchor_flags[idx] < 0) { - anchor_flags[idx] = 0; - } - } -} - -template -__global__ void AssignTrainigTargets(DType *loc_target, DType *loc_mask, - DType *cls_target, DType *anchor_flags, - DType *best_matches, DType *labels, - DType *anchors, int num_anchors, - int num_labels, float vx, float vy, - float vw, float vh) { - int nbatch = blockIdx.x; - loc_target += nbatch * num_anchors * 4; - loc_mask += nbatch * num_anchors * 4; - cls_target += nbatch * num_anchors; - anchor_flags += nbatch * num_anchors; - best_matches += nbatch * num_anchors; - labels += nbatch * num_labels * 5; - const int num_threads = WARPS_PER_BLOCK * THREADS_PER_WARP; - - for (int i = threadIdx.x; i < num_anchors; i += num_threads) { - if (anchor_flags[i] > 0.5) { - // positive sample - int offset_l = static_cast(best_matches[i]) * 5; - cls_target[i] = labels[offset_l] + 1; // 0 reserved for background - int offset = i * 4; - loc_mask[offset] = 1; - loc_mask[offset + 1] = 1; - loc_mask[offset + 2] = 1; - loc_mask[offset + 3] = 1; - // regression targets - float al = anchors[offset]; - float at = anchors[offset + 1]; - float ar = anchors[offset + 2]; - float ab = anchors[offset + 3]; - float aw = ar - al; - float ah = ab - at; - float ax = (al + ar) * 0.5; - float ay = (at + ab) * 0.5; - float gl = labels[offset_l + 1]; - float gt = labels[offset_l + 2]; - float gr = labels[offset_l + 3]; - float gb = labels[offset_l + 4]; - float gw = gr - gl; - float gh = gb - gt; - float gx = (gl + gr) * 0.5; - float gy = (gt + gb) * 0.5; - loc_target[offset] = DType((gx - ax) / aw / vx); // xmin - loc_target[offset + 1] = DType((gy - ay) / ah / vy); // ymin - loc_target[offset + 2] = DType(log(gw / aw) / vw); // xmax - loc_target[offset + 3] = DType(log(gh / ah) / vh); // ymax - } else if (anchor_flags[i] < 0.5 && anchor_flags[i] > -0.5) { - // background - cls_target[i] = 0; - } - } -} -} // namespace cuda - -template -inline void MultiBoxTargetForward(const Tensor &loc_target, - const Tensor &loc_mask, - const Tensor &cls_target, - const Tensor &anchors, - const Tensor &labels, - const Tensor &cls_preds, - const Tensor &temp_space, - float overlap_threshold, float background_label, - float negative_mining_ratio, - float negative_mining_thresh, - int minimum_negative_samples, - const std::vector &variances) { - int num_batches = labels.size(0); - int num_labels = labels.size(1); - int num_anchors = anchors.size(0); - int num_classes = cls_preds.size(1); - CHECK_GE(num_batches, 1); - CHECK_GT(num_labels, 2); - CHECK_GE(num_anchors, 1); - - // init ground-truth flags, by checking valid labels - temp_space[1] = 0.f; - DType *gt_flags = temp_space[1].dptr_; - const int num_threads = THREADS_PER_WARP * WARPS_PER_BLOCK; - dim3 init_thread_dim(num_threads); - dim3 init_block_dim((num_batches * num_labels - 1) / num_threads + 1); - cuda::InitGroundTruthFlags<<>>( - gt_flags, labels.dptr_, num_batches, num_labels); - MULTIBOX_TARGET_CUDA_CHECK(cudaPeekAtLastError()); - - // compute best matches - temp_space[2] = -1.f; - temp_space[3] = -1.f; - DType *anchor_flags = temp_space[2].dptr_; - DType *best_matches = temp_space[3].dptr_; - const DType *overlaps = temp_space[0].dptr_; - cuda::FindBestMatches<<>>(best_matches, - gt_flags, anchor_flags, overlaps, num_anchors, num_labels); - MULTIBOX_TARGET_CUDA_CHECK(cudaPeekAtLastError()); - - // find good matches with overlap > threshold - if (overlap_threshold > 0) { - cuda::FindGoodMatches<<>>(best_matches, - anchor_flags, overlaps, num_anchors, num_labels, - overlap_threshold); - MULTIBOX_TARGET_CUDA_CHECK(cudaPeekAtLastError()); - } - - // do negative mining or not - if (negative_mining_ratio > 0) { - CHECK_GT(negative_mining_thresh, 0); - temp_space[4] = 0; - DType *buffer = temp_space[4].dptr_; - cuda::NegativeMining<<>>(overlaps, - cls_preds.dptr_, anchor_flags, buffer, negative_mining_ratio, - negative_mining_thresh, minimum_negative_samples, - num_anchors, num_labels, num_classes); - MULTIBOX_TARGET_CUDA_CHECK(cudaPeekAtLastError()); - } else { - int num_blocks = (num_batches * num_anchors - 1) / num_threads + 1; - cuda::UseAllNegatives<<>>(anchor_flags, - num_batches * num_anchors); - MULTIBOX_TARGET_CUDA_CHECK(cudaPeekAtLastError()); - } - - cuda::AssignTrainigTargets<<>>(loc_target.dptr_, - loc_mask.dptr_, cls_target.dptr_, anchor_flags, best_matches, labels.dptr_, - anchors.dptr_, num_anchors, num_labels, variances[0], variances[1], - variances[2], variances[3]); - MULTIBOX_TARGET_CUDA_CHECK(cudaPeekAtLastError()); -} -} // namespace mshadow - -namespace mxnet { -namespace op { -template<> -Operator *CreateOp(MultiBoxTargetParam param, int dtype) { - Operator *op = NULL; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new MultiBoxTargetOp(param); - }); - return op; -} -} // namespace op -} // namespace mxnet diff --git a/operator/scale-inl.h b/operator/scale-inl.h deleted file mode 100644 index aaba76d..0000000 --- a/operator/scale-inl.h +++ /dev/null @@ -1,280 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file scale-inl.h - * \brief A scaling layer with inital scale, and adjusted with backprop - * \author Joshua Zhang -*/ -#ifndef MXNET_OPERATOR_SCALE_INL_H_ -#define MXNET_OPERATOR_SCALE_INL_H_ -#include -#include -#include -#include -#include -#include -#include -#include -#include "./operator_common.h" -#include "./mshadow_op.h" - -namespace mxnet { -namespace op { -namespace scale_enum { -enum ScaleOpInputs {kData, kWeight}; -enum ScaleOpOutputs {kOut}; -enum ScaleOpResource {kTempSpace}; -enum ScaleOpType {kInstance, kChannel, kSpatial}; -} // scale_enum - -struct ScaleParam : public dmlc::Parameter { - int mode; - DMLC_DECLARE_PARAMETER(ScaleParam) { - DMLC_DECLARE_FIELD(mode) - .add_enum("instance", scale_enum::kInstance) - .add_enum("spatial", scale_enum::kSpatial) - .add_enum("channel", scale_enum::kChannel) - .set_default(scale_enum::kInstance) - .describe("Scaling Mode. If set to instance, this operator will use independent " - "scale for each instance in the batch; this is the default mode. " - "If set to channel, this operator will share scales cross channel at " - "each position of each instance. If set to spatial, this operator shares scales " - "in each channel."); - } -}; // struct ScaleParam - -template -class ScaleOp : public Operator { - public: - explicit ScaleOp(ScaleParam p) { - this->param_ = p; - } - - virtual void Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - CHECK_EQ(in_data.size(), 2); - CHECK_EQ(out_data.size(), 1); - Stream *s = ctx.get_stream(); - TShape orig_shape = in_data[scale_enum::kData].shape_; - index_t nbatch = orig_shape[0]; - if (param_.mode == scale_enum::kInstance) { - Shape<2> dshape = Shape2(orig_shape[0], - orig_shape.ProdShape(1, orig_shape.ndim())); - Tensor data = in_data[scale_enum::kData] - .get_with_shape(dshape, s); - Tensor out = out_data[scale_enum::kOut] - .get_with_shape(dshape, s); - Tensor weight = in_data[scale_enum::kWeight].get(s); - out = data * broadcast<0>(broadcast_keepdim(weight, 0, nbatch), out.shape_); - } else if (param_.mode == scale_enum::kChannel) { - CHECK_GE(orig_shape.ndim(), 3); - Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1], - orig_shape.ProdShape(2, orig_shape.ndim())); - Tensor data = in_data[scale_enum::kData] - .get_with_shape(dshape, s); - Tensor out = out_data[scale_enum::kOut] - .get_with_shape(dshape, s); - Tensor weight = in_data[scale_enum::kWeight] - .get_with_shape(Shape2(1, dshape[2]), s); - out = data * broadcast_with_axis( - broadcast_keepdim(weight, 0, nbatch), 0, dshape[1]); - } else if (param_.mode == scale_enum::kSpatial) { - CHECK_GE(orig_shape.ndim(), 3); - Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1], - orig_shape.ProdShape(2, orig_shape.ndim())); - Tensor data = in_data[scale_enum::kData] - .get_with_shape(dshape, s); - Tensor out = out_data[scale_enum::kOut] - .get_with_shape(dshape, s); - Tensor weight = in_data[scale_enum::kWeight] - .get_with_shape(Shape2(1, dshape[1]), s); - out = data * broadcast_with_axis( - broadcast_keepdim(weight, 0, nbatch), 1, dshape[2]); - } else { - LOG(FATAL) << "Unknown scaling mode."; - } - } - - virtual void Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - CHECK_EQ(in_data.size(), 2); - CHECK_EQ(out_data.size(), 1); - Stream *s = ctx.get_stream(); - TShape orig_shape = out_data[scale_enum::kOut].shape_; - index_t nbatch = orig_shape[0]; - if (param_.mode == scale_enum::kInstance) { - Shape<2> dshape = Shape2(orig_shape[0], - orig_shape.ProdShape(1, orig_shape.ndim())); - Tensor data = out_data[scale_enum::kOut] - .get_with_shape(dshape, s); - Tensor grad_in = in_grad[scale_enum::kData] - .get_with_shape(dshape, s); - Tensor grad_out = out_grad[scale_enum::kOut] - .get_with_shape(dshape, s); - Tensor wgrad_in = in_grad[scale_enum::kWeight].get(s); - Tensor weight = in_data[scale_enum::kWeight].get(s); - Tensor temp = ctx.requested[scale_enum::kTempSpace] - .get_space_typed(mshadow::Shape1(1), s); - temp = sumall_except_dim<0>(reduce_keepdim(grad_out * data, 0)); - Assign(wgrad_in, req[scale_enum::kWeight], temp / weight); - Assign(grad_in, req[scale_enum::kData], grad_out * - broadcast<0>(broadcast_keepdim(weight, 0, nbatch), grad_out.shape_)); - } else if (param_.mode == scale_enum::kChannel) { - CHECK_GE(orig_shape.ndim(), 3); - Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1], - orig_shape.ProdShape(2, orig_shape.ndim())); - Tensor data = out_data[scale_enum::kOut] - .get_with_shape(dshape, s); - Tensor grad_in = in_grad[scale_enum::kData] - .get_with_shape(dshape, s); - Tensor grad_out = out_grad[scale_enum::kOut] - .get_with_shape(dshape, s); - Shape<2> wshape = Shape2(1, dshape[2]); - Tensor wgrad_in = in_grad[scale_enum::kWeight] - .get_with_shape(wshape, s); - Tensor weight = in_data[scale_enum::kWeight] - .get_with_shape(wshape, s); - Tensor temp = ctx.requested[scale_enum::kTempSpace] - .get_space_typed(mshadow::Shape2(1, data.shape_[2]), s); - temp = reduce_keepdim( - reduce_with_axis(grad_out * data, 1), 0); - Assign(wgrad_in, req[scale_enum::kWeight], temp / weight); - Assign(grad_in, req[scale_enum::kWeight], - grad_out * broadcast_with_axis( - broadcast_keepdim(weight, 0, nbatch), 0, orig_shape[1])); - } else if (param_.mode == scale_enum::kSpatial) { - CHECK_GE(orig_shape.ndim(), 3); - Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1], - orig_shape.ProdShape(2, orig_shape.ndim())); - Tensor data = out_data[scale_enum::kOut] - .get_with_shape(dshape, s); - Tensor grad_in = in_grad[scale_enum::kData] - .get_with_shape(dshape, s); - Tensor grad_out = out_grad[scale_enum::kOut] - .get_with_shape(dshape, s); - Shape<2> wshape = Shape2(1, dshape[1]); - Tensor wgrad_in = in_grad[scale_enum::kWeight] - .get_with_shape(wshape, s); - Tensor weight = in_data[scale_enum::kWeight] - .get_with_shape(wshape, s); - Tensor temp = ctx.requested[scale_enum::kTempSpace] - .get_space_typed(mshadow::Shape2(1, data.shape_[1]), s); - temp = reduce_keepdim( - reduce_with_axis(grad_out * data, 2), 0); - Assign(wgrad_in, req[scale_enum::kWeight], temp / weight); - Assign(grad_in, req[scale_enum::kData], - grad_out * broadcast_with_axis( - broadcast_keepdim(weight, 0, nbatch), 1, dshape[2])); - } else { - LOG(FATAL) << "Unknown scaling mode"; - } - } - - private: - ScaleParam param_; -}; // class ScaleOp - -template -Operator* CreateOp(ScaleParam param, int dtype); - -#if DMLC_USE_CXX11 -class ScaleProp : public OperatorProperty { - public: - std::vector ListArguments() const override { - return {"data", "scale"}; - } - - std::vector ListOutputs() const override { - return {"output"}; - } - - void Init(const std::vector >& kwargs) override { - param_.Init(kwargs); - } - - std::map GetParams() const override { - return param_.__DICT__(); - } - - bool InferShape(std::vector *in_shape, - std::vector *out_shape, - std::vector *aux_shape) const override { - using namespace mshadow; - CHECK_EQ(in_shape->size(), 2) << "Input:[data, weight]"; - const TShape &dshape = (*in_shape)[scale_enum::kData]; - if (dshape.ndim() == 0) return false; - if (param_.mode == scale_enum::kInstance) { - CHECK_GE(dshape.ndim(), 2); - SHAPE_ASSIGN_CHECK(*in_shape, scale_enum::kWeight, Shape1(1)); - } else if (param_.mode == scale_enum::kChannel) { - CHECK_GE(dshape.ndim(), 3) - << "At least 3 dimensions required in channel mode"; - SHAPE_ASSIGN_CHECK(*in_shape, scale_enum::kWeight, - Shape1(dshape.ProdShape(2, dshape.ndim()))); - } else if (param_.mode == scale_enum::kSpatial) { - CHECK_GE(dshape.ndim(), 3) - << "At least 3 dimensions required in spatial mode"; - SHAPE_ASSIGN_CHECK(*in_shape, scale_enum::kWeight, Shape1(dshape[1])); - } - out_shape->clear(); - out_shape->push_back(dshape); - return true; - } - - - - OperatorProperty* Copy() const override { - auto ptr = new ScaleProp(); - ptr->param_ = param_; - return ptr; - } - - std::string TypeString() const override { - return "Scale"; - } - - std::vector DeclareBackwardDependency( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data) const override { - return {out_grad[scale_enum::kOut], out_data[scale_enum::kOut], in_data[scale_enum::kWeight]}; - } - - std::vector ForwardResource( - const std::vector &in_shape) const override { - return {ResourceRequest::kTempSpace}; - } - - std::vector BackwardResource( - const std::vector &in_shape) const override { - return {ResourceRequest::kTempSpace}; - } - - Operator* CreateOperator(Context ctx) const override { - LOG(FATAL) << "Not implemented"; - return NULL; - } - - Operator* CreateOperatorEx(Context ctx, std::vector *in_shape, - std::vector *in_type) const override; - - private: - ScaleParam param_; -}; // class ScaleProp -#endif // DMLC_USE_CXX11 - -} // namespace op -} // namespace mxnet - -#endif // MXNET_OPERATOR_SCALE_INL_H_ diff --git a/operator/scale.cc b/operator/scale.cc deleted file mode 100644 index 04d53c6..0000000 --- a/operator/scale.cc +++ /dev/null @@ -1,36 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file scale.cc - * \brief scale operator -*/ -#include "./scale-inl.h" - -namespace mxnet { -namespace op { -template<> -Operator* CreateOp(ScaleParam param, int dtype) { - Operator *op = NULL; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new ScaleOp(param); - }); - return op; -} - -Operator* ScaleProp::CreateOperatorEx(Context ctx, - std::vector *in_shape, - std::vector *in_type) const { - std::vector out_shape, aux_shape; - std::vector out_type, aux_type; - CHECK(InferShape(in_shape, &out_shape, &aux_shape)); - CHECK(InferType(in_type, &out_type, &aux_type)); - DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); -} - -DMLC_REGISTER_PARAMETER(ScaleParam); - -MXNET_REGISTER_OP_PROPERTY(Scale, ScaleProp) -.describe("Scale the input initialized by user and learned through backpropogation.") -.add_argument("data", "Symbol", "Input data to the ScaleOp.") -.add_arguments(ScaleParam::__FIELDS__()); -} // namespace op -} // namespace mxnet diff --git a/operator/scale.cu b/operator/scale.cu deleted file mode 100644 index d02e672..0000000 --- a/operator/scale.cu +++ /dev/null @@ -1,19 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file scale.cu - * \brief scale operator -*/ -#include "./scale-inl.h" - -namespace mxnet { -namespace op { -template<> -Operator* CreateOp(ScaleParam param, int dtype) { - Operator *op = NULL; - MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new ScaleOp(param); - }); - return op; -} -} // namespace op -} // namespace mxnet diff --git a/symbol/symbol_vgg16_reduced.py b/symbol/symbol_vgg16_reduced.py index ed1a3c7..d031246 100644 --- a/symbol/symbol_vgg16_reduced.py +++ b/symbol/symbol_vgg16_reduced.py @@ -51,7 +51,7 @@ def get_symbol_train(num_classes=20): relu3_3 = mx.symbol.Activation(data=conv3_3, act_type="relu", name="relu3_3") pool3 = mx.symbol.Pooling( data=relu3_3, pool_type="max", kernel=(2, 2), stride=(2, 2), \ - pad=(1,1), name="pool3") + pooling_convention="full", name="pool3") # group 4 conv4_1 = mx.symbol.Convolution( data=pool3, kernel=(3, 3), pad=(1, 1), num_filter=512, name="conv4_1") diff --git a/train.py b/train.py index 7e1dc31..a4e295a 100644 --- a/train.py +++ b/train.py @@ -41,7 +41,7 @@ def parse_args(): default=20, type=int) parser.add_argument('--data-shape', dest='data_shape', type=int, default=300, help='set image shape') - parser.add_argument('--lr', dest='learning_rate', type=float, default=0.001, + parser.add_argument('--lr', dest='learning_rate', type=float, default=0.002, help='learning rate') parser.add_argument('--momentum', dest='momentum', type=float, default=0.9, help='momentum')