Skip to content

Commit

Permalink
add wgan-gp and other matters
Browse files Browse the repository at this point in the history
  • Loading branch information
xinntao committed Apr 26, 2018
1 parent c5dd974 commit c845ea5
Show file tree
Hide file tree
Showing 12 changed files with 143 additions and 95 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# folder
.vscode
experiments
results

# file type
*.svg

*.pyc
*.t7
*.pth
Expand Down
10 changes: 5 additions & 5 deletions codes/data_preprocess/extract_subimgs_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import time

def worker(GT_paths, save_GT_dir):
crop_sz = 500
step = 400
thres_sz = 100
crop_sz = 500//4
step = 400//4
thres_sz = 50//4
for GT_path in GT_paths:
base_name = os.path.basename(GT_path)
print(base_name, os.getpid())
Expand Down Expand Up @@ -44,8 +44,8 @@ def worker(GT_paths, save_GT_dir):

if __name__=='__main__':

GT_dir = '/mnt/SSD/xtwang/BasicSR_datasets/GOPRO/test/blur_gamma'
save_GT_dir = '/mnt/SSD/xtwang/BasicSR_datasets/GOPRO/test/blur_gamma_sub'
GT_dir = '/mnt/SSD/xtwang/BasicSR_datasets/Semi_pair/PHONE'
save_GT_dir = '/mnt/SSD/xtwang/BasicSR_datasets/Semi_pair/PHONE_sub'
n_thread = 20

print('Parent process %s.' % os.getpid())
Expand Down
120 changes: 77 additions & 43 deletions codes/models/SRGAN_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,32 +18,39 @@ def name(self):
def __init__(self, opt):
super(SRGANModel, self).__init__(opt)
train_opt = opt['train']

self.input_L = self.Tensor()
self.input_H = self.Tensor()
self.input_ref = self.Tensor() # for Discriminator

self.need_pixel_loss = True
self.need_feature_loss = True
if train_opt['pixel_weight'] == 0:
print('Set pixel loss to zero.')
self.need_pixel_loss = False
if train_opt['feature_weight'] == 0:
print('Set feature loss to zero.')
self.need_feature_loss = False
assert self.need_pixel_loss or self.need_feature_loss, 'pixel and feature loss are both 0.'

# define network and load pretrained models
# Generator - SR network
self.netG = networks.define_G(opt)
self.load_path_G = opt['path']['pretrain_model_G']
if self.is_train:
self.need_pixel_loss = True
self.need_feature_loss = True
if train_opt['pixel_weight'] == 0:
print('Set pixel loss to zero.')
self.need_pixel_loss = False
if train_opt['feature_weight'] == 0:
print('Set feature loss to zero.')
self.need_feature_loss = False
assert self.need_pixel_loss or self.need_feature_loss, 'pixel and feature loss are both 0.'
# Discriminator
self.netD = networks.define_D(opt)
self.load_path_D = opt['path']['pretrain_model_D']
if self.need_feature_loss:
self.netF = networks.define_F(opt, use_bn=False) # perceptual loss
self.load() # load G and D if needed

if self.is_train:
# for wgan-gp
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
if train_opt['gan_type'] == 'wgan-gp':
self.random_pt = Variable(self.Tensor(1, 1, 1, 1))

# define loss function
# pixel loss
pixel_loss_type = train_opt['pixel_criterion']
Expand All @@ -53,6 +60,8 @@ def __init__(self, opt):
self.criterion_pixel = nn.MSELoss()
else:
raise NotImplementedError('Loss type [%s] is not recognized.' % pixel_loss_type)
self.loss_pixel_weight = train_opt['pixel_weight']

# feature loss
feature_loss_type = train_opt['feature_criterion']
if feature_loss_type == 'l1':
Expand All @@ -61,18 +70,25 @@ def __init__(self, opt):
self.criterion_feature = nn.MSELoss()
else:
raise NotImplementedError('Loss type [%s] is not recognized.' % feature_loss_type)
self.loss_feature_weight = train_opt['feature_weight']

