Skip to content

Commit

Permalink
Add type check and replace arguments order
Browse files Browse the repository at this point in the history
  • Loading branch information
tsurumeso committed Feb 11, 2017
1 parent ec1cabc commit ef87593
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 23 deletions.
2 changes: 2 additions & 0 deletions lib/iproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def read_image_rgb_uint8(path):


def array_to_wand(src):
assert isinstance(src, np.ndarray)
with io.BytesIO() as buf:
tmp = Image.fromarray(src).convert('RGB')
tmp.save(buf, 'PNG', compress_level=0)
Expand All @@ -22,6 +23,7 @@ def array_to_wand(src):


def wand_to_array(src):
assert isinstance(src, WandImage)
with io.BytesIO(src.make_blob('PNG')) as buf:
tmp = Image.open(buf).convert('RGB')
dst = np.array(tmp, dtype=np.uint8)
Expand Down
27 changes: 13 additions & 14 deletions lib/reconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from lib import utils


def blockwise(model, src, block_size, batch_size):
def blockwise(src, model, block_size, batch_size):
if src.ndim == 2:
src = src[:, :, np.newaxis]
h, w, ch = src.shape
Expand Down Expand Up @@ -72,7 +72,7 @@ def get_tta_patterns(src, n):
return [patterns[0]]


def scale_tta(model, src, tta_level, block_size, batch_size):
def scale_tta(src, model, tta_level, block_size, batch_size):
patterns = get_tta_patterns(src, tta_level)
dst = np.zeros((src.size[1] * 2, src.size[0] * 2, 3))
if model.ch == 1:
Expand All @@ -83,7 +83,7 @@ def scale_tta(model, src, tta_level, block_size, batch_size):
pat = np.array(pat.convert('YCbCr'), dtype=np.uint8)
if i == 0:
cbcr = pat[:, :, 1:]
tmp = blockwise(model, pat[:, :, 0], block_size, batch_size)
tmp = blockwise(pat[:, :, 0], model, block_size, batch_size)
if not inv is None:
tmp = inv(tmp)
dst[:, :, 0] += tmp[:, :, 0]
Expand All @@ -97,7 +97,7 @@ def scale_tta(model, src, tta_level, block_size, batch_size):
six.print_(i, end=' ', flush=True)
pat = pat.resize((pat.size[0] * 2, pat.size[1] * 2), Image.NEAREST)
pat = np.array(pat.convert('RGB'), dtype=np.uint8)
tmp = blockwise(model, pat, block_size, batch_size)
tmp = blockwise(pat, model, block_size, batch_size)
if not inv is None:
tmp = inv(tmp)
dst += tmp
Expand All @@ -107,7 +107,7 @@ def scale_tta(model, src, tta_level, block_size, batch_size):
return dst


def noise_tta(model, src, tta_level, block_size, batch_size):
def noise_tta(src, model, tta_level, block_size, batch_size):
patterns = get_tta_patterns(src)[:tta_level]
dst = np.zeros((src.size[1], src.size[0], 3))
if model.ch == 1:
Expand All @@ -117,7 +117,7 @@ def noise_tta(model, src, tta_level, block_size, batch_size):
pat = np.array(pat.convert('YCbCr'), dtype=np.uint8)
if i == 0:
cbcr = pat[:, :, 1:]
tmp = blockwise(model, pat[:, :, 0], block_size, batch_size)
tmp = blockwise(pat[:, :, 0], model, block_size, batch_size)
if not inv is None:
tmp = inv(tmp)
dst[:, :, 0] += tmp[:, :, 0]
Expand All @@ -130,7 +130,7 @@ def noise_tta(model, src, tta_level, block_size, batch_size):
for i, (pat, inv) in enumerate(patterns):
six.print_(i, end=' ', flush=True)
pat = np.array(pat.convert('RGB'), dtype=np.uint8)
tmp = blockwise(model, pat, block_size, batch_size)
tmp = blockwise(pat, model, block_size, batch_size)
if not inv is None:
tmp = inv(tmp)
dst += tmp
Expand All @@ -140,33 +140,32 @@ def noise_tta(model, src, tta_level, block_size, batch_size):
return dst


