Skip to content

Commit

Permalink
use configurable for box head
Browse files Browse the repository at this point in the history
Reviewed By: rbgirshick

Differential Revision: D20808474

fbshipit-source-id: 7790a22f20c7c119f99a7d4a29246e52959c7f5c
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Apr 3, 2020
1 parent ffff8ac commit 41ab438
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 35 deletions.
19 changes: 0 additions & 19 deletions configs/quick_schedules/mask_rcnn_R_50_C4_GCN_instant_test.yaml

This file was deleted.

4 changes: 3 additions & 1 deletion detectron2/layers/batch_norm.py
Expand Up @@ -127,7 +127,9 @@ def convert_frozen_batchnorm(cls, module):
def get_norm(norm, out_channels):
"""
Args:
norm (str or callable):
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
or a callable that takes a channel number and returns
the normalization layer as a nn.Module.
Returns:
nn.Module or None: the normalization layer
Expand Down
45 changes: 30 additions & 15 deletions detectron2/modeling/roi_heads/box_head.py
Expand Up @@ -5,6 +5,7 @@
from torch import nn
from torch.nn import functional as F

from detectron2.config import configurable
from detectron2.layers import Conv2d, Linear, ShapeSpec, get_norm
from detectron2.utils.registry import Registry

Expand All @@ -19,26 +20,29 @@
@ROI_BOX_HEAD_REGISTRY.register()
class FastRCNNConvFCHead(nn.Module):
"""
A head with several 3x3 conv layers (each followed by norm & relu) and
A head with several 3x3 conv layers (each followed by norm & relu) and then
several fc layers (each followed by relu).
"""

def __init__(self, cfg, input_shape: ShapeSpec):
@configurable
def __init__(
self,
input_shape: ShapeSpec,
num_conv: int,
conv_dim: int,
num_fc: int,
fc_dim: int,
conv_norm="",
):
"""
The following attributes are parsed from config:
Args:
input_shape (ShapeSpec): shape of the input feature.
num_conv, num_fc: the number of conv/fc layers
conv_dim/fc_dim: the dimension of the conv/fc layers
norm: normalization for the conv layers
conv_dim/fc_dim: the output dimension of the conv/fc layers
conv_norm: normalization for the conv layers. See :func:`detectron2.layers.get_norm`
for supported types.
"""
super().__init__()

# fmt: off
num_conv = cfg.MODEL.ROI_BOX_HEAD.NUM_CONV
conv_dim = cfg.MODEL.ROI_BOX_HEAD.CONV_DIM
num_fc = cfg.MODEL.ROI_BOX_HEAD.NUM_FC
fc_dim = cfg.MODEL.ROI_BOX_HEAD.FC_DIM
norm = cfg.MODEL.ROI_BOX_HEAD.NORM
# fmt: on
assert num_conv + num_fc > 0

self._output_size = (input_shape.channels, input_shape.height, input_shape.width)
Expand All @@ -50,8 +54,8 @@ def __init__(self, cfg, input_shape: ShapeSpec):
conv_dim,
kernel_size=3,
padding=1,
bias=not norm,
norm=get_norm(norm, conv_dim),
bias=not conv_norm,
norm=get_norm(conv_norm, conv_dim),
activation=F.relu,
)
self.add_module("conv{}".format(k + 1), conv)
Expand All @@ -70,6 +74,17 @@ def __init__(self, cfg, input_shape: ShapeSpec):
for layer in self.fcs:
weight_init.c2_xavier_fill(layer)

@classmethod
def from_config(cls, cfg, input_shape):
return {
"num_conv": cfg.MODEL.ROI_BOX_HEAD.NUM_CONV,
"conv_dim": cfg.MODEL.ROI_BOX_HEAD.CONV_DIM,
"num_fc": cfg.MODEL.ROI_BOX_HEAD.NUM_FC,
"fc_dim": cfg.MODEL.ROI_BOX_HEAD.FC_DIM,
"conv_norm": cfg.MODEL.ROI_BOX_HEAD.NORM,
"input_shape": input_shape,
}

def forward(self, x):
for layer in self.conv_norm_relus:
x = layer(x)
Expand Down

0 comments on commit 41ab438

Please sign in to comment.