Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
tsurumeso committed Mar 20, 2017
1 parent 51b3e0e commit 3242ab1
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 15 deletions.
14 changes: 4 additions & 10 deletions lib/reconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,9 @@ def get_tta_patterns(src, n):
return [patterns[0]]


def image_tta(src, model, scale, tta_level, block_size, batch_size):
if scale:
dst = np.zeros((src.size[1] * 2, src.size[0] * 2, 3))
if model.inner_scale == 1:
src = src.resize((src.size[0] * 2, src.size[1] * 2), Image.NEAREST)
else:
dst = np.zeros_like(src, dtype=np.float32)
def image_tta(src, model, tta_level, block_size, batch_size):
inner_scale = model.inner_scale
dst = np.zeros((src.size[1] * inner_scale, src.size[0] * inner_scale, 3))
patterns = get_tta_patterns(src, tta_level)
if model.ch == 1:
for i, (pat, inv) in enumerate(patterns):
Expand Down Expand Up @@ -129,9 +125,7 @@ def image_tta(src, model, scale, tta_level, block_size, batch_size):
return dst


def image(src, model, scale, block_size, batch_size):
if scale and model.inner_scale == 1:
src = src.resize((src.size[0] * 2, src.size[1] * 2), Image.NEAREST)
def image(src, model, block_size, batch_size):
if model.ch == 1:
src = np.array(src.convert('YCbCr'), dtype=np.uint8)
dst = blockwise(src[:, :, 0], model, block_size, batch_size)
Expand Down
11 changes: 6 additions & 5 deletions waifu2x.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,25 @@ def denoise_image(src, model, cfg):
six.print_('Level %d denoising...' % cfg.noise_level, end=' ', flush=True)
if cfg.tta:
dst = reconstruct.image_tta(
src, model, False, cfg.tta_level, cfg.block_size, cfg.batch_size)
src, model, cfg.tta_level, cfg.block_size, cfg.batch_size)
else:
dst = reconstruct.image(src, model, False, cfg.block_size, cfg.batch_size)
dst = reconstruct.image(src, model, cfg.block_size, cfg.batch_size)
six.print_('OK')
return dst


def upscale_image(src, model, cfg):
dst = src
if model.inner_scale == 1:
dst = src.resize((src.size[0] * 2, src.size[1] * 2), Image.NEAREST)
log_scale = np.log2(cfg.scale_factor)
for _ in range(int(np.ceil(log_scale))):
six.print_('2.0x upscaling...', end=' ', flush=True)
if cfg.tta:
dst = reconstruct.image_tta(
dst, model, True,
cfg.tta_level, cfg.block_size, cfg.batch_size)
dst, model, cfg.tta_level, cfg.block_size, cfg.batch_size)
else:
dst = reconstruct.image(dst, model, True, cfg.block_size, cfg.batch_size)
dst = reconstruct.image(dst, model, cfg.block_size, cfg.batch_size)
six.print_('OK')
if np.round(log_scale % 1.0, 6) != 0:
six.print_('Resizing...', end=' ', flush=True)
Expand Down

0 comments on commit 3242ab1

Please sign in to comment.