Skip to content

Commit

Permalink
update models/modules
Browse files Browse the repository at this point in the history
  • Loading branch information
xinntao committed Sep 9, 2018
1 parent c91609d commit 1e39fde
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 169 deletions.
175 changes: 57 additions & 118 deletions codes/models/modules/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torchvision
from . import block as B
from . import spectral_norm as SN
# import block as B

####################
# Generator
Expand Down Expand Up @@ -45,10 +44,10 @@ def forward(self, x):
return x


class RRDB_Net(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
mode='CNA', res_scale=1, upsample_mode='upconv'):
super(RRDB_Net, self).__init__()
class RRDBNet(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
super(RRDBNet, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
Expand Down Expand Up @@ -79,49 +78,58 @@ def forward(self, x):
return x


class RRRDB_Net(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
mode='CNA', res_scale=1, upsample_mode='upconv'):
super(RRRDB_Net, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
####################
# Discriminator
####################

fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
rb_blocks = [B.RRRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
rb_blocks.append(B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA'))
rb_blocks.append(B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA'))
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)

if upsample_mode == 'upconv':
upsample_block = B.upconv_blcok
elif upsample_mode == 'pixelshuffle':
upsample_block = B.pixelshuffle_block
else:
raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode)
if upscale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
else:
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
# VGG style Discriminator with input size 128*128
class Discriminaotr_VGG_128(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminaotr_VGG_128, self).__init__()
# features
# hxw, c
# 128, 64
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
mode=mode)
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 64, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 32, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 16, 256
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 8, 512
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 4, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
conv9)

self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
*upsampler, HR_conv0, HR_conv1)
# classifier
self.classifier = nn.Sequential(
nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))

def forward(self, x):
x = self.model(x)
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x

####################
# Discriminator
####################


# VGG style Discriminator with input size 128*128
# VGG style Discriminator with input size 128*128, Spectral Normalization
class Discriminaotr_VGG_128_SN(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminaotr_VGG_128_SN, self).__init__()
Expand Down Expand Up @@ -167,55 +175,6 @@ def forward(self, x):
return x


# VGG style Discriminator with input size 128*128, Spectral Normalization
class Discriminaotr_VGG_128(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminaotr_VGG_128, self).__init__()
# features
# hxw, c
# 128, 64
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
mode=mode)
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 64, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 32, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 16, 256
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 8, 512
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 4, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
conv9)

# classifier
self.classifier = nn.Sequential(
nn.Linear(512 * 4 * 4, 100),
nn.LeakyReLU(0.2, True),
nn.Linear(100, 1)
)

def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x


class Discriminaotr_VGG_96(nn.Module):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminaotr_VGG_96, self).__init__()
Expand Down Expand Up @@ -310,6 +269,7 @@ def forward(self, x):
x = self.classifier(x)
return x


####################
# Perceptual Network
####################
Expand Down Expand Up @@ -347,6 +307,7 @@ def forward(self, x):
return output


# Assume input range is [0, 1]
class ResNet101FeatureExtractor(nn.Module):
def __init__(self, use_input_norm=True, device=torch.device('cpu')):
super(ResNet101FeatureExtractor, self).__init__()
Expand All @@ -371,9 +332,9 @@ def forward(self, x):
return output


class minc(nn.Module):
class MINCNet(nn.Module):
def __init__(self):
super(minc, self).__init__()
super(MINCNet, self).__init__()
self.ReLU = nn.ReLU(True)
self.conv11 = nn.Conv2d(3, 64, 3, 1, 1)
self.conv12 = nn.Conv2d(64, 64, 3, 1, 1)
Expand Down Expand Up @@ -415,15 +376,12 @@ def forward(self, x):


# Assume input range is [0, 1]
class MincFeatureExtractor(nn.Module):
def __init__(self,
feature_layer=34,
use_bn=False,
use_input_norm=True,
device=torch.device('cpu')):
super(MincFeatureExtractor, self).__init__()
class MINCFeatureExtractor(nn.Module):
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, \
device=torch.device('cpu')):
super(MINCFeatureExtractor, self).__init__()

