Skip to content

Commit

Permalink
fix mask
Browse files Browse the repository at this point in the history
  • Loading branch information
dingguo1996 committed May 14, 2019
1 parent fb0b4dd commit 4f44845
Showing 1 changed file with 66 additions and 4 deletions.
70 changes: 66 additions & 4 deletions detectron/modeling/mask_rcnn_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,14 @@ def mask_rcnn_fcn_head_v1upXconvs(
spatial_scale=spatial_scale
)

mask_final_sum=add_three_stage_mask(model, dim_in)

dilation = cfg.MRCNN.DILATION
dim_inner = cfg.MRCNN.DIM_REDUCED

for i in range(num_convs):
current = model.Conv(
current,
mask_final_sum = model.Conv(
mask_final_sum,
'_[mask]_fcn' + str(i + 1),
dim_in,
dim_inner,
Expand All @@ -160,12 +162,12 @@ def mask_rcnn_fcn_head_v1upXconvs(
weight_init=(cfg.MRCNN.CONV_INIT, {'std': 0.001}),
bias_init=('ConstantFill', {'value': 0.})
)
current = model.Relu(current, current)
mask_final_sum = model.Relu(mask_final_sum, mask_final_sum)
dim_in = dim_inner

# upsample layer
model.ConvTranspose(
current,
mask_final_sum,
'conv5_mask',
dim_inner,
dim_inner,
Expand Down Expand Up @@ -327,3 +329,63 @@ def add_ResNet_roi_conv5_head_for_masks(model, blob_in, dim_in, spatial_scale):
)

return s, 2048

def add_three_stage_mask(model,dim_in):
# for the first stage
for lvl in range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL):
model.AveragePool('_[mask]_roi_feat_fpn' + str(lvl), '_[mask]_ave_fpn' + str(lvl), kernel=14, stride=1,
channel=dim_in)
model.Conv(
'_[mask]_roi_feat_fpn' + str(lvl),
'_[mask]_conv_fpn' + str(lvl),
dim_in,
dim_in,
kernel=1,
pad=0,
stride=1,
weight_init=(cfg.MRCNN.CONV_INIT, {'std': 0.001}),
bias_init=('ConstantFill', {'value': 0.})
)
model.ConvTranspose(
'_[mask]_ave_fpn' + str(lvl),
'_[mask]_deconv_fpn' + str(lvl),
dim_in,
dim_in,
kernel=14,
pad=0,
stride=1,
weight_init=(cfg.MRCNN.CONV_INIT, {'std': 0.001}),
bias_init=('ConstantFill', {'value': 0.})
)
model.Sum(['_[mask]_conv_fpn' + str(lvl), '_[mask]_deconv_fpn' + str(lvl)], '_[mask]_part_fpn' + str(lvl))
model.Sum(['_[mask]_part_fpn2', '_[mask]_part_fpn3', '_[mask]_part_fpn4', '_[mask]_part_fpn5'], '_[mask]_part_sum')
for lvl in range(cfg.FPN.RPN_MIN_LEVEL, cfg.FPN.RPN_MAX_LEVEL):
model.Conv(
'_[mask]_part_fpn' + str(lvl),
'_[mask]_part_conv_fpn' + str(lvl),
dim_in,
dim_in,
kernel=1,
pad=0,
stride=1,
weight_init=(cfg.MRCNN.CONV_INIT, {'std': 0.001}),
bias_init=('ConstantFill', {'value': 0.})
)
model.Conv(
'_[mask]_part_sum',
'_[mask]_part_sum_conv_fpn' + str(lvl),
dim_in,
dim_in,
kernel=1,
pad=0,
stride=1,
weight_init=(cfg.MRCNN.CONV_INIT, {'std': 0.001}),
bias_init=('ConstantFill', {'value': 0.})
)
model.Sum(['_[mask]_part_sum_conv_fpn' + str(lvl), '_[mask]_part_conv_fpn' + str(lvl)],
'_[mask]_handle_1_roi_feat_fpn' + str(lvl))

mask_final_sum=model.Sum(
['_[mask]_handle_1_roi_feat_fpn2', '_[mask]_handle_1_roi_feat_fpn3', '_[mask]_handle_1_roi_feat_fpn4',
'_[mask]_handle_1_roi_feat_fpn5'], '_[mask]_final_sum')
return mask_final_sum

0 comments on commit 4f44845

Please sign in to comment.