# gan loss
gan_type = train_opt['gan_type']
self.criterion_gan = GANLoss(gan_type, real_label_val=1.0, fake_label_val=0.0, \
tensor=self.Tensor)
self.loss_gan_weight = train_opt['gan_weight']

# gradient penalty loss
if train_opt['gan_type'] == 'wgan-gp':
self.criterion_gp = GradientPenaltyLoss(tensor=self.Tensor)
self.loss_gp_weight = train_opt['gp_weigth']

if self.use_gpu:
self.criterion_pixel.cuda()
self.criterion_feature.cuda()
self.criterion_gan.cuda()
self.loss_pixel_weight = train_opt['pixel_weight']
self.loss_feature_weight = train_opt['feature_weight']
self.loss_gan_weight = train_opt['gan_weight']
if train_opt['gan_type'] == 'wgan-gp':
self.criterion_gp.cuda()

# initialize optimizers
self.optimizers = [] # G and D
Expand Down Expand Up @@ -126,37 +142,39 @@ def feed_data(self, data, volatile=False, need_HR=True):
self.input_L.resize_(input_L.size()).copy_(input_L)
self.real_L = Variable(self.input_L, volatile=volatile)

def optimize_parameters(self):
def optimize_parameters(self, step):
# G
self.optimizer_G.zero_grad()
# forward G
# self.real_L: leaf, not requires_grad; self.fake_H: no leaf, requires_grad
self.fake_H = self.netG(self.real_L)
if self.need_pixel_loss:
loss_g_pixel = self.loss_pixel_weight * self.criterion_pixel(self.fake_H, self.real_H)
# forward F
if self.need_feature_loss:

if step % self.D_update_ratio == 0 and step > self.D_init_iters:
if self.need_pixel_loss:
loss_g_pixel = self.loss_pixel_weight * self.criterion_pixel(self.fake_H, self.real_H)
# forward F
# self.real_fea: leaf, not requires_grad (gt features, do not need bp)
real_fea = self.netF(self.real_H).detach()
# self.fake_fea: not leaf, requires_grad (need bp, in the graph)
# self.real_fea and self.fake_fea are not the same, since features is independent to conv
fake_fea = self.netF(self.fake_H)
loss_g_fea = self.loss_feature_weight * self.criterion_feature(fake_fea, real_fea)
# forward D
pred_g_fake = self.netD(self.fake_H)
loss_g_gan = self.loss_gan_weight * self.criterion_gan(pred_g_fake, True)

# total los
if self.need_pixel_loss:
if self.need_feature_loss:
loss_g_total = loss_g_pixel + loss_g_fea + loss_g_gan
# forward F
# self.real_fea: leaf, not requires_grad (gt features, do not need bp)
real_fea = self.netF(self.real_H).detach()
# self.fake_fea: not leaf, requires_grad (need bp, in the graph)
# self.real_fea and self.fake_fea are not the same, since features is independent to conv
fake_fea = self.netF(self.fake_H)
loss_g_fea = self.loss_feature_weight * self.criterion_feature(fake_fea, real_fea)
# forward D
pred_g_fake = self.netD(self.fake_H)
loss_g_gan = self.loss_gan_weight * self.criterion_gan(pred_g_fake, True)

# total los
if self.need_pixel_loss:
if self.need_feature_loss:
loss_g_total = loss_g_pixel + loss_g_fea + loss_g_gan
else:
loss_g_total = loss_g_pixel + loss_g_gan
else:
loss_g_total = loss_g_pixel + loss_g_gan
else:
loss_g_total = loss_g_fea + loss_g_gan
loss_g_total.backward()
self.optimizer_G.step()
loss_g_total = loss_g_fea + loss_g_gan
loss_g_total.backward()
self.optimizer_G.step()

