From 75c60f53fe1cc72ac2c8912eacf6cd57f8b65e98 Mon Sep 17 00:00:00 2001 From: TangTT <1489272013@qq.com> Date: Mon, 29 Jun 2020 17:28:27 +0800 Subject: [PATCH 1/3] Refactor Backbone --- configs/yolo/yolov3_ms_aug_273e.py | 12 +- .../yolo/yolov3_ms_aug_273e_no_pretrain.py | 116 ------------------ mmdet/models/backbones/__init__.py | 4 +- mmdet/models/backbones/darknet.py | 83 +++++++++---- 4 files changed, 72 insertions(+), 143 deletions(-) delete mode 100644 configs/yolo/yolov3_ms_aug_273e_no_pretrain.py diff --git a/configs/yolo/yolov3_ms_aug_273e.py b/configs/yolo/yolov3_ms_aug_273e.py index d04cfc21f31..b81c320ca32 100644 --- a/configs/yolo/yolov3_ms_aug_273e.py +++ b/configs/yolo/yolov3_ms_aug_273e.py @@ -5,7 +5,10 @@ type='YoloNet', pretrained='./work_dirs/darknet_state_dict_only.pth', backbone=dict( - type='DarkNet53',), + type='Darknet', + depth=53, + out_indices=(3, 4, 5), + ), neck=dict( type='YoloNeck',), bbox_head=dict( @@ -28,6 +31,7 @@ data_root = 'data/coco/' img_norm_cfg = dict( mean=[0, 0, 0], std=[255., 255., 255.], to_rgb=True) +# TODO: Add PhotoMetricDistortion train_pipeline = [ dict(type='LoadImageFromFile', to_float32=True), dict(type='LoadAnnotations', with_bbox=True), @@ -86,7 +90,7 @@ ) ) # optimizer -optimizer = dict(type='SGD', lr=5e-4, momentum=0.9, weight_decay=0.0005) +optimizer = dict(type='SGD', lr=2e-3, momentum=0.9, weight_decay=0.0005) optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) # learning policy lr_config = dict( @@ -101,12 +105,10 @@ interval=50, hooks=[ dict(type='TextLoggerHook'), - # dict(type='TensorboardLoggerHook') ]) # yapf:enable # runtime settings total_epochs = 273 -device_ids = range(8) dist_params = dict(backend='nccl') log_level = 'INFO' work_dir = './work_dirs/yolo_pretrained' @@ -114,3 +116,5 @@ resume_from = None workflow = [('train', 1)] evaluation = dict(interval=1, metric=['bbox']) +# TODO: Remove hot fix +find_unused_parameters=True diff --git a/configs/yolo/yolov3_ms_aug_273e_no_pretrain.py b/configs/yolo/yolov3_ms_aug_273e_no_pretrain.py deleted file mode 100644 index e0ee71896ce..00000000000 --- a/configs/yolo/yolov3_ms_aug_273e_no_pretrain.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) 2019 Western Digital Corporation or its affiliates. - -# model settings -model = dict( - type='YoloNet', - pretrained=None, - backbone=dict( - type='DarkNet53',), - neck=dict( - type='YoloNeck',), - bbox_head=dict( - type='YoloHead',)) -# training and testing settings -train_cfg = dict( - one_hot_smoother=0., - ignore_config=0.5, - xy_use_logit=False, - debug=False) -test_cfg = dict( - nms_pre=1000, - min_bbox_size=0, - score_thr=0.05, - conf_thr=0.005, - nms=dict(type='nms', iou_thr=0.45), - max_per_img=100) -# dataset settings -dataset_type = 'CocoDataset' -data_root = 'data/coco/' -img_norm_cfg = dict( - mean=[0, 0, 0], std=[255., 255., 255.], to_rgb=True) -train_pipeline = [ - dict(type='LoadImageFromFile', to_float32=True), - dict(type='LoadAnnotations', with_bbox=True), - dict(type='PhotoMetricDistortion'), - dict(type='Expand', - mean=img_norm_cfg['mean'], - to_rgb=img_norm_cfg['to_rgb'], - ratio_range=(1, 2) - ), - dict(type='MinIoURandomCrop', - min_ious=(0.4, 0.5, 0.6, 0.7, 0.8, 0.9), - min_crop_size=0.3 - ), - dict(type='Resize', img_scale=[(320, 320), (608, 608)], keep_ratio=True), - dict(type='RandomFlip', flip_ratio=0.5), - dict(type='Normalize', **img_norm_cfg), - dict(type='Pad', size_divisor=32), - dict(type='DefaultFormatBundle'), - dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), -] -test_pipeline = [ - dict(type='LoadImageFromFile'), - dict( - type='MultiScaleFlipAug', - img_scale=(608, 608), - flip=False, - transforms=[ - dict(type='Resize', keep_ratio=True), - dict(type='RandomFlip'), - dict(type='Normalize', **img_norm_cfg), - dict(type='Pad', size_divisor=32), - dict(type='ImageToTensor', keys=['img']), - dict(type='Collect', keys=['img']), - ]) -] -data = dict( - samples_per_gpu=8, - workers_per_gpu=8, - train=dict( - type=dataset_type, - ann_file=data_root + 'annotations/instances_train2017.json', - img_prefix=data_root + 'train2017/', - pipeline=train_pipeline, - ), - val=dict( - type=dataset_type, - ann_file=data_root + 'annotations/instances_val2017.json', - img_prefix=data_root + 'val2017/', - pipeline=test_pipeline, - ), - test=dict( - type=dataset_type, - ann_file=data_root + 'annotations/instances_val2017.json', - img_prefix=data_root + 'val2017/', - pipeline=test_pipeline, - ) -) -# optimizer -optimizer = dict(type='SGD', lr=5e-4, momentum=0.9, weight_decay=0.0005) -optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) -# learning policy -lr_config = dict( - policy='step', - warmup='linear', - warmup_iters=2000, # same as burn-in in darknet - warmup_ratio=0.1, - step=[218, 246]) -checkpoint_config = dict(interval=1) -# yapf:disable -log_config = dict( - interval=50, - hooks=[ - dict(type='TextLoggerHook'), - # dict(type='TensorboardLoggerHook') - ]) -# yapf:enable -# runtime settings -total_epochs = 273 -device_ids = range(8) -dist_params = dict(backend='nccl') -log_level = 'INFO' -work_dir = './work_dirs/yolo_no_pretrain' -load_from = None -resume_from = None -workflow = [('train', 1)] -evaluation = dict(interval=1, metric=['bbox']) diff --git a/mmdet/models/backbones/__init__.py b/mmdet/models/backbones/__init__.py index e851854fb90..26a61735d3a 100644 --- a/mmdet/models/backbones/__init__.py +++ b/mmdet/models/backbones/__init__.py @@ -5,9 +5,9 @@ from .resnet import ResNet, ResNetV1d from .resnext import ResNeXt from .ssd_vgg import SSDVGG -from .darknet import DarkNet53 +from .darknet import Darknet __all__ = [ 'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'Res2Net', - 'HourglassNet', 'DarkNet53' + 'HourglassNet', 'Darknet' ] diff --git a/mmdet/models/backbones/darknet.py b/mmdet/models/backbones/darknet.py index 1d30881cda0..85a6dc0f32e 100644 --- a/mmdet/models/backbones/darknet.py +++ b/mmdet/models/backbones/darknet.py @@ -13,7 +13,8 @@ class ResBlock(nn.Module): """The basic residual block used in YoloV3. - Each ResBlock consists of two ConvLayers and the input is added to the final output. + Each ResBlock consists of two ConvModules and the input is added to the final output. + Each ConvModule is composed of Conv, BN, and LeakyReLU In YoloV3 paper, the first convLayer has half of the number of the filters as much as the second convLayer. The first convLayer has filter size of 1x1 and the second one has the filter size of 3x3. """ @@ -46,7 +47,7 @@ def forward(self, x): def make_conv_and_res_block(in_channels, out_channels, res_repeat): - """In Darknet 53 backbone, there is usually one Conv Layer followed by some ResBlock. + """In Darknet backbone, there is usually one Conv Layer followed by some ResBlock. This function will make that. The Conv layers always have 3x3 filters with stride=2. The number of the filters in Conv layer is the same as the out channels of the ResBlock""" @@ -64,37 +65,77 @@ def make_conv_and_res_block(in_channels, out_channels, res_repeat): @BACKBONES.register_module() -class DarkNet53(nn.Module): +class Darknet(nn.Module): + """Darknet backbone. + + Args: + depth (int): Depth of Darknet. Currently only support 53. + out_indices (Sequence[int]): Output from which stages. + Note: By default, the sequence of the layers will be returned + in a **reversed** manner. i.e., from bottom to up. + See the example bellow. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + reverse_output (bool): If True, the sequence of the output layers + will be from bottom to up. Default: True. (To cope with YoloNeck) + + Example: + >>> from mmdet.models import Darknet + >>> import torch + >>> self = Darknet(depth=53) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 416, 416) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + ... + (1, 1024, 13, 13) + (1, 512, 26, 26) + (1, 256, 52, 52) + """ def __init__(self, + depth=53, + out_indices=(3, 4, 5), norm_eval=True, - reverse_output=False): - super(DarkNet53, self).__init__() + reverse_output=True): + super(Darknet, self).__init__() + self.depth = depth + self.out_indices = out_indices + if self.depth == 53: + self.layers = [1, 2, 8, 8, 4] + self.channels = [[32, 64], [64, 128], [128, 256], [256, 512], [512, 1024]] + else: + raise KeyError(f'invalid depth {depth} for darknet') + self.conv1 = ConvModule(3, 32, 3, padding=1, norm_cfg=dict(type='BN', requires_grad=True), act_cfg=dict(type='LeakyReLU', negative_slope=0.1)) - self.cr_block1 = make_conv_and_res_block(32, 64, 1) - self.cr_block2 = make_conv_and_res_block(64, 128, 2) - self.cr_block3 = make_conv_and_res_block(128, 256, 8) - self.cr_block4 = make_conv_and_res_block(256, 512, 8) - self.cr_block5 = make_conv_and_res_block(512, 1024, 4) + + self.cr_blocks = ['conv1'] + for i, n_layers in enumerate(self.layers): + layer_name = f'cr_block{i + 1}' + in_c, out_c = self.channels[i] + self.add_module(layer_name, make_conv_and_res_block(in_c, out_c, n_layers)) + self.cr_blocks.append(layer_name) self.norm_eval = norm_eval self.reverse_output=reverse_output def forward(self, x): - tmp = self.conv1(x) - tmp = self.cr_block1(tmp) - tmp = self.cr_block2(tmp) - out3 = self.cr_block3(tmp) - out2 = self.cr_block4(out3) - out1 = self.cr_block5(out2) - - if not self.reverse_output: - return out1, out2, out3 + outs = [] + for i, layer_name in enumerate(self.cr_blocks): + cr_block = getattr(self, layer_name) + x = cr_block(x) + if i in self.out_indices: + outs.append(x) + + if self.reverse_output: + return tuple(outs[::-1]) else: - return out3, out2, out1 + return tuple(outs) def init_weights(self, pretrained=None): if isinstance(pretrained, str): @@ -115,7 +156,7 @@ def _freeze_stages(self): param.requires_grad = False def train(self, mode=True): - super(DarkNet53, self).train(mode) + super(Darknet, self).train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): From 3e59739a6a578e2868b2722a1433f739643c86a7 Mon Sep 17 00:00:00 2001 From: TangTT <1489272013@qq.com> Date: Mon, 29 Jun 2020 17:32:08 +0800 Subject: [PATCH 2/3] Update README --- configs/yolo/README.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/configs/yolo/README.md b/configs/yolo/README.md index 92ecbbce795..6ba3cef0dfd 100644 --- a/configs/yolo/README.md +++ b/configs/yolo/README.md @@ -1,4 +1,4 @@ -#YOLOv3 +# YOLOv3 ## Introduction ``` @@ -16,8 +16,10 @@ Test set: COCO val2017 -bbox_mAP: 0.3520 +bbox_mAP: 0.3640 -bbox_mAP_50: 0.6100 +bbox_mAP_50: 0.6350 -Checkpoint link: [here](https://drive.google.com/drive/folders/1NzQ5LwBaYPlu1gywnRAViNz70NV9743O?usp=sharing) \ No newline at end of file +Checkpoint link: [here](https://drive.google.com/drive/folders/1NzQ5LwBaYPlu1gywnRAViNz70NV9743O?usp=sharing) + +This implementation originates from the project of Haoyu Wu(@wuhy08) at Western Digital. From 4f24fa1a188f1a1ea807f9d3abba5da5b5a92995 Mon Sep 17 00:00:00 2001 From: TangTT <1489272013@qq.com> Date: Mon, 29 Jun 2020 18:00:02 +0800 Subject: [PATCH 3/3] Lint and format --- configs/yolo/yolov3_ms_aug_273e.py | 41 ++-- mmdet/models/dense_heads/yolo_head.py | 270 +++++++++++++++----------- mmdet/models/necks/yolo_neck.py | 87 +++++---- 3 files changed, 233 insertions(+), 165 deletions(-) diff --git a/configs/yolo/yolov3_ms_aug_273e.py b/configs/yolo/yolov3_ms_aug_273e.py index b81c320ca32..b91ece70158 100644 --- a/configs/yolo/yolov3_ms_aug_273e.py +++ b/configs/yolo/yolov3_ms_aug_273e.py @@ -8,17 +8,12 @@ type='Darknet', depth=53, out_indices=(3, 4, 5), - ), - neck=dict( - type='YoloNeck',), - bbox_head=dict( - type='YoloHead',)) + ), + neck=dict(type='YoloNeck', ), + bbox_head=dict(type='YoloHead', )) # training and testing settings train_cfg = dict( - one_hot_smoother=0., - ignore_config=0.5, - xy_use_logit=False, - debug=False) + one_hot_smoother=0., ignore_config=0.5, xy_use_logit=False, debug=False) test_cfg = dict( nms_pre=1000, min_bbox_size=0, @@ -29,22 +24,21 @@ # dataset settings dataset_type = 'CocoDataset' data_root = 'data/coco/' -img_norm_cfg = dict( - mean=[0, 0, 0], std=[255., 255., 255.], to_rgb=True) +img_norm_cfg = dict(mean=[0, 0, 0], std=[255., 255., 255.], to_rgb=True) # TODO: Add PhotoMetricDistortion train_pipeline = [ dict(type='LoadImageFromFile', to_float32=True), dict(type='LoadAnnotations', with_bbox=True), dict(type='PhotoMetricDistortion'), - dict(type='Expand', - mean=img_norm_cfg['mean'], - to_rgb=img_norm_cfg['to_rgb'], - ratio_range=(1, 2) - ), - dict(type='MinIoURandomCrop', - min_ious=(0.4, 0.5, 0.6, 0.7, 0.8, 0.9), - min_crop_size=0.3 - ), + dict( + type='Expand', + mean=img_norm_cfg['mean'], + to_rgb=img_norm_cfg['to_rgb'], + ratio_range=(1, 2)), + dict( + type='MinIoURandomCrop', + min_ious=(0.4, 0.5, 0.6, 0.7, 0.8, 0.9), + min_crop_size=0.3), dict(type='Resize', img_scale=[(320, 320), (608, 608)], keep_ratio=True), dict(type='RandomFlip', flip_ratio=0.5), dict(type='Normalize', **img_norm_cfg), @@ -87,10 +81,9 @@ ann_file=data_root + 'annotations/instances_val2017.json', img_prefix=data_root + 'val2017/', pipeline=test_pipeline, - ) -) + )) # optimizer -optimizer = dict(type='SGD', lr=2e-3, momentum=0.9, weight_decay=0.0005) +optimizer = dict(type='SGD', lr=5e-4, momentum=0.9, weight_decay=0.0005) optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) # learning policy lr_config = dict( @@ -117,4 +110,4 @@ workflow = [('train', 1)] evaluation = dict(interval=1, metric=['bbox']) # TODO: Remove hot fix -find_unused_parameters=True +find_unused_parameters = True diff --git a/mmdet/models/dense_heads/yolo_head.py b/mmdet/models/dense_heads/yolo_head.py index 757d459143b..735b345db0c 100644 --- a/mmdet/models/dense_heads/yolo_head.py +++ b/mmdet/models/dense_heads/yolo_head.py @@ -1,22 +1,17 @@ # Copyright (c) 2019 Western Digital Corporation or its affiliates. +import logging + import torch import torch.nn as nn import torch.nn.functional as F +from mmcv.cnn import ConvModule, xavier_init +from mmcv.runner import load_checkpoint -import logging - +from mmdet.core import force_fp32, multiclass_nms from ..builder import HEADS from .base_dense_head import BaseDenseHead -from mmcv.cnn import xavier_init, ConvModule -from mmcv.runner import load_checkpoint - -from mmdet.core import (multiclass_nms, force_fp32) - -from mmdet.ops.nms import nms_wrapper - - _EPSILON = 1e-6 @@ -35,12 +30,12 @@ class YoloHead(BaseDenseHead): out_channels = [1024, 512, 256] scales = ['l', 'm', 's'] strides = [32, 16, 8] - # scale_params = [(1024, 'l', 32), (512, 'm', 16), (256, 's', 8)] #out_channel, scale, stride - anchor_base_sizes = [[(116, 90), (156, 198), (373, 326)], - [(30, 61), (62, 45), (59, 119)], - [(10, 13), (16, 30), (33, 23)], - ] + anchor_base_sizes = [ + [(116, 90), (156, 198), (373, 326)], + [(30, 61), (62, 45), (59, 119)], + [(10, 13), (16, 30), (33, 23)], + ] def __init__(self, train_cfg=None, test_cfg=None): super(YoloHead, self).__init__() @@ -49,12 +44,13 @@ def __init__(self, train_cfg=None, test_cfg=None): for i_scale in range(self.num_scales): in_c = self.in_channels[i_scale] out_c = self.out_channels[i_scale] - conv_bridge = ConvModule(in_c, - out_c, - 3, - padding=1, - norm_cfg=dict(type='BN', requires_grad=True), - act_cfg=dict(type='LeakyReLU', negative_slope=0.1)) + conv_bridge = ConvModule( + in_c, + out_c, + 3, + padding=1, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='LeakyReLU', negative_slope=0.1)) conv_final = nn.Conv2d(out_c, self.last_layer_dim, 1, bias=True) self.convs_bridge.append(conv_bridge) @@ -85,7 +81,7 @@ def forward(self, feats): return tuple(results), - @force_fp32(apply_to=('results_raw',)) + @force_fp32(apply_to=('results_raw', )) def get_bboxes(self, results_raw, img_metas, cfg=None, rescale=False): result_list = [] for img_id in range(len(img_metas)): @@ -93,14 +89,19 @@ def get_bboxes(self, results_raw, img_metas, cfg=None, rescale=False): results_raw[i][img_id].detach() for i in range(self.num_scales) ] scale_factor = img_metas[img_id]['scale_factor'] - proposals = self.get_bboxes_single(result_raw_list, scale_factor, cfg, rescale) + proposals = self.get_bboxes_single(result_raw_list, scale_factor, + cfg, rescale) result_list.append(proposals) return result_list @staticmethod def _get_anchors_grid_xy(num_grid_h, num_grid_w, stride, device='cpu'): - grid_x = torch.arange(num_grid_w, dtype=torch.float, device=device).repeat(num_grid_h, 1) - grid_y = torch.arange(num_grid_h, dtype=torch.float, device=device).repeat(num_grid_w, 1) + grid_x = torch.arange( + num_grid_w, dtype=torch.float, + device=device).repeat(num_grid_h, 1) + grid_y = torch.arange( + num_grid_h, dtype=torch.float, + device=device).repeat(num_grid_w, 1) grid_x = grid_x.unsqueeze(0) * stride grid_y = grid_y.t().unsqueeze(0) * stride @@ -120,21 +121,28 @@ def get_bboxes_single(self, results_raw, scale_factor, cfg, rescale=False): prediction_raw = result_raw.view(self.num_anchors_per_scale, self.num_attrib, - num_grid_h, - num_grid_w).permute(0, 2, 3, 1).contiguous() + num_grid_h, num_grid_w).permute( + 0, 2, 3, 1).contiguous() # grid x y offset, with stride step included stride = self.strides[i_scale] - grid_x, grid_y = self._get_anchors_grid_xy(num_grid_h, num_grid_w, stride, result_raw.device) + grid_x, grid_y = self._get_anchors_grid_xy(num_grid_h, num_grid_w, + stride, + result_raw.device) # Get outputs x, y - x_center_pred = torch.sigmoid(prediction_raw[..., 0]) * stride + grid_x # Center x - y_center_pred = torch.sigmoid(prediction_raw[..., 1]) * stride + grid_y # Center y + x_center_pred = torch.sigmoid( + prediction_raw[..., 0]) * stride + grid_x # Center x + y_center_pred = torch.sigmoid( + prediction_raw[..., 1]) * stride + grid_y # Center y - anchors = torch.tensor(self.anchor_base_sizes[i_scale], device=result_raw.device, dtype=torch.float32) + anchors = torch.tensor( + self.anchor_base_sizes[i_scale], + device=result_raw.device, + dtype=torch.float32) anchor_w = anchors[:, 0:1].view((-1, 1, 1)) anchor_h = anchors[:, 1:2].view((-1, 1, 1)) @@ -148,10 +156,11 @@ def get_bboxes_single(self, results_raw, scale_factor, cfg, rescale=False): x2_pred = x_center_pred + w_pred / 2 y2_pred = y_center_pred + h_pred / 2 - bbox_pred = torch.stack((x1_pred, y1_pred, x2_pred, y2_pred), dim=3).view( - (-1, 4)) # cxcywh + bbox_pred = torch.stack((x1_pred, y1_pred, x2_pred, y2_pred), + dim=3).view((-1, 4)) # cxcywh conf_pred = torch.sigmoid(prediction_raw[..., 4]).view(-1) # Conf - cls_pred = torch.sigmoid(prediction_raw[..., 5:]).view(-1, self.num_classes_no_bkg) # Cls pred one-hot. + cls_pred = torch.sigmoid(prediction_raw[..., 5:]).view( + -1, self.num_classes_no_bkg) # Cls pred one-hot. conf_thr = cfg.get('conf_thr', -1) conf_inds = conf_pred.ge(conf_thr).nonzero().flatten() @@ -174,29 +183,33 @@ def get_bboxes_single(self, results_raw, scale_factor, cfg, rescale=False): multi_lvl_conf_scores = torch.cat(multi_lvl_conf_scores) if multi_lvl_conf_scores.size(0) == 0: - return torch.zeros((0, 5)), torch.zeros((0,)) + return torch.zeros((0, 5)), torch.zeros((0, )) if rescale: multi_lvl_bboxes /= multi_lvl_bboxes.new_tensor(scale_factor) - padding = multi_lvl_cls_scores.new_zeros(multi_lvl_cls_scores.shape[0], 1) - multi_lvl_cls_scores = torch.cat([padding, multi_lvl_cls_scores], dim=1) - - det_bboxes, det_labels = multiclass_nms(multi_lvl_bboxes, multi_lvl_cls_scores, - cfg.score_thr, cfg.nms, - cfg.max_per_img, score_factors=multi_lvl_conf_scores) + padding = multi_lvl_cls_scores.new_zeros(multi_lvl_cls_scores.shape[0], + 1) + multi_lvl_cls_scores = torch.cat([padding, multi_lvl_cls_scores], + dim=1) + + det_bboxes, det_labels = multiclass_nms( + multi_lvl_bboxes, + multi_lvl_cls_scores, + cfg.score_thr, + cfg.nms, + cfg.max_per_img, + score_factors=multi_lvl_conf_scores) det_labels -= 1 # Hot fix return det_bboxes, det_labels - @force_fp32(apply_to=('preds_raw',)) + @force_fp32(apply_to=('preds_raw', )) def loss(self, preds_raw, gt_bboxes, gt_labels, img_metas, - # cfg, # removed since it's removed in mmdet 2.0 gt_bboxes_ignore=None): - losses = {'loss_xy': 0, 'loss_wh': 0, 'loss_conf': 0, 'loss_cls': 0} for img_id in range(len(img_metas)): @@ -207,10 +220,11 @@ def loss(self, num_grid_h = pred_raw.size(1) num_grid_w = pred_raw.size(2) pred_raw = pred_raw.view(self.num_anchors_per_scale, - self.num_attrib, - num_grid_h, - num_grid_w).permute(0, 2, 3, 1).contiguous() - anchor_grid = self.get_anchors(num_grid_h, num_grid_w, i_scale, device=pred_raw.device) + self.num_attrib, num_grid_h, + num_grid_w).permute(0, 2, 3, + 1).contiguous() + anchor_grid = self.get_anchors( + num_grid_h, num_grid_w, i_scale, device=pred_raw.device) pred_raw_list.append(pred_raw) anchor_grids.append(anchor_grid) @@ -229,15 +243,16 @@ def loss(self, self._preprocess_target_single_img(gt_bboxes_per_img, gt_labels_per_img, anchor_grids, - ignore_thresh=ignore_thresh, - one_hot_smoother=one_hot_smoother, - xy_use_logit=xy_use_logit) + ignore_thresh, + one_hot_smoother, + xy_use_logit) - losses_per_img = self.loss_single(pred_raw_list, - gt_t_across_scale, - negative_mask_across_scale, - xy_use_logit=xy_use_logit, - balance_conf=balance_conf) + losses_per_img = self.loss_single( + pred_raw_list, + gt_t_across_scale, + negative_mask_across_scale, + xy_use_logit=xy_use_logit, + balance_conf=balance_conf) for loss_term in losses: term_no_loss = loss_term[5:] @@ -245,7 +260,12 @@ def loss(self, return losses - def loss_single(self, preds_raw, gts_t, neg_masks, xy_use_logit=False, balance_conf=False): + def loss_single(self, + preds_raw, + gts_t, + neg_masks, + xy_use_logit=False, + balance_conf=False): losses = {'xy': 0, 'wh': 0, 'conf': 0, 'cls': 0} @@ -257,7 +277,7 @@ def loss_single(self, preds_raw, gts_t, neg_masks, xy_use_logit=False, balance_c pos_and_neg_mask = neg_mask + pos_mask pos_mask = pos_mask.unsqueeze(dim=-1) if torch.max(pos_and_neg_mask) > 1.: - print("Warning: pos_and_neg_mask gives max of more than 1. Some bugs in the program.") + raise Warning('pos_and_neg_mask gives max of more than 1.') pos_and_neg_mask = pos_and_neg_mask.clamp(min=0., max=1.) # ignore_mask = (1. - pos_and_neg_mask).clamp(min=0) @@ -285,22 +305,24 @@ def loss_single(self, preds_raw, gts_t, neg_masks, xy_use_logit=False, balance_c pos_weight = gt_label.new_tensor(pos_weight) - losses_cls = F.binary_cross_entropy_with_logits(pred_label, gt_label, reduction='none') + losses_cls = F.binary_cross_entropy_with_logits( + pred_label, gt_label, reduction='none') losses_cls *= pos_mask - losses_conf = F.binary_cross_entropy_with_logits(pred_conf, - gt_conf, - reduction='none', - pos_weight=pos_weight - ) * pos_and_neg_mask * conf_loss_weight + losses_conf = F.binary_cross_entropy_with_logits( + pred_conf, gt_conf, reduction='none', + pos_weight=pos_weight) * pos_and_neg_mask * conf_loss_weight if xy_use_logit: - losses_xy = F.mse_loss(pred_t_xy, gt_t_xy, reduction='none') * pos_mask * 2 + losses_xy = F.mse_loss( + pred_t_xy, gt_t_xy, reduction='none') * pos_mask * 2 else: - losses_xy = F.binary_cross_entropy_with_logits(pred_t_xy, gt_t_xy, reduction='none') * pos_mask * 2 + losses_xy = F.binary_cross_entropy_with_logits( + pred_t_xy, gt_t_xy, reduction='none') * pos_mask * 2 - losses_wh = F.mse_loss(pred_t_wh, gt_t_wh, reduction='none') * pos_mask * 2 + losses_wh = F.mse_loss( + pred_t_wh, gt_t_wh, reduction='none') * pos_mask * 2 losses['cls'] += torch.sum(losses_cls) losses['conf'] += torch.sum(losses_conf) @@ -309,7 +331,13 @@ def loss_single(self, preds_raw, gts_t, neg_masks, xy_use_logit=False, balance_c return losses - def _preprocess_target_single_img(self, gt_bboxes, gt_labels, anchor_grids, ignore_thresh, one_hot_smoother=0, xy_use_logit=False): + def _preprocess_target_single_img(self, + gt_bboxes, + gt_labels, + anchor_grids, + ignore_thresh, + one_hot_smoother=0, + xy_use_logit=False): """Generate matching bounding box prior and converted GT.""" assert gt_bboxes.size(1) == 4 assert gt_bboxes.size(0) == gt_labels.size(0) @@ -319,14 +347,16 @@ def _preprocess_target_single_img(self, gt_bboxes, gt_labels, anchor_grids, igno # each tensor has dimension of AxWxH # where A is the number of anchors in this scale, # W and H is the size of the grid in this scale - # each element of the tensor represents whether the prediction should have generate non-objectness + # each element of the tensor represents whether the prediction + # should have generate non-objectness negative_mask_across_scale = [] gt_t_across_scale = [] for anchor_grid in anchor_grids: negative_mask_size = list(anchor_grid.size())[:-1] - negative_mask = anchor_grid.new_ones(negative_mask_size, dtype=torch.uint8) + negative_mask = anchor_grid.new_ones( + negative_mask_size, dtype=torch.uint8) negative_mask_across_scale.append(negative_mask) gt_t_size = negative_mask_size + [self.num_attrib] gt_t = anchor_grid.new_zeros(gt_t_size) @@ -348,7 +378,8 @@ def _preprocess_target_single_img(self, gt_bboxes, gt_labels, anchor_grids, igno for i_scale in range(self.num_scales): stride = self.strides[i_scale] anchor_grid = anchor_grids[i_scale] - iou_gt_anchor = iou_multiple_to_one(anchor_grid, gt_bbox_cxywh, center=True) + iou_gt_anchor = iou_multiple_to_one( + anchor_grid, gt_bbox_cxywh, center=True) negative_mask = (iou_gt_anchor <= ignore_thresh) w_grid = int(gt_cx // stride) h_grid = int(gt_cy // stride) @@ -360,55 +391,72 @@ def _preprocess_target_single_img(self, gt_bboxes, gt_labels, anchor_grids, igno grid_coord_across_scale.append((h_grid, w_grid)) itmas = iou_to_match_across_scale # make the name shorter - max_match_iou_idx = max(range(len(itmas)), key=lambda x: itmas[x]) # get idx of max iou + max_match_iou_idx = max( + range(len(itmas)), + key=lambda x: itmas[x]) # get idx of max iou match_scale = max_match_iou_idx // self.num_anchors_per_scale - match_anchor_in_scale = max_match_iou_idx - match_scale * self.num_anchors_per_scale + match_anchor_in_scale = max_match_iou_idx - \ + match_scale * self.num_anchors_per_scale match_grid_h, match_grid_w = grid_coord_across_scale[match_scale] - match_anchor_w, match_anchor_h = self.anchor_base_sizes[match_scale][match_anchor_in_scale] + match_anchor_w, match_anchor_h = self.anchor_base_sizes[ + match_scale][match_anchor_in_scale] gt_tw = torch.log((gt_w / match_anchor_w).clamp(min=_EPSILON)) gt_th = torch.log((gt_h / match_anchor_h).clamp(min=_EPSILON)) - gt_tcx = (gt_cx / self.strides[match_scale] - match_grid_w).clamp(_EPSILON, 1 - _EPSILON) - gt_tcy = (gt_cy / self.strides[match_scale] - match_grid_h).clamp(_EPSILON, 1 - _EPSILON) + gt_tcx = (gt_cx / self.strides[match_scale] - match_grid_w).clamp( + _EPSILON, 1 - _EPSILON) + gt_tcy = (gt_cy / self.strides[match_scale] - match_grid_h).clamp( + _EPSILON, 1 - _EPSILON) if xy_use_logit: - gt_tcx = torch.log(gt_tcx / (1. - gt_tcx)) # inverse of sigmoid - gt_tcy = torch.log(gt_tcy / (1. - gt_tcy)) # inverse of sigmoid + gt_tcx = torch.log(gt_tcx / + (1. - gt_tcx)) # inverse of sigmoid + gt_tcy = torch.log(gt_tcy / + (1. - gt_tcy)) # inverse of sigmoid gt_t_bbox = torch.stack((gt_tcx, gt_tcy, gt_tw, gt_th)) - # in mmdet 1.x, raw label start from 1, need to minus 1 to compensate that - # gt_label_one_hot = F.one_hot(gt_label - 1, num_classes=self.num_classes_no_bkg).float() - # However, in mmdet 2.x, label “K” means background, - # and labels [0, K-1] correspond to the K = num_categories object categories. - gt_label_one_hot = F.one_hot(gt_label, num_classes=self.num_classes_no_bkg).float() + # in mmdet 1.x, raw label starts from 1, + # need to minus 1 to compensate that + # However, in mmdet 2.x, label “K” means background, and labels + # [0, K-1] correspond to the K = num_categories object categories. + gt_label_one_hot = F.one_hot( + gt_label, num_classes=self.num_classes_no_bkg).float() # TODO: Check is pending - gt_label_one_hot = gt_label_one_hot * (1 - one_hot_smoother) + one_hot_smoother / self.num_classes_no_bkg + gt_label_one_hot = gt_label_one_hot * ( + 1 - + one_hot_smoother) + one_hot_smoother / self.num_classes_no_bkg - gt_t_across_scale[match_scale][match_anchor_in_scale, match_grid_h, match_grid_w, :4] = gt_t_bbox - # if gt_t_across_scale[match_scale][match_anchor_in_scale, match_grid_h, match_grid_w, 4] == 1.: - # print("Warning: Target confidence has been assigned to 1 already.") - gt_t_across_scale[match_scale][match_anchor_in_scale, match_grid_h, match_grid_w, 4] = 1. - gt_t_across_scale[match_scale][match_anchor_in_scale, match_grid_h, match_grid_w, 5:] = gt_label_one_hot + gt_t_across_scale[match_scale][match_anchor_in_scale, match_grid_h, + match_grid_w, :4] = gt_t_bbox + gt_t_across_scale[match_scale][match_anchor_in_scale, match_grid_h, + match_grid_w, 4] = 1. + gt_t_across_scale[match_scale][match_anchor_in_scale, match_grid_h, + match_grid_w, 5:] = gt_label_one_hot - # although iou fall under a certain thres, since it has max iou, still positive - negative_mask_across_scale[match_scale][match_anchor_in_scale, match_grid_h, match_grid_w] = 0 + # although iou fall under a certain thres, + # since it has max iou, still positive + negative_mask_across_scale[match_scale][match_anchor_in_scale, + match_grid_h, + match_grid_w] = 0 return gt_t_across_scale, negative_mask_across_scale def get_anchors(self, num_grid_h, num_grid_w, scale, device='cpu'): assert scale in range(self.num_scales) - anchors = torch.tensor(self.anchor_base_sizes[scale], device=device, dtype=torch.float32) + anchors = torch.tensor( + self.anchor_base_sizes[scale], device=device, dtype=torch.float32) num_anchors = anchors.size(0) stride = self.strides[scale] - grid_x, grid_y = self._get_anchors_grid_xy(num_grid_h, num_grid_w, stride, device) + grid_x, grid_y = self._get_anchors_grid_xy(num_grid_h, num_grid_w, + stride, device) - grid_x += stride / 2 # convert to center of the grid, that is, making the raw prediction 0, not -inf - grid_y += stride / 2 + grid_x += stride / 2 # convert to center of the grid, + grid_y += stride / 2 # that is, making the raw prediction 0, not -inf grid_x = grid_x.expand((num_anchors, -1, -1)) grid_y = grid_y.expand((num_anchors, -1, -1)) @@ -423,17 +471,22 @@ def get_anchors(self, num_grid_h, num_grid_w, scale, device='cpu'): def iou_multiple_to_one(bboxes1, bbox2, center=False, zero_center=False): - """Calculate the IOUs between bboxes1 (multiple) and bbox2 (one). - :param - bboxes1: (Tensor) A n-D tensor representing first group of bboxes. - The dimension is (..., 4). - The lst dimension represent the bbox, with coordinate (x, y, w, h) or (cx, cy, w, h). - bbox2: (Tensor) A 1D tensor representing the second bbox. - The dimension is (4,). - center: (bool). Whether the bboxes are in format (cx, cy, w, h). - zero_center: (bool). Whether to align two bboxes so their center is aligned. + """ + Calculate the IOUs between bboxes1 (multiple) and bbox2 (one). + Args: + bboxes1: (Tensor) A n-D tensor representing first group of bboxes. + The dimension is (..., 4). + The lst dimension represent the bbox, with coordinate (x, y, w, h) + or (cx, cy, w, h). + bbox2: (Tensor) A 1D tensor representing the second bbox. + The dimension is (4,). + center: (bool). Whether the bboxes are in format (cx, cy, w, h). + zero_center: (bool). Whether to align two bboxes so their center + is aligned. :return - iou_: (Tensor) A (n-1)-D tensor representing the IOUs.It has one less dim than bboxes1""" + iou_: (Tensor) A (n-1)-D tensor representing the IOUs. + It has one less dim than bboxes1 + """ epsilon = 1e-6 @@ -467,11 +520,12 @@ def iou_multiple_to_one(bboxes1, bbox2, center=False, zero_center=False): left2 = x2 bottom1 = y1 bottom2 = y2 - w_intersect = (torch.min(right1, right2) - torch.max(left1, left2)).clamp(min=0) - h_intersect = (torch.min(top1, top2) - torch.max(bottom1, bottom2)).clamp(min=0) + w_intersect = (torch.min(right1, right2) - + torch.max(left1, left2)).clamp(min=0) + h_intersect = (torch.min(top1, top2) - + torch.max(bottom1, bottom2)).clamp(min=0) area_intersect = h_intersect * w_intersect iou_ = area_intersect / (area1 + area2 - area_intersect + epsilon) return iou_ - diff --git a/mmdet/models/necks/yolo_neck.py b/mmdet/models/necks/yolo_neck.py index 4601f35fe62..5e0b7883646 100644 --- a/mmdet/models/necks/yolo_neck.py +++ b/mmdet/models/necks/yolo_neck.py @@ -5,15 +5,15 @@ import torch import torch.nn as nn import torch.nn.functional as F +from mmcv.cnn import ConvModule, xavier_init +from mmcv.runner import load_checkpoint from ..builder import NECKS -from mmcv.cnn import xavier_init, ConvModule -from mmcv.runner import load_checkpoint - class DetectionNeck(nn.Module): - """The DetectionBlock contains: + """ + The DetectionBlock contains: Six ConvLayers, 1 Conv2D Layer and 1 YoloLayer. The first 6 ConvLayers are formed the following way: 1x1xn, 3x3x2n, 1x1xn, 3x3x2n, 1x1xn, 3x3x2n, @@ -25,26 +25,39 @@ class DetectionNeck(nn.Module): def __init__(self, in_channels, out_channels): super(DetectionNeck, self).__init__() - # assert double_out_channels % 2 == 0 #assert out_channels is an even number - # out_channels = double_out_channels // 2 double_out_channels = out_channels * 2 - self.conv1 = ConvModule(in_channels, out_channels, 1, - norm_cfg=dict(type='BN', requires_grad=True), - act_cfg=dict(type='LeakyReLU', negative_slope=0.1)) - self.conv2 = ConvModule(out_channels, double_out_channels, 3, - padding=1, - norm_cfg=dict(type='BN', requires_grad=True), - act_cfg=dict(type='LeakyReLU', negative_slope=0.1)) - self.conv3 = ConvModule(double_out_channels, out_channels, 1, - norm_cfg=dict(type='BN', requires_grad=True), - act_cfg=dict(type='LeakyReLU', negative_slope=0.1)) - self.conv4 = ConvModule(out_channels, double_out_channels, 3, - padding=1, - norm_cfg=dict(type='BN', requires_grad=True), - act_cfg=dict(type='LeakyReLU', negative_slope=0.1)) - self.conv5 = ConvModule(double_out_channels, out_channels, 1, - norm_cfg=dict(type='BN', requires_grad=True), - act_cfg=dict(type='LeakyReLU', negative_slope=0.1)) + self.conv1 = ConvModule( + in_channels, + out_channels, + 1, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='LeakyReLU', negative_slope=0.1)) + self.conv2 = ConvModule( + out_channels, + double_out_channels, + 3, + padding=1, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='LeakyReLU', negative_slope=0.1)) + self.conv3 = ConvModule( + double_out_channels, + out_channels, + 1, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='LeakyReLU', negative_slope=0.1)) + self.conv4 = ConvModule( + out_channels, + double_out_channels, + 3, + padding=1, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='LeakyReLU', negative_slope=0.1)) + self.conv5 = ConvModule( + double_out_channels, + out_channels, + 1, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='LeakyReLU', negative_slope=0.1)) def forward(self, x): tmp = self.conv1(x) @@ -57,22 +70,30 @@ def forward(self, x): @NECKS.register_module() class YoloNeck(nn.Module): - - """The tail side of the YoloNet. - It will take the result from DarkNet53BackBone and do some upsampling and concatenation. + """ + The neck of the YoloNet. + It can be treated as a simplified version of FPN. + It will take the result from Darknet backbone + and do some upsampling and concatenation. It will finally output the detection result. - Assembling YoloNetTail and DarkNet53BackBone will give you final result""" + """ def __init__(self): super(YoloNeck, self).__init__() self.detect1 = DetectionNeck(1024, 512) - self.conv1 = ConvModule(512, 256, 1, - norm_cfg=dict(type='BN', requires_grad=True), - act_cfg=dict(type='LeakyReLU', negative_slope=0.1)) + self.conv1 = ConvModule( + 512, + 256, + 1, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='LeakyReLU', negative_slope=0.1)) self.detect2 = DetectionNeck(768, 256) - self.conv2 = ConvModule(256, 128, 1, - norm_cfg=dict(type='BN', requires_grad=True), - act_cfg=dict(type='LeakyReLU', negative_slope=0.1)) + self.conv2 = ConvModule( + 256, + 128, + 1, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='LeakyReLU', negative_slope=0.1)) self.detect3 = DetectionNeck(384, 128) def forward(self, x):