Skip to content

Commit

Permalink
add A2N architecture for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
victorca25 committed Oct 26, 2021
1 parent a0f0a22 commit 12d006f
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 43 deletions.
241 changes: 206 additions & 35 deletions codes/models/modules/architectures/PAN_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,19 @@



def pa_upconv_block(nf, unf, kernel_size=3, stride=1, padding=1, mode='nearest', upscale_factor=2, act_type='lrelu'):
def pa_upconv_block(nf, unf, kernel_size=3,
stride=1, padding=1, mode='nearest', upscale_factor=2,
act_type='lrelu'):
upsample = B.Upsample(scale_factor=upscale_factor, mode=mode)
upconv = nn.Conv2d(nf, unf, kernel_size, stride, padding, bias=True)
att = PA(unf)
HRconv = nn.Conv2d(unf, unf, kernel_size, stride, padding, bias=True)
a = B.act(act_type) if act_type else None
return B.sequential(upsample, upconv, att, a, HRconv, a)


class PA(nn.Module):
'''PA is pixel attention'''
"""PA is pixel attention"""
def __init__(self, nf):

super(PA, self).__init__()
Expand All @@ -31,16 +34,21 @@ def forward(self, x):
out = torch.mul(x, y)

return out



class PACnv(nn.Module):

def __init__(self, nf, k_size=3):

