Skip to content

Commit

Permalink
Merge pull request open-mmlab#4 from ElectronicElephant/yolo-dev
Browse files Browse the repository at this point in the history
Refactor backbone, lint and format
  • Loading branch information
ElectronicElephant committed Jun 29, 2020
2 parents 1f64c24 + 4f24fa1 commit 4f82cd5
Show file tree
Hide file tree
Showing 7 changed files with 307 additions and 308 deletions.
10 changes: 6 additions & 4 deletions configs/yolo/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#YOLOv3
# YOLOv3

## Introduction
```
Expand All @@ -16,8 +16,10 @@

Test set: COCO val2017

bbox_mAP: 0.3520
bbox_mAP: 0.3640

bbox_mAP_50: 0.6100
bbox_mAP_50: 0.6350

Checkpoint link: [here](https://drive.google.com/drive/folders/1NzQ5LwBaYPlu1gywnRAViNz70NV9743O?usp=sharing)
Checkpoint link: [here](https://drive.google.com/drive/folders/1NzQ5LwBaYPlu1gywnRAViNz70NV9743O?usp=sharing)

This implementation originates from the project of Haoyu Wu(@wuhy08) at Western Digital.
45 changes: 21 additions & 24 deletions configs/yolo/yolov3_ms_aug_273e.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,15 @@
type='YoloNet',
pretrained='./work_dirs/darknet_state_dict_only.pth',
backbone=dict(
type='DarkNet53',),
neck=dict(
type='YoloNeck',),
bbox_head=dict(
type='YoloHead',))
type='Darknet',
depth=53,
out_indices=(3, 4, 5),
),
neck=dict(type='YoloNeck', ),
bbox_head=dict(type='YoloHead', ))
# training and testing settings
train_cfg = dict(
one_hot_smoother=0.,
ignore_config=0.5,
xy_use_logit=False,
debug=False)
one_hot_smoother=0., ignore_config=0.5, xy_use_logit=False, debug=False)
test_cfg = dict(
nms_pre=1000,
min_bbox_size=0,
Expand All @@ -26,21 +24,21 @@
# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[0, 0, 0], std=[255., 255., 255.], to_rgb=True)
img_norm_cfg = dict(mean=[0, 0, 0], std=[255., 255., 255.], to_rgb=True)
# TODO: Add PhotoMetricDistortion
train_pipeline = [
dict(type='LoadImageFromFile', to_float32=True),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='PhotoMetricDistortion'),
dict(type='Expand',
mean=img_norm_cfg['mean'],
to_rgb=img_norm_cfg['to_rgb'],
ratio_range=(1, 2)
),
dict(type='MinIoURandomCrop',
min_ious=(0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
min_crop_size=0.3
),
dict(
type='Expand',
mean=img_norm_cfg['mean'],
to_rgb=img_norm_cfg['to_rgb'],
ratio_range=(1, 2)),
dict(
type='MinIoURandomCrop',
min_ious=(0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
min_crop_size=0.3),
dict(type='Resize', img_scale=[(320, 320), (608, 608)], keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
Expand Down Expand Up @@ -83,8 +81,7 @@
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline,
)
)
))
# optimizer
optimizer = dict(type='SGD', lr=5e-4, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
Expand All @@ -101,16 +98,16 @@
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
# runtime settings
total_epochs = 273
device_ids = range(8)
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = './work_dirs/yolo_pretrained'
load_from = None
resume_from = None
workflow = [('train', 1)]
evaluation = dict(interval=1, metric=['bbox'])
# TODO: Remove hot fix
find_unused_parameters = True
116 changes: 0 additions & 116 deletions configs/yolo/yolov3_ms_aug_273e_no_pretrain.py

This file was deleted.

4 changes: 2 additions & 2 deletions mmdet/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from .resnet import ResNet, ResNetV1d
from .resnext import ResNeXt
from .ssd_vgg import SSDVGG
from .darknet import DarkNet53
from .darknet import Darknet

__all__ = [
'RegNet', 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'Res2Net',
'HourglassNet', 'DarkNet53'
'HourglassNet', 'Darknet'
]
83 changes: 62 additions & 21 deletions mmdet/models/backbones/darknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

class ResBlock(nn.Module):
"""The basic residual block used in YoloV3.
Each ResBlock consists of two ConvLayers and the input is added to the final output.
Each ResBlock consists of two ConvModules and the input is added to the final output.
Each ConvModule is composed of Conv, BN, and LeakyReLU
In YoloV3 paper, the first convLayer has half of the number of the filters as much as the second convLayer.
The first convLayer has filter size of 1x1 and the second one has the filter size of 3x3.
"""
Expand Down Expand Up @@ -46,7 +47,7 @@ def forward(self, x):


def make_conv_and_res_block(in_channels, out_channels, res_repeat):
"""In Darknet 53 backbone, there is usually one Conv Layer followed by some ResBlock.
"""In Darknet backbone, there is usually one Conv Layer followed by some ResBlock.
This function will make that.
The Conv layers always have 3x3 filters with stride=2.
The number of the filters in Conv layer is the same as the out channels of the ResBlock"""
Expand All @@ -64,37 +65,77 @@ def make_conv_and_res_block(in_channels, out_channels, res_repeat):


@BACKBONES.register_module()
class DarkNet53(nn.Module):
class Darknet(nn.Module):
"""Darknet backbone.
Args:
depth (int): Depth of Darknet. Currently only support 53.
out_indices (Sequence[int]): Output from which stages.
Note: By default, the sequence of the layers will be returned
in a **reversed** manner. i.e., from bottom to up.
See the example bellow.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
reverse_output (bool): If True, the sequence of the output layers
will be from bottom to up. Default: True. (To cope with YoloNeck)
Example:
>>> from mmdet.models import Darknet
>>> import torch
>>> self = Darknet(depth=53)
>>> self.eval()
>>> inputs = torch.rand(1, 3, 416, 416)
>>> level_outputs = self.forward(inputs)
>>> for level_out in level_outputs:
... print(tuple(level_out.shape))
...
(1, 1024, 13, 13)
(1, 512, 26, 26)
(1, 256, 52, 52)
"""

def __init__(self,
depth=53,
out_indices=(3, 4, 5),
norm_eval=True,
reverse_output=False):
super(DarkNet53, self).__init__()
reverse_output=True):
super(Darknet, self).__init__()
self.depth = depth
self.out_indices = out_indices
if self.depth == 53:
self.layers = [1, 2, 8, 8, 4]
self.channels = [[32, 64], [64, 128], [128, 256], [256, 512], [512, 1024]]
else:
raise KeyError(f'invalid depth {depth} for darknet')

self.conv1 = ConvModule(3, 32, 3,
padding=1,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='LeakyReLU', negative_slope=0.1))
self.cr_block1 = make_conv_and_res_block(32, 64, 1)
self.cr_block2 = make_conv_and_res_block(64, 128, 2)
self.cr_block3 = make_conv_and_res_block(128, 256, 8)
self.cr_block4 = make_conv_and_res_block(256, 512, 8)
self.cr_block5 = make_conv_and_res_block(512, 1024, 4)

self.cr_blocks = ['conv1']
for i, n_layers in enumerate(self.layers):
layer_name = f'cr_block{i + 1}'
in_c, out_c = self.channels[i]
self.add_module(layer_name, make_conv_and_res_block(in_c, out_c, n_layers))
self.cr_blocks.append(layer_name)

self.norm_eval = norm_eval
self.reverse_output=reverse_output

def forward(self, x):
tmp = self.conv1(x)
tmp = self.cr_block1(tmp)
tmp = self.cr_block2(tmp)
out3 = self.cr_block3(tmp)
out2 = self.cr_block4(out3)
out1 = self.cr_block5(out2)

if not self.reverse_output:
return out1, out2, out3
outs = []
for i, layer_name in enumerate(self.cr_blocks):
cr_block = getattr(self, layer_name)
x = cr_block(x)
if i in self.out_indices:
outs.append(x)

if self.reverse_output:
return tuple(outs[::-1])
else:
return out3, out2, out1
return tuple(outs)

def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
Expand All @@ -115,7 +156,7 @@ def _freeze_stages(self):
param.requires_grad = False

def train(self, mode=True):
super(DarkNet53, self).train(mode)
super(Darknet, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
Expand Down

0 comments on commit 4f82cd5

Please sign in to comment.