def scale(model, src, block_size, batch_size):
def scale(src, model, block_size, batch_size):
src = src.resize((src.size[0] * 2, src.size[1] * 2), Image.NEAREST)
if model.ch == 1:
src = np.array(src.convert('YCbCr'), dtype=np.uint8)
dst = blockwise(model, src[:, :, 0], block_size, batch_size)
dst = blockwise(src[:, :, 0], model, block_size, batch_size)
dst = np.clip(dst, 0, 1) * 255
src[:, :, 0] = dst[:, :, 0]
dst = Image.fromarray(src, mode='YCbCr').convert('RGB')
elif model.ch == 3:
src = np.array(src.convert('RGB'), dtype=np.uint8)
dst = blockwise(model, src, block_size, batch_size)
dst = blockwise(src, model, block_size, batch_size)
dst = np.clip(dst, 0, 1) * 255
dst = Image.fromarray(dst.astype(np.uint8))
return dst


def noise(model, src, block_size, batch_size):
def noise(src, model, block_size, batch_size):
if model.ch == 1:
src = np.array(src.convert('YCbCr'), dtype=np.uint8)
dst = blockwise(model, src[:, :, 0], block_size, batch_size)
dst = blockwise(src[:, :, 0], model, block_size, batch_size)
dst = np.clip(dst, 0, 1) * 255
src[:, :, 0] = dst[:, :, 0]
dst = Image.fromarray(src, mode='YCbCr').convert('RGB')
elif model.ch == 3:
src = np.array(src.convert('RGB'), dtype=np.uint8)
dst = blockwise(model, src, block_size, batch_size)
dst = blockwise(src, model, block_size, batch_size)
dst = np.clip(dst, 0, 1) * 255
dst = Image.fromarray(dst.astype(np.uint8))
return dst

17 changes: 8 additions & 9 deletions waifu2x.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,28 @@
from lib import reconstruct


def denoise_image(model, src, cfg):
def denoise_image(src, model, cfg):
six.print_('Level %d denoising...' % cfg.noise_level, end=' ', flush=True)
if cfg.tta:
dst = reconstruct.noise_tta(model, src, tta_level,
dst = reconstruct.noise_tta(src, model, tta_level,
cfg.block_size, cfg.batch_size)
else:
dst = reconstruct.noise(model, src,
dst = reconstruct.noise(src, model,
cfg.block_size, cfg.batch_size)
six.print_('OK')
return dst


def upscale_image(model, src, cfg):
def upscale_image(src, model, cfg):
iter = 0
while iter < int(np.ceil(cfg.scale_factor / 2)):
iter += 1
six.print_('2.0x upscaling...', end=' ', flush=True)
if cfg.tta:
dst = reconstruct.scale_tta(model, src, cfg.tta_level,
dst = reconstruct.scale_tta(src, model, cfg.tta_level,
cfg.block_size, cfg.batch_size)
else:
dst = reconstruct.scale(model, src,
dst = reconstruct.scale(src, model,
cfg.block_size, cfg.batch_size)
six.print_('OK')
if np.round(cfg.scale_factor % 2.0, 6) != 0:
Expand Down Expand Up @@ -118,13 +118,12 @@ def upscale_image(model, src, cfg):
dst = src.copy()
if args.noise:
oname += '(noise%d)' % args.noise_level
dst = denoise_image(model_noise, dst, args)
dst = denoise_image(dst, model_noise, args)
if args.scale:
oname += '(scale%.1fx)' % args.scale_factor
dst = upscale_image(model_scale, dst, args)
dst = upscale_image(dst, model_scale, args)

oname += '(%s).png' % args.arch.lower()
opath = os.path.join(args.dir, oname)
dst.save(opath, icc_profile=icc_profile)
six.print_('Saved as \'%s\'' % opath)

0 comments on commit ef87593

Please sign in to comment.