-
Notifications
You must be signed in to change notification settings - Fork 310
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
159 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,56 @@ | ||
# source: https://mp.weixin.qq.com/s/La6rbQpnZzjWH3psB2gD6Q | ||
# code: https://github.com/YimianDai/open-aff | ||
|
||
# https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py | ||
# AFF | ||
class AXYforXplusYAddFuse(HybridBlock): | ||
def __init__(self, channels=64): | ||
super(AXYforXplusYAddFuse, self).__init__() | ||
|
||
with self.name_scope(): | ||
|
||
self.local_att = nn.HybridSequential(prefix='local_att') | ||
self.local_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) | ||
self.local_att.add(nn.BatchNorm()) | ||
|
||
self.global_att = nn.HybridSequential(prefix='global_att') | ||
self.global_att.add(nn.GlobalAvgPool2D()) | ||
self.global_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) | ||
self.global_att.add(nn.BatchNorm()) | ||
|
||
self.sig = nn.Activation('sigmoid') | ||
|
||
def hybrid_forward(self, F, x, residual): | ||
|
||
xi = x + residual | ||
xl = self.local_att(xi) | ||
xg = self.global_att(xi) | ||
xlg = F.broadcast_add(xl, xg) | ||
wei = self.sig(xlg) | ||
|
||
xo = F.broadcast_mul(wei, residual) + x | ||
|
||
return xo | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class DAF(nn.Module): | ||
''' | ||
直接相加 DirectAddFuse | ||
''' | ||
|
||
def __init__(self): | ||
super(DAF, self).__init__() | ||
|
||
def forward(self, x, residual): | ||
return x + residual | ||
|
||
class AFF(nn.Module): | ||
''' | ||
多特征融合 AFF | ||
''' | ||
|
||
def __init__(self, channels=64, r=4): | ||
super(AFF, self).__init__() | ||
inter_channels = int(channels // r) | ||
|
||
self.local_att = nn.Sequential( | ||
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), | ||
nn.BatchNorm2d(inter_channels), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), | ||
nn.BatchNorm2d(channels), | ||
) | ||
|
||
self.global_att = nn.Sequential( | ||
nn.AdaptiveAvgPool2d(1), | ||
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), | ||
nn.BatchNorm2d(inter_channels), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), | ||
nn.BatchNorm2d(channels), | ||
) | ||
|
||
self.sigmoid = nn.Sigmoid() | ||
|
||
def forward(self, x, residual): | ||
xa = x + residual | ||
xl = self.local_att(xa) | ||
xg = self.global_att(xa) | ||
xlg = xl + xg | ||
wei = self.sigmoid(xlg) | ||
|
||
xo = 2 * x * wei + 2 * residual * (1 - wei) | ||
return xo |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,37 @@ | ||
# source: https://mp.weixin.qq.com/s/La6rbQpnZzjWH3psB2gD6Q | ||
# code: https://github.com/YimianDai/open-aff | ||
|
||
|
||
|
||
class ResGlobLocaChaFuse(HybridBlock): | ||
def __init__(self, channels=64): | ||
super(ResGlobLocaChaFuse, self).__init__() | ||
|
||
with self.name_scope(): | ||
|
||
self.local_att = nn.HybridSequential(prefix='local_att') | ||
self.local_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) | ||
self.local_att.add(nn.BatchNorm()) | ||
|
||
self.global_att = nn.HybridSequential(prefix='global_att') | ||
self.global_att.add(nn.GlobalAvgPool2D()) | ||
self.global_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) | ||
self.global_att.add(nn.BatchNorm()) | ||
|
||
self.sig = nn.Activation('sigmoid') | ||
|
||
def hybrid_forward(self, F, x, residual): | ||
|
||
xa = x + residual | ||
xl = self.local_att(xa) | ||
xg = self.global_att(xa) | ||
xlg = F.broadcast_add(xl, xg) | ||
wei = self.sig(xlg) | ||
|
||
xo = 2 * F.broadcast_mul(x, wei) + 2 * F.broadcast_mul(residual, 1-wei) | ||
|
||
return xo | ||
class MS_CAM(nn.Module): | ||
''' | ||
单特征 进行通道加权,作用类似SE模块 | ||
''' | ||
|
||
def __init__(self, channels=64, r=4): | ||
super(MS_CAM, self).__init__() | ||
inter_channels = int(channels // r) | ||
|
||
self.local_att = nn.Sequential( | ||
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), | ||
nn.BatchNorm2d(inter_channels), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), | ||
nn.BatchNorm2d(channels), | ||
) | ||
|
||
self.global_att = nn.Sequential( | ||
nn.AdaptiveAvgPool2d(1), | ||
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), | ||
nn.BatchNorm2d(inter_channels), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), | ||
nn.BatchNorm2d(channels), | ||
) | ||
|
||
self.sigmoid = nn.Sigmoid() | ||
|
||
def forward(self, x): | ||
xl = self.local_att(x) | ||
xg = self.global_att(x) | ||
xlg = xl + xg | ||
wei = self.sigmoid(xlg) | ||
return x * wei |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,83 @@ | ||
# source: https://mp.weixin.qq.com/s/La6rbQpnZzjWH3psB2gD6Q | ||
# code: https://github.com/YimianDai/open-aff | ||
class AXYforXYAddFuse(HybridBlock): | ||
def __init__(self, channels=64): | ||
super(AXYforXYAddFuse, self).__init__() | ||
|
||
with self.name_scope(): | ||
|
||
self.local_att = nn.HybridSequential(prefix='local_att') | ||
self.local_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) | ||
self.local_att.add(nn.BatchNorm()) | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
self.global_att = nn.HybridSequential(prefix='global_att') | ||
self.global_att.add(nn.GlobalAvgPool2D()) | ||
self.global_att.add(nn.Conv2D(channels, kernel_size=1, strides=1, padding=0)) | ||
self.global_att.add(nn.BatchNorm()) | ||
|
||
self.sig = nn.Activation('sigmoid') | ||
class DAF(nn.Module): | ||
''' | ||
直接相加 DirectAddFuse | ||
''' | ||
|
||
def hybrid_forward(self, F, x, residual): | ||
def __init__(self): | ||
super(DAF, self).__init__() | ||
|
||
xi = x + residual | ||
xl = self.local_att(xi) | ||
xg = self.global_att(xi) | ||
xlg = F.broadcast_add(xl, xg) | ||
wei = self.sig(xlg) | ||
def forward(self, x, residual): | ||
return x + residual | ||
|
||
xo = F.broadcast_mul(wei, xi) | ||
|
||
class iAFF(nn.Module): | ||
''' | ||
多特征融合 iAFF | ||
''' | ||
|
||
def __init__(self, channels=64, r=4): | ||
super(iAFF, self).__init__() | ||
inter_channels = int(channels // r) | ||
|
||
# 本地注意力 | ||
self.local_att = nn.Sequential( | ||
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), | ||
nn.BatchNorm2d(inter_channels), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), | ||
nn.BatchNorm2d(channels), | ||
) | ||
|
||
# 全局注意力 | ||
self.global_att = nn.Sequential( | ||
nn.AdaptiveAvgPool2d(1), | ||
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), | ||
nn.BatchNorm2d(inter_channels), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), | ||
nn.BatchNorm2d(channels), | ||
) | ||
|
||
# 第二次本地注意力 | ||
self.local_att2 = nn.Sequential( | ||
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), | ||
nn.BatchNorm2d(inter_channels), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), | ||
nn.BatchNorm2d(channels), | ||
) | ||
# 第二次全局注意力 | ||
self.global_att2 = nn.Sequential( | ||
nn.AdaptiveAvgPool2d(1), | ||
nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), | ||
nn.BatchNorm2d(inter_channels), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), | ||
nn.BatchNorm2d(channels), | ||
) | ||
|
||
self.sigmoid = nn.Sigmoid() | ||
|
||
def forward(self, x, residual): | ||
xa = x + residual | ||
xl = self.local_att(xa) | ||
xg = self.global_att(xa) | ||
xlg = xl + xg | ||
wei = self.sigmoid(xlg) | ||
xi = x * wei + residual * (1 - wei) | ||
|
||
xl2 = self.local_att2(xi) | ||
xg2 = self.global_att(xi) | ||
xlg2 = xl2 + xg2 | ||
wei2 = self.sigmoid(xlg2) | ||
xo = x * wei2 + residual * (1 - wei2) | ||
return xo |