Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
tsurumeso committed Feb 24, 2018
1 parent 26ceb80 commit 77c16db
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 52 deletions.
4 changes: 3 additions & 1 deletion lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def load_datalist(dir, shuffle=False):
files = os.listdir(dir)
datalist = []
for file in files:
datalist.append(os.path.join(dir, file))
path = os.path.join(dir, file)
if os.path.isfile(path):
datalist.append(path)
if shuffle:
random.shuffle(datalist)
return datalist
104 changes: 53 additions & 51 deletions waifu2x.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,7 @@


def denoise_image(src, model, cfg):
alpha = None
dst = src.convert('RGB')
if src.mode == 'LA' or src.mode == 'RGBA':
six.print_('Splitting alpha channel...', end=' ', flush=True)
alpha = src.split()[-1]
dst = iproc.alpha_make_border(dst, alpha, model.offset)
six.print_('OK')

dst, alpha = split_alpha(src, model.offset)
six.print_('Level %d denoising...' % cfg.noise_level, end=' ', flush=True)
if cfg.tta:
dst = reconstruct.image_tta(
Expand All @@ -35,14 +28,7 @@ def denoise_image(src, model, cfg):


def upscale_image(src, model, cfg):
alpha = None
dst = src.convert('RGB')
if src.mode == 'LA' or src.mode == 'RGBA':
six.print_('Splitting alpha channel...', end=' ', flush=True)
alpha = src.split()[-1]
dst = iproc.alpha_make_border(dst, alpha, model.offset)
six.print_('OK')

dst, alpha = split_alpha(src, model.offset)
log_scale = np.log2(cfg.scale_factor)
for _ in range(int(np.ceil(log_scale))):
six.print_('2.0x upscaling...', end=' ', flush=True)
Expand Down Expand Up @@ -73,6 +59,55 @@ def upscale_image(src, model, cfg):
return dst


def split_alpha(src, offset):
alpha = None
rgb = src.convert('RGB')
if src.mode == 'LA' or src.mode == 'RGBA':
six.print_('Splitting alpha channel...', end=' ', flush=True)
alpha = src.split()[-1]
rgb = iproc.alpha_make_border(rgb, alpha, offset)
six.print_('OK')
return rgb, alpha


def load_models(args):
ch = 3 if args.color == 'rgb' else 1
if args.model_dir is None:
model_dir = 'models/%s' % args.arch.lower()
else:
model_dir = args.model_dir

models = {}
flag = False
if args.method == 'noise_scale':
model_name = ('anime_style_noise%d_scale_%s.npz'
% (args.noise_level, args.color))
model_path = os.path.join(model_dir, model_name)
if os.path.exists(model_path):
models['noise_scale'] = srcnn.archs[args.arch](ch)
chainer.serializers.load_npz(model_path, models['noise_scale'])
else:
flag = True
if args.method == 'scale' or flag:
model_name = 'anime_style_scale_%s.npz' % args.color
model_path = os.path.join(model_dir, model_name)
models['scale'] = srcnn.archs[args.arch](ch)
chainer.serializers.load_npz(model_path, models['scale'])
if args.method == 'noise' or flag:
model_name = ('anime_style_noise%d_%s.npz'
% (args.noise_level, args.color))
model_path = os.path.join(model_dir, model_name)
models['noise'] = srcnn.archs[args.arch](ch)
chainer.serializers.load_npz(model_path, models['noise'])

if args.gpu >= 0:
cuda.check_cuda_available()
cuda.get_device(args.gpu).use()
for _, model in models.items():
model.to_gpu()
return models


p = argparse.ArgumentParser()
p.add_argument('--gpu', '-g', type=int, default=-1)
p.add_argument('--input', '-i', default='images/small.png')
Expand Down Expand Up @@ -104,11 +139,8 @@ def upscale_image(src, model, cfg):


if __name__ == '__main__':
ch = 3 if args.color == 'rgb' else 1
if args.model_dir is None:
model_dir = 'models/%s' % args.arch.lower()
else:
model_dir = args.model_dir
models = load_models(args)

if os.path.isdir(args.output):
if not os.path.exists(args.output):
os.makedirs(args.output)
Expand All @@ -118,36 +150,6 @@ def upscale_image(src, model, cfg):
dirname = './'
if not os.path.exists(dirname):
os.makedirs(dirname)

models = {}
flag = False
if args.method == 'noise_scale':
model_name = ('anime_style_noise%d_scale_%s.npz'
% (args.noise_level, args.color))
model_path = os.path.join(model_dir, model_name)
if os.path.exists(model_path):
models['noise_scale'] = srcnn.archs[args.arch](ch)
chainer.serializers.load_npz(model_path, models['noise_scale'])
else:
flag = True
if args.method == 'scale' or flag:
model_name = 'anime_style_scale_%s.npz' % args.color
model_path = os.path.join(model_dir, model_name)
models['scale'] = srcnn.archs[args.arch](ch)
chainer.serializers.load_npz(model_path, models['scale'])
if args.method == 'noise' or flag:
model_name = ('anime_style_noise%d_%s.npz'
% (args.noise_level, args.color))
model_path = os.path.join(model_dir, model_name)
models['noise'] = srcnn.archs[args.arch](ch)
chainer.serializers.load_npz(model_path, models['noise'])

if args.gpu >= 0:
cuda.check_cuda_available()
cuda.get_device(args.gpu).use()
for _, model in models.items():
model.to_gpu()

if os.path.isdir(args.input):
filelist = glob.glob(os.path.join(args.input, '*.*'))
else:
Expand Down

0 comments on commit 77c16db

Please sign in to comment.