# D
self.optimizer_D.zero_grad()
Expand All @@ -166,8 +184,20 @@ def optimize_parameters(self):
# fake data
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
loss_d_fake = self.criterion_gan(pred_d_fake, False)
# total loss
loss_d_total = loss_d_real + loss_d_fake
if self.opt['train']['gan_type'] == 'wgan-gp':
n = self.real_ref.size(0)
if not self.random_pt.size(0) == n:
self.random_pt.data.resize_(n, 1, 1, 1)
self.random_pt.data.uniform_() # Draw random interpolation points
interp = (self.random_pt * self.fake_H + (1 - self.random_pt) * self.real_ref).detach()
interp.requires_grad = True
interp_crit = self.netD(interp)
loss_d_gp = self.loss_gp_weight * self.criterion_gp(interp, interp_crit)
# total loss
loss_d_total = loss_d_real + loss_d_fake + loss_d_gp
else:
# total loss
loss_d_total = loss_d_real + loss_d_fake
loss_d_total.backward()
self.optimizer_D.step()

Expand All @@ -178,11 +208,14 @@ def optimize_parameters(self):

# set losses
self.loss_dict = OrderedDict()
self.loss_dict['loss_g_pixel'] = loss_g_pixel.data[0] if self.need_pixel_loss else 0
self.loss_dict['loss_g_fea'] = loss_g_fea.data[0] if self.need_feature_loss else 0
self.loss_dict['loss_g_gan'] = loss_g_gan.data[0]
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
self.loss_dict['loss_g_pixel'] = loss_g_pixel.data[0] if self.need_pixel_loss else 0
self.loss_dict['loss_g_fea'] = loss_g_fea.data[0] if self.need_feature_loss else 0
self.loss_dict['loss_g_gan'] = loss_g_gan.data[0]
self.loss_dict['loss_d_real'] = loss_d_real.data[0]
self.loss_dict['loss_d_fake'] = loss_d_fake.data[0]
if self.opt['train']['gan_type'] == 'wgan-gp':
self.loss_dict['loss_d_gp'] = loss_d_gp.data[0]

def val(self):
self.fake_H = self.netG(self.real_L)
Expand Down Expand Up @@ -233,7 +266,7 @@ def load(self):
if self.load_path_G is not None:
print('loading model for G [%s] ...' % self.load_path_G)
self.load_network(self.load_path_G, self.netG)
if self.load_path_D is not None:
if self.opt['is_train'] and self.load_path_D is not None:
print('loading model for D [%s] ...' % self.load_path_D)
self.load_network(self.load_path_D, self.netD)

Expand All @@ -247,4 +280,5 @@ def train(self):

def eval(self):
self.netG.eval()
self.netD.eval()
if self.opt['is_train']:
self.netD.eval()
2 changes: 1 addition & 1 deletion codes/models/SR_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def backward_G(self):
self.loss_pixel = self.loss_pixel_weight * self.criterion_pixel(self.fake_H, self.real_H)
self.loss_pixel.backward()

def optimize_parameters(self):
def optimize_parameters(self, step):
self.forward_G()
self.optimizer_G.zero_grad()
self.backward_G()
Expand Down
4 changes: 2 additions & 2 deletions codes/models/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def forward(self, input, target_is_real):


class GradientPenaltyLoss(nn.Module):
def __init__(self):
def __init__(self, tensor=torch.FloatTensor):
super(GradientPenaltyLoss, self).__init__()
self.register_buffer('grad_outputs', torch.Tensor())
self.register_buffer('grad_outputs', tensor())

def get_grad_outputs(self, input):
if self.grad_outputs.size() != input.size():
Expand Down
2 changes: 1 addition & 1 deletion codes/options/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def get_timestamp():


def parse(opt_path, is_train=True):
# remove comments start with '//'
# remove comments starting with '//'
json_str = ''
with open(opt_path, 'r') as f:
for line in f:
Expand Down
39 changes: 22 additions & 17 deletions codes/options/train/SRGAN.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"name":"037_026ft_SRGAN_pixel1_fea0_adv1_DIV2Kbic_ISOunpair_DnoBN"
"name":"044_005ft_SRGAN_pixel0_fea1_adv1e-2_DIV2Kbic_DIV2Kunpair_wgan-gp"
,"model":"srgan"
,"gpu_ids": [0]

