Skip to content

Commit

Permalink
update option.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xinntao committed May 30, 2019
1 parent b3f6acb commit 8df14f0
Showing 1 changed file with 44 additions and 47 deletions.
91 changes: 44 additions & 47 deletions codes/options/options.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,49 @@
import os
import os.path as osp
import logging
from collections import OrderedDict
import json

import yaml
from utils.util import OrderedYaml
Loader, Dumper = OrderedYaml()

def parse(opt_path, is_train=True):
# remove comments starting with '//'
json_str = ''
with open(opt_path, 'r') as f:
for line in f:
line = line.split('//')[0] + '\n'
json_str += line
opt = json.loads(json_str, object_pairs_hook=OrderedDict)
with open(opt_path, mode='r') as f:
opt = yaml.load(f, Loader=Loader)
# export CUDA_VISIBLE_DEVICES
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)

opt['is_train'] = is_train
scale = opt['scale']
if opt['distortion'] == 'sr':
scale = opt['scale']

# datasets
for phase, dataset in opt['datasets'].items():
phase = phase.split('_')[0]
dataset['phase'] = phase
dataset['scale'] = scale
if opt['distortion'] == 'sr':
dataset['scale'] = scale
is_lmdb = False
if 'dataroot_HR' in dataset and dataset['dataroot_HR'] is not None:
dataset['dataroot_HR'] = os.path.expanduser(dataset['dataroot_HR'])
if dataset['dataroot_HR'].endswith('lmdb'):
if dataset.get('dataroot_GT', None) is not None:
dataset['dataroot_GT'] = os.path.expanduser(dataset['dataroot_GT'])
if dataset['dataroot_GT'].endswith('lmdb'):
is_lmdb = True
if 'dataroot_HR_bg' in dataset and dataset['dataroot_HR_bg'] is not None:
dataset['dataroot_HR_bg'] = os.path.expanduser(dataset['dataroot_HR_bg'])
if 'dataroot_LR' in dataset and dataset['dataroot_LR'] is not None:
# if dataset.get('dataroot_GT_bg', None) is not None:
# dataset['dataroot_GT_bg'] = os.path.expanduser(dataset['dataroot_GT_bg'])
if dataset.get('dataroot_LR', None) is not None:
dataset['dataroot_LR'] = os.path.expanduser(dataset['dataroot_LR'])
if dataset['dataroot_LR'].endswith('lmdb'):
is_lmdb = True
dataset['data_type'] = 'lmdb' if is_lmdb else 'img'

if phase == 'train' and 'subset_file' in dataset and dataset['subset_file'] is not None:
dataset['subset_file'] = os.path.expanduser(dataset['subset_file'])
if dataset['mode'].endswith('mc'): # for memcached
dataset['data_type'] = 'mc'
dataset['mode'] = dataset['mode'].replace('_mc', '')

# path
for key, path in opt['path'].items():
if path and key in opt['path']:
opt['path'][key] = os.path.expanduser(path)
opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
if is_train:
experiments_root = os.path.join(opt['path']['root'], 'experiments', opt['name'])
opt['path']['experiments_root'] = experiments_root
Expand All @@ -53,25 +55,33 @@ def parse(opt_path, is_train=True):
# change some options for debug mode
if 'debug' in opt['name']:
opt['train']['val_freq'] = 8
opt['logger']['print_freq'] = 2
opt['logger']['print_freq'] = 1
opt['logger']['save_checkpoint_freq'] = 8
opt['train']['lr_decay_iter'] = 10
else: # test
results_root = os.path.join(opt['path']['root'], 'results', opt['name'])
opt['path']['results_root'] = results_root
opt['path']['log'] = results_root

# network
opt['network_G']['scale'] = scale

# export CUDA_VISIBLE_DEVICES
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
if opt['distortion'] == 'sr':
opt['network_G']['scale'] = scale

return opt


def dict2str(opt, indent_l=1):
'''dict to string for logger'''
msg = ''
for k, v in opt.items():
if isinstance(v, dict):
msg += ' ' * (indent_l * 2) + k + ':[\n'
msg += dict2str(v, indent_l + 1)
msg += ' ' * (indent_l * 2) + ']\n'
else:
msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
return msg


class NoneDict(dict):
def __missing__(self, key):
return None
Expand All @@ -90,31 +100,18 @@ def dict_to_nonedict(opt):
return opt


def dict2str(opt, indent_l=1):
'''dict to string for logger'''
msg = ''
for k, v in opt.items():
if isinstance(v, dict):
msg += ' ' * (indent_l * 2) + k + ':[\n'
msg += dict2str(v, indent_l + 1)
msg += ' ' * (indent_l * 2) + ']\n'
else:
msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
return msg


def check_resume(opt):
def check_resume(opt, resume_iter):
'''Check resume states and pretrain_model paths'''
logger = logging.getLogger('base')
if opt['path']['resume_state']:
if opt['path']['pretrain_model_G'] or opt['path']['pretrain_model_D']:
if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
'pretrain_model_D', None) is not None:
logger.warning('pretrain_model path will be ignored when resuming training.')

state_idx = osp.basename(opt['path']['resume_state']).split('.')[0]
opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
'{}_G.pth'.format(state_idx))
'{}_G.pth'.format(resume_iter))
logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
if 'gan' in opt['model']:
opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
'{}_D.pth'.format(state_idx))
'{}_D.pth'.format(resume_iter))
logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])

0 comments on commit 8df14f0

Please sign in to comment.