Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
tsurumeso committed May 1, 2018
1 parent ecf4011 commit bb62746
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions waifu2x.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from lib import utils


def denoise_image(src, model, cfg):
def denoise_image(cfg, src, model):
dst, alpha = split_alpha(src, model.offset)
six.print_('Level %d denoising...' % cfg.noise_level, end=' ', flush=True)
if cfg.tta:
Expand All @@ -29,7 +29,7 @@ def denoise_image(src, model, cfg):
return dst


def upscale_image(src, scale_model, cfg, alpha_model=None):
def upscale_image(cfg, src, scale_model, alpha_model=None):
dst, alpha = split_alpha(src, scale_model.offset)
log_scale = np.log2(cfg.scale_ratio)
for i in range(int(np.ceil(log_scale))):
Expand Down Expand Up @@ -79,47 +79,47 @@ def split_alpha(src, offset):
return rgb, alpha


def load_models(args):
ch = 3 if args.color == 'rgb' else 1
if args.model_dir is None:
model_dir = 'models/%s' % args.arch.lower()
def load_models(cfg):
ch = 3 if cfg.color == 'rgb' else 1
if cfg.model_dir is None:
model_dir = 'models/%s' % cfg.arch.lower()
else:
model_dir = args.model_dir
model_dir = cfg.model_dir

models = {}
flag = False
if args.method == 'noise_scale':
if cfg.method == 'noise_scale':
model_name = ('anime_style_noise%d_scale_%s.npz'
% (args.noise_level, args.color))
% (cfg.noise_level, cfg.color))
model_path = os.path.join(model_dir, model_name)
if os.path.exists(model_path):
models['noise_scale'] = srcnn.archs[args.arch](ch)
models['noise_scale'] = srcnn.archs[cfg.arch](ch)
chainer.serializers.load_npz(model_path, models['noise_scale'])
alpha_model_name = 'anime_style_scale_%s.npz' % args.color
alpha_model_name = 'anime_style_scale_%s.npz' % cfg.color
alpha_model_path = os.path.join(model_dir, alpha_model_name)
models['alpha'] = srcnn.archs[args.arch](ch)
models['alpha'] = srcnn.archs[cfg.arch](ch)
chainer.serializers.load_npz(alpha_model_path, models['alpha'])
else:
flag = True
if args.method == 'scale' or flag:
model_name = 'anime_style_scale_%s.npz' % args.color
if cfg.method == 'scale' or flag:
model_name = 'anime_style_scale_%s.npz' % cfg.color
model_path = os.path.join(model_dir, model_name)
models['scale'] = srcnn.archs[args.arch](ch)
models['scale'] = srcnn.archs[cfg.arch](ch)
chainer.serializers.load_npz(model_path, models['scale'])
if args.method == 'noise' or flag:
if cfg.method == 'noise' or flag:
model_name = ('anime_style_noise%d_%s.npz'
% (args.noise_level, args.color))
% (cfg.noise_level, cfg.color))
model_path = os.path.join(model_dir, model_name)
if not os.path.exists(model_path):
model_name = ('anime_style_noise%d_scale_%s.npz'
% (args.noise_level, args.color))
% (cfg.noise_level, cfg.color))
model_path = os.path.join(model_dir, model_name)
models['noise'] = srcnn.archs[args.arch](ch)
models['noise'] = srcnn.archs[cfg.arch](ch)
chainer.serializers.load_npz(model_path, models['noise'])

if args.gpu >= 0:
if cfg.gpu >= 0:
cuda.check_cuda_available()
cuda.get_device(args.gpu).use()
cuda.get_device(cfg.gpu).use()
for _, model in models.items():
model.to_gpu()
return models
Expand Down Expand Up @@ -188,14 +188,14 @@ def load_models(args):
outname += ('(noise%d_scale%.1fx)'
% (args.noise_level, args.scale_ratio))
dst = upscale_image(
dst, models['noise_scale'], args, models['alpha'])
args, dst, models['noise_scale'], models['alpha'])
else:
if 'noise' in models:
outname += '(noise%d)' % args.noise_level
dst = denoise_image(dst, models['noise'], args)
dst = denoise_image(args, dst, models['noise'])
if 'scale' in models:
outname += '(scale%.1fx)' % args.scale_ratio
dst = upscale_image(dst, models['scale'], args)
dst = upscale_image(args, dst, models['scale'])

outname += '(%s_%s).png' % (args.arch.lower(), args.color)
if os.path.isdir(args.output):
Expand Down

0 comments on commit bb62746

Please sign in to comment.