-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
DynamicConv.py
90 lines (77 loc) · 3.42 KB
/
DynamicConv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from torch import nn
from torch.nn import functional as F
class Attention(nn.Module):
def __init__(self,in_planes,ratio,K,temprature=30,init_weight=True):
super().__init__()
self.avgpool=nn.AdaptiveAvgPool2d(1)
self.temprature=temprature
assert in_planes>ratio
hidden_planes=in_planes//ratio
self.net=nn.Sequential(
nn.Conv2d(in_planes,hidden_planes,kernel_size=1,bias=False),
nn.ReLU(),
nn.Conv2d(hidden_planes,K,kernel_size=1,bias=False)
)
if(init_weight):
self._initialize_weights()
def update_temprature(self):
if(self.temprature>1):
self.temprature-=1
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if isinstance(m ,nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self,x):
att=self.avgpool(x) #bs,dim,1,1
att=self.net(att).view(x.shape[0],-1) #bs,K
return F.softmax(att/self.temprature,-1)
class DynamicConv(nn.Module):
def __init__(self,in_planes,out_planes,kernel_size,stride,padding=0,dilation=1,grounps=1,bias=True,K=4,temprature=30,ratio=4,init_weight=True):
super().__init__()
self.in_planes=in_planes
self.out_planes=out_planes
self.kernel_size=kernel_size
self.stride=stride
self.padding=padding
self.dilation=dilation
self.groups=grounps
self.bias=bias
self.K=K
self.init_weight=init_weight
self.attention=Attention(in_planes=in_planes,ratio=ratio,K=K,temprature=temprature,init_weight=init_weight)
self.weight=nn.Parameter(torch.randn(K,out_planes,in_planes//grounps,kernel_size,kernel_size),requires_grad=True)
if(bias):
self.bias=nn.Parameter(torch.randn(K,out_planes),requires_grad=True)
else:
self.bias=None
if(self.init_weight):
self._initialize_weights()
#TODO 初始化
def _initialize_weights(self):
for i in range(self.K):
nn.init.kaiming_uniform_(self.weight[i])
def forward(self,x):
bs,in_planels,h,w=x.shape
softmax_att=self.attention(x) #bs,K
x=x.view(1,-1,h,w)
weight=self.weight.view(self.K,-1) #K,-1
aggregate_weight=torch.mm(softmax_att,weight).view(bs*self.out_planes,self.in_planes//self.groups,self.kernel_size,self.kernel_size) #bs*out_p,in_p,k,k
if(self.bias is not None):
bias=self.bias.view(self.K,-1) #K,out_p
aggregate_bias=torch.mm(softmax_att,bias).view(-1) #bs,out_p
output=F.conv2d(x,weight=aggregate_weight,bias=aggregate_bias,stride=self.stride,padding=self.padding,groups=self.groups*bs,dilation=self.dilation)
else:
output=F.conv2d(x,weight=aggregate_weight,bias=None,stride=self.stride,padding=self.padding,groups=self.groups*bs,dilation=self.dilation)
output=output.view(bs,self.out_planes,h,w)
return output
if __name__ == '__main__':
input=torch.randn(2,32,64,64)
m=DynamicConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)
out=m(input)
print(out.shape)