Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the dynamic conv stride bug #90

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions model/conv/CondConv.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ def forward(self,x):
else:
output=F.conv2d(x,weight=aggregate_weight,bias=None,stride=self.stride,padding=self.padding,groups=self.groups*bs,dilation=self.dilation)

output=output.view(bs,self.out_planes,h,w)
output=output.view(bs,self.out_planes, output.size(-2), output.size(-1))
return output

if __name__ == '__main__':
input=torch.randn(2,32,64,64)
m=CondConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)
m=CondConv(in_planes=32,out_planes=64,kernel_size=3,stride=2,padding=1,bias=False)
out=m(input)
print(out.shape)
118 changes: 64 additions & 54 deletions model/conv/DynamicConv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,89 +2,99 @@
from torch import nn
from torch.nn import functional as F


class Attention(nn.Module):
def __init__(self,in_planes,ratio,K,temprature=30,init_weight=True):
def __init__(self, in_planes, ratio, K, temprature=30, init_weight=True):
super().__init__()
self.avgpool=nn.AdaptiveAvgPool2d(1)
self.temprature=temprature
assert in_planes>ratio
hidden_planes=in_planes//ratio
self.net=nn.Sequential(
nn.Conv2d(in_planes,hidden_planes,kernel_size=1,bias=False),
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.temprature = temprature
assert in_planes > ratio
hidden_planes = in_planes // ratio
self.net = nn.Sequential(
nn.Conv2d(in_planes, hidden_planes, kernel_size=1, bias=False),
nn.ReLU(),
nn.Conv2d(hidden_planes,K,kernel_size=1,bias=False)
nn.Conv2d(hidden_planes, K, kernel_size=1, bias=False)
)

if(init_weight):
if (init_weight):
self._initialize_weights()

def update_temprature(self):
if(self.temprature>1):
self.temprature-=1
if (self.temprature > 1):
self.temprature -= 1

def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
if isinstance(m ,nn.BatchNorm2d):
if isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def forward(self,x):
att=self.avgpool(x) #bs,dim,1,1
att=self.net(att).view(x.shape[0],-1) #bs,K
return F.softmax(att/self.temprature,-1)
def forward(self, x):
att = self.avgpool(x) # bs,dim,1,1
att = self.net(att).view(x.shape[0], -1) # bs,K
return F.softmax(att / self.temprature, -1)


class DynamicConv(nn.Module):
def __init__(self,in_planes,out_planes,kernel_size,stride,padding=0,dilation=1,grounps=1,bias=True,K=4,temprature=30,ratio=4,init_weight=True):
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0, dilation=1, grounps=1, bias=True, K=4,
temprature=30, ratio=4, init_weight=True):
super().__init__()
self.in_planes=in_planes
self.out_planes=out_planes
self.kernel_size=kernel_size
self.stride=stride
self.padding=padding
self.dilation=dilation
self.groups=grounps
self.bias=bias
self.K=K
self.init_weight=init_weight
self.attention=Attention(in_planes=in_planes,ratio=ratio,K=K,temprature=temprature,init_weight=init_weight)

self.weight=nn.Parameter(torch.randn(K,out_planes,in_planes//grounps,kernel_size,kernel_size),requires_grad=True)
if(bias):
self.bias=nn.Parameter(torch.randn(K,out_planes),requires_grad=True)
self.in_planes = in_planes
self.out_planes = out_planes
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = grounps
self.bias = bias
self.K = K
self.init_weight = init_weight
self.attention = Attention(in_planes=in_planes, ratio=ratio, K=K, temprature=temprature,
init_weight=init_weight)

self.weight = nn.Parameter(torch.randn(K, out_planes, in_planes // grounps, kernel_size, kernel_size),
requires_grad=True)
if (bias):
self.bias = nn.Parameter(torch.randn(K, out_planes), requires_grad=True)
else:
self.bias=None
if(self.init_weight):
self.bias = None

if (self.init_weight):
self._initialize_weights()

#TODO 初始化
# TODO 初始化

def _initialize_weights(self):
for i in range(self.K):
nn.init.kaiming_uniform_(self.weight[i])

def forward(self,x):
bs,in_planels,h,w=x.shape
softmax_att=self.attention(x) #bs,K
x=x.view(1,-1,h,w)
weight=self.weight.view(self.K,-1) #K,-1
aggregate_weight=torch.mm(softmax_att,weight).view(bs*self.out_planes,self.in_planes//self.groups,self.kernel_size,self.kernel_size) #bs*out_p,in_p,k,k

if(self.bias is not None):
bias=self.bias.view(self.K,-1) #K,out_p
aggregate_bias=torch.mm(softmax_att,bias).view(-1) #bs,out_p
output=F.conv2d(x,weight=aggregate_weight,bias=aggregate_bias,stride=self.stride,padding=self.padding,groups=self.groups*bs,dilation=self.dilation)
def forward(self, x):
bs, in_planels, h, w = x.shape
softmax_att = self.attention(x) # bs,K
x = x.view(1, -1, h, w)
weight = self.weight.view(self.K, -1) # K,-1
aggregate_weight = torch.mm(softmax_att, weight).view(bs * self.out_planes, self.in_planes // self.groups,
self.kernel_size, self.kernel_size) # bs*out_p,in_p,k,k

if (self.bias is not None):
bias = self.bias.view(self.K, -1) # K,out_p
aggregate_bias = torch.mm(softmax_att, bias).view(-1) # bs,out_p
output = F.conv2d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding,
groups=self.groups * bs, dilation=self.dilation)
else:
output=F.conv2d(x,weight=aggregate_weight,bias=None,stride=self.stride,padding=self.padding,groups=self.groups*bs,dilation=self.dilation)

output=output.view(bs,self.out_planes,h,w)
output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
groups=self.groups * bs, dilation=self.dilation)

output = output.view(bs, self.out_planes, output.size(-2), output.size(-1))
return output


if __name__ == '__main__':
input=torch.randn(2,32,64,64)
m=DynamicConv(in_planes=32,out_planes=64,kernel_size=3,stride=1,padding=1,bias=False)
out=m(input)
print(out.shape)
input = torch.randn(2, 32, 64, 64)
m = DynamicConv(in_planes=32, out_planes=64, kernel_size=3, stride=2, padding=1, bias=False)
out = m(input)
print(out.shape)