Expand All @@ -15,7 +15,7 @@
// ,"dataroot_LR": "/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub_bicLR"
,"dataroot_HR": "/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub"
,"dataroot_LR": "/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub_bicLR"
,"dataroot_ref": "/mnt/SSD/xtwang/BasicSR_datasets/ISO/fusion_merge_bicLRx2_100"
,"dataroot_ref": "/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub"
,"subset_file": null
,"use_shuffle": true
,"n_workers": 8
Expand All @@ -24,32 +24,32 @@
,"scale": 4
,"use_flip": true
,"use_rot": true
,"reverse": true // reverse LR and HR
,"reverse": false // reverse LR and HR
}
, "val": { // evaluation
"name": "DSLR_val" // Set5 | Set14_part | GOPRO
"name": "Set14_part" // Set5 | Set14_part | GOPRO
,"data_type": "img"
,"mode": "LRHR_pair"
,"phase": "val"
//,"dataroot_HR": "/mnt/SSD/xtwang/BasicSR_datasets/GOPRO/val_blur_gamma_sub"
//,"dataroot_LR": "/mnt/SSD/xtwang/BasicSR_datasets/GOPRO/val_blur_gamma_sub_bicLRx4"
,"dataroot_HR": "/mnt/SSD/xtwang/BasicSR_datasets/DSLRPHONE/val_sub"
,"dataroot_LR": "/mnt/SSD/xtwang/BasicSR_datasets/DSLRPHONE/val_sub_bicLRx4"
,"dataroot_HR": "/mnt/SSD/xtwang/BasicSR_datasets/val_set14_part/Set14_part"
,"dataroot_LR": "/mnt/SSD/xtwang/BasicSR_datasets/val_set14_part/Set14_part_bicLRx4"
,"scale": 4
,"metric_mode": "rgb"
,"reverse": true // reverse LR and HR
,"reverse": false // reverse LR and HR
}
}

,"path": {
"root": "/home/xtwang/Projects/BasicSR"
,"pretrain_model_G": "/home/xtwang/Projects/BasicSR/experiments/026_009ft_SRGAN_pixel0_fea1_adv1e-1_DIV2Kbic_ISOunpair/models/190000_G.pth"
,"pretrain_model_G": "/home/xtwang/Projects/BasicSR/experiments/005_001c_SRResNet_torch_DIV2K_bic_ft_halt/models/1000000_G.pth"
//,"pretrain_model_D": "xxx"
}

,"network_G": {
// "which_model_G": "sr_resnet_torch"
"which_model_G": "degradation_net"
"which_model_G": "sr_resnet_torch"
// "which_model_G": "degradation_net"
,"norm_type": "batch"
,"mode": "NAC"
,"scale": 4
Expand All @@ -60,7 +60,7 @@
,"group": 1
}
,"network_D": {
"which_model_D": "discriminaotr_vgg_32"
"which_model_D": "discriminaotr_vgg_128"
,"norm_type": null
,"act_type": "leakyrelu"
,"mode": "CNA"
Expand All @@ -80,19 +80,24 @@
,"lr_gamma": 0.5

,"pixel_criterion": "l1" // "l1" | "l2"
,"pixel_weight": 1
,"pixel_weight": 0
,"feature_criterion": "l1" // "l1" | "l2"
,"feature_weight": 0
,"gan_type": "vanilla" // "vanilla" | "lsgan" | "wgan-gp"
,"gan_weight": 1
,"feature_weight": 1
,"gan_type": "wgan-gp" // "vanilla" | "lsgan" | "wgan-gp"
,"gan_weight": 1e-2

//for wgan-gp
,"D_update_ratio": 5
,"D_init_iters": 50
,"gp_weigth": 10

,"manual_seed": 0
,"niter": 1e6
,"val_freq": 1e4
,"val_freq": 5e3
}

,"logger": {
"print_freq": 200
,"save_checkpoint_freq": 1e4
,"save_checkpoint_freq": 5e3
}
}

0 comments on commit c845ea5

Please sign in to comment.