super(PACnv, self).__init__()
self.k2 = nn.Conv2d(nf, nf, 1) # 1x1 convolution nf->nf
self.sigmoid = nn.Sigmoid()
self.k3 = nn.Conv2d(nf, nf, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) # 3x3 convolution
self.k4 = nn.Conv2d(nf, nf, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) # 3x3 convolution
self.k3 = nn.Conv2d(
nf, nf, kernel_size=k_size,
padding=(k_size - 1) // 2, bias=False) # 3x3 convolution
self.k4 = nn.Conv2d(
nf, nf, kernel_size=k_size,
padding=(k_size - 1) // 2, bias=False) # 3x3 convolution

def forward(self, x):

Expand All @@ -51,7 +59,7 @@ def forward(self, x):
out = self.k4(out)

return out

class SCPA(nn.Module):

"""SCPA is modified from SCNet (Jiang-Jiang Liu et al. Improving Convolutional Networks with Self-Calibrated Convolutions. In CVPR, 2020)
Expand Down Expand Up @@ -96,37 +104,41 @@ def forward(self, x):
out += residual

return out



class PAN(nn.Module):
'''
Efficient Image Super-Resolution Using Pixel Attention, in ECCV Workshop, 2020.
"""
Efficient Image Super-Resolution Using Pixel Attention,
in ECCV Workshop, 2020.
Modified from https://github.com/zhaohengyuan1/PAN
'''

def __init__(self, in_nc, out_nc, nf, unf, nb, scale=4, self_attention=True, double_scpa=False, ups_inter_mode = 'nearest'):
"""

def __init__(self, in_nc, out_nc, nf,
unf, nb, scale=4, self_attention=True,
double_scpa=False, ups_inter_mode='nearest'):
super(PAN, self).__init__()
n_upscale = int(math.log(scale, 2))
if scale == 3:
n_upscale = 1
elif scale == 1:
unf = nf

# SCPA
SCPA_block_f = functools.partial(SCPA, nf=nf, reduction=2)
self.scale = scale
self.ups_inter_mode = ups_inter_mode #'nearest' # 'bilinear'
self.ups_inter_mode = ups_inter_mode # 'nearest' # 'bilinear'
self.double_scpa = double_scpa

## self-attention
self.self_attention = self_attention
if self_attention:
if self_attention:
spectral_norm = False
max_pool = True #False
max_pool = True # False
poolsize = 4

### first convolution
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)

### main blocks
self.SCPA_trunk = B.make_layer(SCPA_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
Expand All @@ -137,52 +149,59 @@ def __init__(self, in_nc, out_nc, nf, unf, nb, scale=4, self_attention=True, dou

### self-attention
if self.self_attention:
self.FSA = B.SelfAttentionBlock(in_dim=nf, max_pool=max_pool, poolsize=poolsize, spectral_norm=spectral_norm)

'''
self.FSA = B.SelfAttentionBlock(
in_dim=nf, max_pool=max_pool,
poolsize=poolsize, spectral_norm=spectral_norm)

"""
# original upsample
#### upsampling
self.upconv1 = nn.Conv2d(nf, unf, 3, 1, 1, bias=True)
self.att1 = PA(unf)
self.HRconv1 = nn.Conv2d(unf, unf, 3, 1, 1, bias=True)
if self.scale == 4:
self.upconv2 = nn.Conv2d(unf, unf, 3, 1, 1, bias=True)
self.att2 = PA(unf)
self.HRconv2 = nn.Conv2d(unf, unf, 3, 1, 1, bias=True)
'''
#### new upsample
"""

#### new upsample
upsampler = []
for i in range(n_upscale):
if i < 1:
if self.scale == 3:
upsampler.append(pa_upconv_block(nf, unf, 3, 1, 1, self.ups_inter_mode, 3))
upsampler.append(
pa_upconv_block(nf, unf, 3, 1, 1,
self.ups_inter_mode, 3))
else:
upsampler.append(pa_upconv_block(nf, unf, 3, 1, 1, self.ups_inter_mode))
upsampler.append(
pa_upconv_block(nf, unf, 3, 1, 1,
self.ups_inter_mode))
else:
upsampler.append(pa_upconv_block(unf, unf, 3, 1, 1, self.ups_inter_mode))
upsampler.append(
pa_upconv_block(unf, unf, 3, 1, 1,
self.ups_inter_mode))
self.upsample = B.sequential(*upsampler)

self.conv_last = nn.Conv2d(unf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

def forward(self, x):

fea = self.conv_first(x)
trunk = self.trunk_conv(self.SCPA_trunk(fea))
if self.double_scpa:
trunk = self.trunk_conv2(self.SCPA_trunk2(trunk))

# fea = fea + trunk
# Elementwise sum, with FSA if enabled
if self.self_attention:
fea = self.FSA(fea + trunk)
else:
fea = fea + trunk

'''
#original upsample
"""
# original upsample
if self.scale == 2 or self.scale == 3:
fea = self.upconv1(F.interpolate(fea, scale_factor=self.scale, mode=self.ups_inter_mode, align_corners=True))
fea = self.lrelu(self.att1(fea))
Expand All @@ -194,7 +213,7 @@ def forward(self, x):
fea = self.upconv2(F.interpolate(fea, scale_factor=2, mode=self.ups_inter_mode, align_corners=True))
fea = self.lrelu(self.att2(fea))
fea = self.lrelu(self.HRconv2(fea))
'''
"""

# new upsample
fea = self.upsample(fea)
Expand All @@ -205,6 +224,158 @@ def forward(self, x):
ILR = F.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=True)
else:
ILR = x

out = out + ILR
return out


class AttentionBranch(nn.Module):
"""Attention Branch."""

def __init__(self, nf, k_size=3):

super(AttentionBranch, self).__init__()
self.k1 = nn.Conv2d(
nf, nf,kernel_size=k_size,
padding=(k_size - 1) // 2, bias=False) # 3x3 convolution
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.k2 = nn.Conv2d(nf, nf, 1) # 1x1 convolution nf->nf
self.sigmoid = nn.Sigmoid()
self.k3 = nn.Conv2d(
nf, nf, kernel_size=k_size,
padding=(k_size - 1) // 2, bias=False) # 3x3 convolution
self.k4 = nn.Conv2d(
nf, nf, kernel_size=k_size,
padding=(k_size - 1) // 2, bias=False) # 3x3 convolution

def forward(self, x):
y = self.k1(x)
y = self.lrelu(y)
y = self.k2(y)
y = self.sigmoid(y)

out = torch.mul(self.k3(x), y)
out = self.k4(out)

return out


class AAB(nn.Module):
""" Attention in Attention Network for Image Super-Resolution (A2N).
Modified from: https://github.com/haoyuc/A2N
"""

def __init__(self, nf, reduction=4, K=2, t=30, mode:str="n"):
super(AAB, self).__init__()
self.t = t
self.K = K

self.conv_first = nn.Conv2d(nf, nf, kernel_size=1, bias=False)
self.conv_last = nn.Conv2d(nf, nf, kernel_size=1, bias=False)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

self.avg_pool = nn.AdaptiveAvgPool2d(1)

# Attention Dropout Module
self.ADM = nn.Sequential(
nn.Linear(nf, nf // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(nf // reduction, self.K, bias=False),
)

# attention branch
self.attention = AttentionBranch(nf)

# non-attention branch
if mode == "m":
# 1x1 conv for A2N-M (Recommended, fewer parameters)
self.non_attention = nn.Conv2d(nf, nf, kernel_size=1, bias=False)
else:
# 3x3 conv for A2N
self.non_attention = nn.Conv2d(nf, nf, kernel_size=3,
padding=(3 - 1) // 2, bias=False)


def forward(self, x):
residual = x
a, b, c, d = x.shape

x = self.conv_first(x)
x = self.lrelu(x)

# Attention Dropout
y = self.avg_pool(x).view(a,b)
y = self.ADM(y)
ax = F.softmax(y/self.t, dim = 1)

attention = self.attention(x)
non_attention = self.non_attention(x)

x = attention * ax[:,0].view(a,1,1,1) + non_attention * ax[:,1].view(a,1,1,1)
x = self.lrelu(x)

out = self.conv_last(x)
out += residual

return out


class AAN(nn.Module):

def __init__(self, in_nc:int=3, out_nc:int=3,
nf:int=40, unf:int=24, nb:int=16, scale:int=4,
mode:str="n"):
super(AAN, self).__init__()

# AAB
AAB_block_f = functools.partial(AAB, nf=nf, mode=mode)
self.scale = scale

### first convolution
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)

### main blocks
self.AAB_trunk = B.make_layer(AAB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)

#### upsampling
self.upconv1 = nn.Conv2d(nf, unf, 3, 1, 1, bias=True)
self.att1 = PA(unf)
self.HRconv1 = nn.Conv2d(unf, unf, 3, 1, 1, bias=True)

if self.scale == 4:
self.upconv2 = nn.Conv2d(unf, unf, 3, 1, 1, bias=True)
self.att2 = PA(unf)
self.HRconv2 = nn.Conv2d(unf, unf, 3, 1, 1, bias=True)

self.conv_last = nn.Conv2d(unf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

def forward(self, x):

fea = self.conv_first(x)
trunk = self.trunk_conv(self.AAB_trunk(fea))
fea = fea + trunk

if self.scale == 2 or self.scale == 3:
fea = self.upconv1(F.interpolate(fea, scale_factor=self.scale, mode='nearest'))
fea = self.lrelu(self.att1(fea))
fea = self.lrelu(self.HRconv1(fea))
elif self.scale == 4:
fea = self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(self.att1(fea))
fea = self.lrelu(self.HRconv1(fea))
fea = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))
fea = self.lrelu(self.att2(fea))
fea = self.lrelu(self.HRconv2(fea))

out = self.conv_last(fea)

if self.scale > 1:
ILR = F.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False)
else:
ILR = x

out = out + ILR

return out
3 changes: 3 additions & 0 deletions codes/models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ def get_network(opt, step=0, selector=None):
elif kind == 'pan_net':
from models.modules.architectures import PAN_arch
net = PAN_arch.PAN
elif kind == 'a2n_net':
from models.modules.architectures import PAN_arch
net = PAN_arch.AAN
elif kind == 'sofvsr_net':
from models.modules.architectures import SOFVSR_arch
net = SOFVSR_arch.SOFVSR
Expand Down

0 comments on commit 12d006f

Please sign in to comment.