Skip to content

Commit

Permalink
check before release
Browse files Browse the repository at this point in the history
  • Loading branch information
xinntao committed Apr 26, 2018
1 parent 2471c7a commit de3ba70
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 40 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# folder
.vscode
experiments
results

# file type
Expand Down
17 changes: 11 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,26 @@ The repo is still under development. There may be some bugs :-)

### Codes descriptions

Please see Wiki pages(https://github.com/xinntao/BasicSR/wiki), which contains:
Please see [Wiki pages](https://github.com/xinntao/BasicSR/wiki), which contains:
- [[data] instructions](https://github.com/xinntao/BasicSR/wiki/%5Bdata%5D-instructions)
- [[options] instructions](https://github.com/xinntao/BasicSR/wiki/%5Boptions%5D-instructions) (including all configuration descriptions)


### Getting Started
- How to test a model
1. prepare your data and pretrained model
1. SRResNet pretrained model can be downloaded from [Google Drive](https://drive.google.com/file/d/18yHStj3INmQ7AD0JlcyedMJ1ENhoBtUl/view?usp=sharing)
The model is not the best and we may provide a new one later.
1. Put the downloaded model in `BasicSR/experiments/pretrained_models/`.
1. modify the corresponding testing json file in `options/test/test.json`
1. test the model with the command `python3 test.py -opt options/test/test.json`

- How to train a model
1. prepare your data (it's better to test whether the data is ok using `test_dataloader`)
1. modify the corresponding training json file in `options/train/xxx.json`
1. modify the corresponding training json file in `options/train/SRResNet(or SRGAN).json`
1. train the model with the command `python3 train.py -opt options/train/SRResNet.json`

- How to test a model
1. prepare your data and pretrained model
1. modify the corresponding testing json file in `options/test/test.json`
1. test the model with the command `python3 test_LRinput.py -opt options/test/test.json`




Expand Down
1 change: 1 addition & 0 deletions codes/options/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def parse(opt_path, is_train=True):
for key, path in opt['path'].items():
opt['path'][key] = os.path.expanduser(path)
for phase, dataset in opt['datasets'].items():
dataset['phase'] = phase
if dataset['dataroot_HR'] is not None:
dataset['dataroot_HR'] = os.path.expanduser(dataset['dataroot_HR'])
if dataset['dataroot_LR'] is not None:
Expand Down
11 changes: 5 additions & 6 deletions codes/options/test/test.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"name": "debug_000_exp_name"
"name": "pretrained_models"
,"model": "sr"
,"gpu_ids": [0]

Expand All @@ -8,16 +8,15 @@
"name": "test"
,"data_type": "img"
,"mode": "LRHRref"
,"phase": "test"
,"dataroot_HR": null
,"dataroot_LR": "/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5_bicLRx4"
,"dataroot_HR": "/mnt/SSD/BasicSR_datasets/val_set5/Set5"
,"dataroot_LR": "/mnt/SSD/BasicSR_datasets/val_set5/Set5_bicLRx4"
,"scale": 4
}
}

,"path": {
"root": "/home/xtwang/Projects/BasicSR"
,"pretrain_model_G": "../experiments/001_SRResNet_DIV2K_bic/models/280000_G.pth"
"root": "/mnt/3TB/Projects/BasicSR"
,"pretrain_model_G": "../experiments/pretrained_models/SRResNet-torch.pth"
}

,"network_G": {
Expand Down
18 changes: 8 additions & 10 deletions codes/options/train/SRGAN.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"name":"debug_000_exp_name"
"name":"debug_002_SRGAN_DIV2K_pair_vanilla" // please remove "debub_" during formal training
,"model":"srgan"
,"gpu_ids": [0]

Expand All @@ -8,9 +8,8 @@
"name": "DIV2K"
,"data_type": "img"
,"mode": "LRHRref"
,"phase": "train"
,"dataroot_HR": "/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub"
,"dataroot_LR": "/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub_bicLRx4"
,"dataroot_HR": "/mnt/SSD/BasicSR_datasets/DIV2K800/DIV2K800_sub"
,"dataroot_LR": "/mnt/SSD/BasicSR_datasets/DIV2K800/DIV2K800_sub_bicLRx4"
,"dataroot_ref": null
,"subset_file": null
,"use_shuffle": true
Expand All @@ -26,18 +25,17 @@
"name": "Set14_part"
,"data_type": "img"
,"mode": "LRHRref"
,"phase": "val"
,"dataroot_HR": "/mnt/SSD/xtwang/BasicSR_datasets/val_set14_part/Set14_part"
,"dataroot_LR": "/mnt/SSD/xtwang/BasicSR_datasets/val_set14_part/Set14_part_bicLRx4"
,"dataroot_HR": "/mnt/SSD/BasicSR_datasets/val_set14_part/Set14_part"
,"dataroot_LR": "/mnt/SSD/BasicSR_datasets/val_set14_part/Set14_part_bicLRx4"
,"scale": 4
,"metric_mode": "rgb"
,"reverse": false
}
}

,"path": {
"root": "/home/xtwang/Projects/BasicSR"
,"pretrain_model_G": "../experiments/001_SRResNet_DIV2K_bic/models/280000_G.pth"
"root": "/mnt/3TB/Projects/BasicSR"
,"pretrain_model_G": "../experiments/pretrained_models/SRResNet-torch.pth"
}

,"network_G": {
Expand Down Expand Up @@ -76,7 +74,7 @@
,"feature_criterion": "l1"
,"feature_weight": 1
,"gan_type": "vanilla" // "vanilla" | "lsgan" | "wgan-gp"
,"gan_weight": 1e-2
,"gan_weight": 5e-3

//for wgan-gp
,"D_update_ratio": 1
Expand Down
20 changes: 9 additions & 11 deletions codes/options/train/SRResNet.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"name":"debug_000_exp_name"
"name":"debug_001_SRResNet_DIV2K_bic" // please remove "debub_" during formal training
,"model":"sr"
,"gpu_ids": [0]

Expand All @@ -8,9 +8,8 @@
"name": "DIV2K"
,"data_type": "img"
,"mode": "LRHRref"
,"phase": "train"
,"dataroot_HR": "/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub"
,"dataroot_LR": "/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub_bicLRx4"
,"dataroot_HR": "/mnt/SSD/BasicSR_datasets/DIV2K800/DIV2K800_sub"
,"dataroot_LR": "/mnt/SSD/BasicSR_datasets/DIV2K800/DIV2K800_sub_bicLRx4"
,"subset_file": null
,"use_shuffle": true
,"n_workers": 6
Expand All @@ -19,24 +18,23 @@
,"scale": 4
,"use_flip": true
,"use_rot": true
//,"reverse": true
,"reverse": false
}
,"val": {
"name": "val_set5"
,"data_type": "img"
,"mode": "LRHRref"
,"phase": "val"
,"dataroot_HR": "/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5"
,"dataroot_LR": "/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5_bicLRx4"
,"dataroot_HR": "/mnt/SSD/BasicSR_datasets/val_set5/Set5"
,"dataroot_LR": "/mnt/SSD/BasicSR_datasets/val_set5/Set5_bicLRx4"
,"scale": 4
,"metric_mode": "rgb"
// ,"reverse": true
,"reverse": false
}
}

,"path": {
"root": "/home/xtwang/Projects/BasicSR"
// ,"pretrain_model_G": "../experiments/001_SRResNet_DIV2K_bic/models/280000_G.pth"
"root": "/mnt/3TB/Projects/BasicSR"
,"pretrain_model_G": "../experiments/pretrained_models/SRResNet-torch.pth"
}

,"network_G": {
Expand Down
38 changes: 32 additions & 6 deletions codes/test_LRinput.py → codes/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ def flush(self):
model = create_model(opt)
model.eval()

# Path for log file
test_log_path = os.path.join(opt['path']['log'], 'test_log.txt')
if os.path.exists(test_log_path):
os.remove(test_log_path)
print('Old test log is removed.')
# # Path for log file
# test_log_path = os.path.join(opt['path']['log'], 'test_log.txt')
# if os.path.exists(test_log_path):
# os.remove(test_log_path)
# print('Old test log is removed.')

print('Start Testing ...')

Expand All @@ -67,6 +67,10 @@ def flush(self):
dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
util.mkdir(dataset_dir)

test_results = OrderedDict()
test_results['psnr'] = []
test_results['ssim'] = []

for data in test_loader:
need_HR = True
if test_loader.dataset.opt['dataroot_HR'] is None:
Expand All @@ -75,12 +79,34 @@ def flush(self):
img_path = data['LR_path'][0]

img_name = os.path.splitext(os.path.basename(img_path))[0]
print(img_name)

model.test() # test
visuals = model.get_current_visuals(need_HR=need_HR)

sr_img = util.tensor2img_np(visuals['SR']) # uint8

if need_HR: # load GT image and calculate psnr
gt_img = util.tensor2img_np(visuals['HR'])
scale = test_loader.dataset.opt['scale']
crop_border = scale
cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border, :]
cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :]
psnr = metric.psnr(cropped_sr_img, cropped_gt_img)
ssim = metric.ssim(cropped_sr_img, cropped_gt_img, multichannel=True)

test_results['psnr'].append(psnr)
test_results['ssim'].append(ssim)

print('{:20s} - PSNR: {:.2f} dB; SSIM: {:.2f}'.format(img_name, psnr, ssim))
else:
print(img_name)

save_img_path = os.path.join(dataset_dir, img_name+'.png')
util.save_img_np(sr_img, save_img_path)

# Average PSNR/SSIM results
ave_psnr = sum(test_results['psnr'])/len(test_results['psnr'])
ave_ssim = sum(test_results['ssim'])/len(test_results['ssim'])
print('-----\nAverage PSNR/SSIM results for {}\n\tPSNR: {:.2f} dB; SSIM: {:.2f}\n-----'.format(\
test_set_name, ave_psnr, ave_ssim))

Empty file.

0 comments on commit de3ba70

Please sign in to comment.