self.features = minc()
self.features = MINCNet()
self.features.load_state_dict(
torch.load('../experiments/pretrained_models/VGG16minc_53.pth'), strict=True)
self.features.eval()
Expand All @@ -434,22 +392,3 @@ def __init__(self,
def forward(self, x):
output = self.features(x)
return output


if __name__ == '__main__':
net = minc()
net.load_state_dict(torch.load('VGG16minc_53.pth'), strict=True)
net.eval()
net = net.cuda()

import cv2
import numpy as np
img = cv2.imread('/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5_bicLRx4/butterfly_bicLRx4.png',
cv2.IMREAD_UNCHANGED)
img = img.astype(np.float32) / 255.
img = img[:, :, [2, 1, 0]]
img = torch.from_numpy(np.ascontiguousarray(np.transpose(img, (2, 0, 1)))).float()
input = img.unsqueeze(0).cuda()
out = net.forward(input)
print(out.float())
print(out.shape)
58 changes: 16 additions & 42 deletions codes/models/modules/block.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections import OrderedDict

import torch
import torch.nn as nn

Expand Down Expand Up @@ -107,13 +106,13 @@ def sequential(*args):
return nn.Sequential(*modules)


def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, \
pad_type='zero', norm_type=None, act_type='relu', mode='CNA'):
"""
'''
Conv layer with padding, normalization, activation
mode: CNA --> Conv -> Norm -> Act
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
"""
'''
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [%s]' % mode
padding = get_valid_padding(kernel_size, dilation)
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
Expand All @@ -136,19 +135,17 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=
return sequential(n, a, p, c)


# TODO: add deconv block

####################
# Useful blocks
####################


class ResNetBlock(nn.Module):
"""
'''
ResNet Block, 3-3 style
with extra residual scaling used in EDSR
(Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
"""
'''

def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA', res_scale=1):
Expand Down Expand Up @@ -177,11 +174,11 @@ def forward(self, x):


class ResidualDenseBlock_5C(nn.Module):
"""
'''
Residual Dense Block
style: 5 convs
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
"""
'''

def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
Expand Down Expand Up @@ -212,9 +209,10 @@ def forward(self, x):


class RRDB(nn.Module):
"""
'''
Residual in Residual Dense Block
"""
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
'''

def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
Expand All @@ -233,43 +231,19 @@ def forward(self, x):
return out.mul(0.2) + x


class RRRDB(nn.Module):
"""
Residual in Residual Dense Block
"""

def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
super(RRRDB, self).__init__()
self.RRDB1 = RRDB(nc, kernel_size, gc, stride, bias, pad_type, \
norm_type, act_type, mode)
self.RRDB2 = RRDB(nc, kernel_size, gc, stride, bias, pad_type, \
norm_type, act_type, mode)
self.RRDB3 = RRDB(nc, kernel_size, gc, stride, bias, pad_type, \
norm_type, act_type, mode)
self.RRDB4 = RRDB(nc, kernel_size, gc, stride, bias, pad_type, \
norm_type, act_type, mode)

def forward(self, x):
out = self.RRDB1(x)
out = self.RRDB2(out)
out = self.RRDB3(out)
out = self.RRDB4(out)
return out.mul(0.2) + x

####################
# Upsampler
####################


def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
pad_type='zero', norm_type=None, act_type='relu'):
"""
'''
Pixel shuffle layer
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
Neural Network, CVPR17)
"""
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
'''
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, \
pad_type=pad_type, norm_type=None, act_type=None)
pixel_shuffle = nn.PixelShuffle(upscale_factor)

Expand All @@ -278,11 +252,11 @@ def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1,
return sequential(conv, pixel_shuffle, n, a)


def upconv_blcok(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
def upconv_blcok(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
pad_type='zero', norm_type=None, act_type='relu', mode='nearest'):
# Up conv
# described in https://distill.pub/2016/deconv-checkerboard/
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, \
pad_type=pad_type, norm_type=norm_type, act_type=act_type)
return sequential(upsample, conv)
2 changes: 1 addition & 1 deletion codes/models/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_grad_outputs(self, input):

def forward(self, interp, interp_crit):
grad_outputs = self.get_grad_outputs(interp_crit)
grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp,
grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, \
grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True)[0]
grad_interp = grad_interp.view(grad_interp.size(0), -1)
grad_interp_norm = grad_interp.norm(2, dim=1)
Expand Down
5 changes: 2 additions & 3 deletions codes/models/modules/seg_arch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""
'''
architecture for segmentation
"""
import torch
'''
import torch.nn as nn
from . import block as B

Expand Down

0 comments on commit 1e39fde

Please sign in to comment.