Skip to content

Commit

Permalink
Work on DANet for GL, 2
Browse files Browse the repository at this point in the history
  • Loading branch information
osmr committed Oct 31, 2020
1 parent 8aa71f5 commit 98cf126
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ before_script:
# stop the build if there are Python syntax errors or undefined names
- flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics --exclude=./venv*
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
- flake8 . --count --max-complexity=44 --max-line-length=127 --ignore=W504,F403,F405,E126,E127,E402,W605 --statistics --exclude=./gluon/gluoncv2/models/others,./pytorch/pytorchcv/models/others,./chainer_/chainercv2/models/others,./keras_/kerascv/models/others,./tensorflow_/tensorflowcv/models/others,./venv*
- flake8 . --count --max-complexity=45 --max-line-length=127 --ignore=W504,F403,F405,E126,E127,E402,W605 --statistics --exclude=./gluon/gluoncv2/models/others,./pytorch/pytorchcv/models/others,./chainer_/chainercv2/models/others,./keras_/kerascv/models/others,./tensorflow_/tensorflowcv/models/others,./venv*
script:
- true # pytest --capture=sys # add others tests here
notifications:
Expand Down
9 changes: 9 additions & 0 deletions convert_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,15 @@ def convert_gl2gl(dst_net,
src_model,
ctx):
if src_model.startswith("oth_danet_resnet"):
src6 = list(filter(re.compile("^head.sa.gamma").search, src_param_keys))
src6n = [key for key in src_param_keys if key not in src6]
src_param_keys = src6n + src6
src7 = list(filter(re.compile("^head.conv51").search, src_param_keys))
src7n = [key for key in src_param_keys if key not in src7]
src_param_keys = src7n + src7
src8 = list(filter(re.compile("^head.conv6").search, src_param_keys))
src8n = [key for key in src_param_keys if key not in src8]
src_param_keys = src8n + src8
src1 = list(filter(re.compile("^head.conv5c").search, src_param_keys))
src1n = [key for key in src_param_keys if key not in src1]
src_param_keys = src1n + src1
Expand Down
46 changes: 38 additions & 8 deletions gluon/gluoncv2/models/danet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Original paper: 'Dual Attention Network for Scene Segmentation,' https://arxiv.org/abs/1809.02983.
"""

__all__ = ['DANet', 'danet_resnetd50b_cityscapes', 'danet_resnetd101b_cityscapes']
__all__ = ['DANet', 'danet_resnetd50b_cityscapes', 'danet_resnetd101b_cityscapes', 'ScaleBlock']


import os
Expand All @@ -14,6 +14,36 @@
from .resnetd import resnetd50b, resnetd101b


class ScaleBlock(HybridBlock):
"""
Simple scale block.
"""
def __init__(self,
**kwargs):
super(ScaleBlock, self).__init__(**kwargs)
with self.name_scope():
self.alpha = self.params.get(
"alpha",
shape=(1,),
init=mx.init.Zero(),
allow_deferred_init=True)

def hybrid_forward(self, F, x, alpha):
return F.broadcast_mul(alpha, x)

def __repr__(self):
s = '{name}(alpha={alpha})'
return s.format(
name=self.__class__.__name__,
gamma=self.alpha.shape[0])

def calc_flops(self, x):
assert (x.shape[0] == 1)
num_flops = x.size
num_macs = 0
return num_flops, num_macs


class PosAttBlock(HybridBlock):
"""
Position attention block from 'Dual Attention Network for Scene Segmentation,' https://arxiv.org/abs/1809.02983.
Expand Down Expand Up @@ -46,9 +76,9 @@ def __init__(self,
in_channels=channels,
out_channels=channels,
use_bias=True)
self.gamma = self.params.get("gamma", shape=(1,), init=mx.init.Zero())
self.scale = ScaleBlock()

def hybrid_forward(self, F, x, gamma):
def hybrid_forward(self, F, x):
proj_query = self.query_conv(x).reshape((0, 0, -1))
proj_key = self.key_conv(x).reshape((0, 0, -1))
proj_value = self.value_conv(x).reshape((0, 0, -1))
Expand All @@ -58,7 +88,7 @@ def hybrid_forward(self, F, x, gamma):

y = F.batch_dot(proj_value, w, transpose_b=True)
y = F.reshape_like(y, x, lhs_begin=2, lhs_end=None, rhs_begin=2, rhs_end=None)
y = F.broadcast_mul(gamma, y) + x
y = self.scale(y) + x
return y


Expand All @@ -71,9 +101,9 @@ def __init__(self,
**kwargs):
super(ChaAttBlock, self).__init__(**kwargs)
with self.name_scope():
self.gamma = self.params.get("gamma", shape=(1,), init=mx.init.Zero())
self.scale = ScaleBlock()

def hybrid_forward(self, F, x, gamma):
def hybrid_forward(self, F, x):
proj_query = x.reshape((0, 0, -1))
proj_key = x.reshape((0, 0, -1))
proj_value = x.reshape((0, 0, -1))
Expand All @@ -84,7 +114,7 @@ def hybrid_forward(self, F, x, gamma):

y = F.batch_dot(w, proj_value)
y = F.reshape_like(y, x, lhs_begin=2, lhs_end=None, rhs_begin=2, rhs_end=None)
y = F.broadcast_mul(gamma, y) + x
y = self.scale(y) + x
return y


Expand Down Expand Up @@ -386,7 +416,7 @@ def _test():
if not pretrained:
net.initialize(ctx=ctx)

# net.hybridize()
net.hybridize()
net_params = net.collect_params()
weight_count = 0
for param in net_params.values():
Expand Down
3 changes: 2 additions & 1 deletion gluon/model_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .gluoncv2.models.proxylessnas import ProxylessUnit
from .gluoncv2.models.lwopenpose_cmupan import LwopDecoderFinalBlock
from .gluoncv2.models.centernet import CenterNetHeatmapMaxDet
from .gluoncv2.models.danet import ScaleBlock

__all__ = ['measure_model']

Expand Down Expand Up @@ -221,7 +222,7 @@ def call_hook(block, x, y):
elif isinstance(block, ProxylessUnit):
extra_num_flops = x[0].size
extra_num_macs = 0
elif type(block) in [InterpolationBlock, HeatmapMaxDetBlock, CenterNetHeatmapMaxDet]:
elif type(block) in [InterpolationBlock, HeatmapMaxDetBlock, CenterNetHeatmapMaxDet, ScaleBlock]:
extra_num_flops, extra_num_macs = block.calc_flops(x[0])
elif isinstance(block, LwopDecoderFinalBlock):
if not block.calc_3d_features:
Expand Down

0 comments on commit 98cf126

Please sign in to comment.