Skip to content

Commit

Permalink
update discriminator and vggfeatureextractor network
Browse files Browse the repository at this point in the history
  • Loading branch information
xinntao committed Apr 21, 2018
1 parent daa122d commit 6359609
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 20 deletions.
20 changes: 10 additions & 10 deletions codes/models/modules/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def forward(self, x):
####################
# VGG style Discriminator with input size 128*128
class Discriminaotr_VGG_128(nn.Module):
def __init__(self, in_nc, out_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
super(Discriminaotr_VGG_128, self).__init__()
# features
# hxw, c
Expand All @@ -83,17 +83,17 @@ def __init__(self, in_nc, out_nc, base_nf, norm_type='batch', act_type='leakyrel
mode=mode)
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 64, 128
# 64, 64
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 32, 256
# 32, 128
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 16, 512
# 16, 256
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
act_type=act_type, mode=mode)
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
Expand All @@ -103,7 +103,7 @@ def __init__(self, in_nc, out_nc, base_nf, norm_type='batch', act_type='leakyrel
act_type=act_type, mode=mode)
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
act_type=act_type, mode=mode)
# 4, 256
# 4, 512
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
conv9)

Expand All @@ -124,24 +124,24 @@ def forward(self, x):
####################
# Perceptual Network
####################
# Assume input in range [0, 1]
# Assume input range is [0, 1]
class VGGFeatureExtractor(nn.Module):
def __init__(self, feature_layer=11, use_bn=True, use_input_norm=True):
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, tensor=torch.FloatTensor):
super(VGGFeatureExtractor, self).__init__()
if use_bn:
model = torchvision.models.vgg19_bn(pretrained=True)
else:
model = torchvision.models.vgg19(pretrained=True)
self.use_input_norm = use_input_norm
if self.use_input_norm:
mean = Variable(torch.Tensor([0.485, 0.456, 0.406]).view(1,3,1,1), requires_grad=False)
mean = Variable(tensor([0.485, 0.456, 0.406]).view(1,3,1,1), requires_grad=False)
# [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
std = Variable(torch.Tensor([0.229, 0.224, 0.225]).view(1,3,1,1), requires_grad=False)
std = Variable(tensor([0.229, 0.224, 0.225]).view(1,3,1,1), requires_grad=False)
# [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
self.register_buffer('mean', mean)
self.register_buffer('std', std)
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
# No need to bp to variable
# No need to BP to variable
for k, v in self.features.named_parameters():
v.requires_grad = False

Expand Down
4 changes: 2 additions & 2 deletions codes/models/modules/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

# helper selecting activation
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1):
# neg_slope: for leaky_relu and init of prelu
# neg_slope: for leakyrelu and init of prelu
# n_prelu: for p_relu num_parameters
act_type = act_type.lower()
if act_type == 'relu':
layer = nn.ReLU(inplace)
elif act_type == 'leaky_relu':
elif act_type == 'leakyrelu':
layer = nn.LeakyReLU(neg_slope, inplace)
elif act_type == 'prelu':
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
Expand Down
35 changes: 27 additions & 8 deletions codes/models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import torch.nn as nn
from torch.nn import init
import functools
import models.modules.generator as G
import models.modules.perceptual_network as F
import models.modules.architecture as Arch

####################
# initialize
Expand Down Expand Up @@ -74,16 +73,17 @@ def init_weights(net, init_type='kaiming', scale=1, std=0.02):
####################
# define network
####################
# Generator
def define_G(opt):
gpu_ids = opt['gpu_ids']
opt = opt['network']
opt = opt['network_G']
which_model = opt['which_model_G']

if which_model == 'sr_resnet_torch':
netG = G.SRResNet_torch(in_nc=opt['in_nc'], out_nc=opt['out_nc'], nf=opt['nf'], \
netG = Arch.SRResNet_torch(in_nc=opt['in_nc'], out_nc=opt['out_nc'], nf=opt['nf'], \
nb=opt['nb'], upscale=opt['scale'], norm_type=opt['norm_type'], mode=opt['mode'])
elif which_model == 'degradation_net':
netG = G.DegradationNet(in_nc=opt['in_nc'], out_nc=opt['out_nc'], nf=opt['nf'], \
netG = Arch.DegradationNet(in_nc=opt['in_nc'], out_nc=opt['out_nc'], nf=opt['nf'], \
nb=opt['nb'], upscale=opt['scale'], norm_type=opt['norm_type'], mode=opt['mode'])

else:
Expand All @@ -96,15 +96,34 @@ def define_G(opt):
return netG


# Discriminator
def define_D(opt):
gpu_ids = opt['gpu_ids']
opt = opt['network_D']
which_model = opt['which_model_D']

if which_model == 'discriminaotr_vgg_128':
netD = Arch.Discriminaotr_VGG_128(in_nc=opt['in_nc'], base_nf=opt['nf'], \
norm_type=opt['norm_type'], mode=opt['mode'] ,act_type=opt['act_type'])
else:
raise NotImplementedError('Discriminator model [%s] is not recognized' % which_model)
if opt['is_train']:
init_weights(netD, init_type='kaiming', scale=1)
if gpu_ids:
netD = nn.DataParallel(netD).cuda()
return netD


def define_F(opt, use_bn=False):
gpu_ids = opt['gpu_ids']
# pytorch pretrained VGG19, with BN
# VGG19-54, before ReLU.
tensor = torch.cuda.FloatTensor if gpu_ids else torch.FloatTensor
# pytorch pretrained VGG19-54, before ReLU.
if use_bn:
feature_layer = 49
else:
feature_layer = 34
netF = F.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, use_input_norm=True)
netF = Arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, \
use_input_norm=True, tensor=tensor)
if gpu_ids:
netF = nn.DataParallel(netF).cuda()
netF.eval() # No need to train
Expand Down

0 comments on commit 6359609

Please sign in to comment.