Skip to content

Commit

Permalink
update aff & iaff #14
Browse files Browse the repository at this point in the history
  • Loading branch information
pprp committed Dec 15, 2020
1 parent 253e66b commit 098daa2
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 78 deletions.
82 changes: 53 additions & 29 deletions Plug-and-play module/attention/AFF/AFF.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,56 @@
# source: https://mp.weixin.qq.com/s/La6rbQpnZzjWH3psB2gD6Q
# code: https://github.com/YimianDai/open-aff

# https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
# AFF
class AXYforXplusYAddFuse(HybridBlock):
def __init__(self, channels=64):
super(AXYforXplusYAddFuse, self).__init__()

with self.name_scope():

self.local_att = nn.HybridSequential(prefix='local_att')
self.local_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0))
self.local_att.add(nn.BatchNorm())

self.global_att = nn.HybridSequential(prefix='global_att')
self.global_att.add(nn.GlobalAvgPool2D())
self.global_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0))
self.global_att.add(nn.BatchNorm())

self.sig = nn.Activation('sigmoid')

def hybrid_forward(self, F, x, residual):

xi = x + residual
xl = self.local_att(xi)
xg = self.global_att(xi)
xlg = F.broadcast_add(xl, xg)
wei = self.sig(xlg)

xo = F.broadcast_mul(wei, residual) + x

return xo
import torch
import torch.nn as nn


class DAF(nn.Module):
'''
直接相加 DirectAddFuse
'''

def __init__(self):
super(DAF, self).__init__()

def forward(self, x, residual):
return x + residual

class AFF(nn.Module):
'''
多特征融合 AFF
'''

def __init__(self, channels=64, r=4):
super(AFF, self).__init__()
inter_channels = int(channels // r)

self.local_att = nn.Sequential(
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(inter_channels),
nn.ReLU(inplace=True),
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(channels),
)

self.global_att = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(inter_channels),
nn.ReLU(inplace=True),
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(channels),
)

self.sigmoid = nn.Sigmoid()

def forward(self, x, residual):
xa = x + residual
xl = self.local_att(xa)
xg = self.global_att(xa)
xlg = xl + xg
wei = self.sigmoid(xlg)

xo = 2 * x * wei + 2 * residual * (1 - wei)
return xo
64 changes: 34 additions & 30 deletions Plug-and-play module/attention/AFF/MC-CAM.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,37 @@
# source: https://mp.weixin.qq.com/s/La6rbQpnZzjWH3psB2gD6Q
# code: https://github.com/YimianDai/open-aff



class ResGlobLocaChaFuse(HybridBlock):
def __init__(self, channels=64):
super(ResGlobLocaChaFuse, self).__init__()

with self.name_scope():

self.local_att = nn.HybridSequential(prefix='local_att')
self.local_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0))
self.local_att.add(nn.BatchNorm())

self.global_att = nn.HybridSequential(prefix='global_att')
self.global_att.add(nn.GlobalAvgPool2D())
self.global_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0))
self.global_att.add(nn.BatchNorm())

self.sig = nn.Activation('sigmoid')

def hybrid_forward(self, F, x, residual):

xa = x + residual
xl = self.local_att(xa)
xg = self.global_att(xa)
xlg = F.broadcast_add(xl, xg)
wei = self.sig(xlg)

xo = 2 * F.broadcast_mul(x, wei) + 2 * F.broadcast_mul(residual, 1-wei)

return xo
class MS_CAM(nn.Module):
'''
单特征 进行通道加权,作用类似SE模块
'''

def __init__(self, channels=64, r=4):
super(MS_CAM, self).__init__()
inter_channels = int(channels // r)

self.local_att = nn.Sequential(
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(inter_channels),
nn.ReLU(inplace=True),
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(channels),
)

self.global_att = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(inter_channels),
nn.ReLU(inplace=True),
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(channels),
)

self.sigmoid = nn.Sigmoid()

def forward(self, x):
xl = self.local_att(x)
xg = self.global_att(x)
xlg = xl + xg
wei = self.sigmoid(xlg)
return x * wei
91 changes: 72 additions & 19 deletions Plug-and-play module/attention/AFF/iAFF.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,83 @@
# source: https://mp.weixin.qq.com/s/La6rbQpnZzjWH3psB2gD6Q
# code: https://github.com/YimianDai/open-aff
class AXYforXYAddFuse(HybridBlock):
def __init__(self, channels=64):
super(AXYforXYAddFuse, self).__init__()

with self.name_scope():

self.local_att = nn.HybridSequential(prefix='local_att')
self.local_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0))
self.local_att.add(nn.BatchNorm())

import torch
import torch.nn as nn

self.global_att = nn.HybridSequential(prefix='global_att')
self.global_att.add(nn.GlobalAvgPool2D())
self.global_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0))
self.global_att.add(nn.BatchNorm())

self.sig = nn.Activation('sigmoid')
class DAF(nn.Module):
'''
直接相加 DirectAddFuse
'''

def hybrid_forward(self, F, x, residual):
def __init__(self):
super(DAF, self).__init__()

xi = x + residual
xl = self.local_att(xi)
xg = self.global_att(xi)
xlg = F.broadcast_add(xl, xg)
wei = self.sig(xlg)
def forward(self, x, residual):
return x + residual

xo = F.broadcast_mul(wei, xi)

class iAFF(nn.Module):
'''
多特征融合 iAFF
'''

def __init__(self, channels=64, r=4):
super(iAFF, self).__init__()
inter_channels = int(channels // r)

# 本地注意力
self.local_att = nn.Sequential(
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(inter_channels),
nn.ReLU(inplace=True),
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(channels),
)

# 全局注意力
self.global_att = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(inter_channels),
nn.ReLU(inplace=True),
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(channels),
)

# 第二次本地注意力
self.local_att2 = nn.Sequential(
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(inter_channels),
nn.ReLU(inplace=True),
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(channels),
)
# 第二次全局注意力
self.global_att2 = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(inter_channels),
nn.ReLU(inplace=True),
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(channels),
)

self.sigmoid = nn.Sigmoid()

def forward(self, x, residual):
xa = x + residual
xl = self.local_att(xa)
xg = self.global_att(xa)
xlg = xl + xg
wei = self.sigmoid(xlg)
xi = x * wei + residual * (1 - wei)

xl2 = self.local_att2(xi)
xg2 = self.global_att(xi)
xlg2 = xl2 + xg2
wei2 = self.sigmoid(xlg2)
xo = x * wei2 + residual * (1 - wei2)
return xo

0 comments on commit 098daa2

Please sign in to comment.