Skip to content

Commit

Permalink
Support to reconstruct with UpConv models
Browse files Browse the repository at this point in the history
  • Loading branch information
tsurumeso committed Mar 14, 2017
1 parent 396a330 commit d002d92
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 37 deletions.
44 changes: 25 additions & 19 deletions lib/reconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,30 @@ def get_outer_padding(size, block_size, offset):
def blockwise(src, model, block_size, batch_size):
if src.ndim == 2:
src = src[:, :, np.newaxis]
h, w, ch = src.shape
scale = 1. / 255.
offset = model.offset
xp = utils.get_model_module(model)
ph = get_outer_padding(h, block_size, offset)
pw = get_outer_padding(w, block_size, offset)
psrc = np.pad(src, ((offset, ph), (offset, pw), (0, 0)), 'edge')
nh = (psrc.shape[0] - offset * 2) // block_size
nw = (psrc.shape[1] - offset * 2) // block_size

inner_scale = model.inner_scale
in_offset = model.offset // inner_scale

in_h, in_w, ch = src.shape
out_h, out_w = in_h * inner_scale, in_w * inner_scale
in_block_size = block_size // inner_scale
block_offset = in_block_size + in_offset * 2
in_ph = get_outer_padding(in_h, in_block_size, in_offset)
in_pw = get_outer_padding(in_w, in_block_size, in_offset)
out_ph = get_outer_padding(out_h, block_size, model.offset)
out_pw = get_outer_padding(out_w, block_size, model.offset)
psrc = np.pad(
src, ((in_offset, in_ph), (in_offset, in_pw), (0, 0)), 'edge')
nh = (psrc.shape[0] - in_offset * 2) // in_block_size
nw = (psrc.shape[1] - in_offset * 2) // in_block_size
psrc = psrc.transpose(2, 0, 1)
block_offset = block_size + offset * 2

x = np.zeros((nh * nw, ch, block_offset, block_offset), dtype=np.uint8)
for i in range(0, nh):
ih = i * block_size
ih = i * in_block_size
for j in range(0, nw):
jw = j * block_size
jw = j * in_block_size
psrc_ij = psrc[:, ih:ih + block_offset, jw:jw + block_offset]
x[(i * nw) + j, :, :, :] = psrc_ij

Expand All @@ -46,14 +53,14 @@ def blockwise(src, model, block_size, batch_size):
batch_y = model(batch_x)
y[i:i + batch_size] = cuda.to_cpu(batch_y.data)

dst = np.zeros((ch, h + ph, w + pw), dtype=np.float32)
dst = np.zeros((ch, out_h + out_ph, out_w + out_pw), dtype=np.float32)
for i in range(0, nh):
ih = i * block_size
for j in range(0, nw):
jw = j * block_size
dst[:, ih:ih + block_size, jw:jw + block_size] = y[(i * nw) + j]

dst = dst[:, :h, :w]
dst = dst[:, :out_h, :out_w]
return dst.transpose(1, 2, 0)


Expand Down Expand Up @@ -85,12 +92,11 @@ def get_tta_patterns(src, n):
return [patterns[0]]


def image_tta(src, model, scale, tta_level, block_size, batch_size):
if scale:
def image_tta(src, model, upsampling, tta_level, block_size, batch_size):
dst = np.zeros((src.size[1] * 2, src.size[0] * 2, 3))
if upsampling:
src = src.resize((src.size[0] * 2, src.size[1] * 2), Image.NEAREST)
patterns = get_tta_patterns(src, tta_level)
dst = np.zeros((src.size[1], src.size[0], 3))
cbcr = np.zeros((src.size[1], src.size[0], 2))
if model.ch == 1:
for i, (pat, inv) in enumerate(patterns):
six.print_(i, end=' ', flush=True)
Expand Down Expand Up @@ -120,8 +126,8 @@ def image_tta(src, model, scale, tta_level, block_size, batch_size):
return dst


def image(src, model, scale, block_size, batch_size):
if scale:
def image(src, model, upsampling, block_size, batch_size):
if upsampling:
src = src.resize((src.size[0] * 2, src.size[1] * 2), Image.NEAREST)
if model.ch == 1:
src = np.array(src.convert('YCbCr'), dtype=np.uint8)
Expand Down
48 changes: 30 additions & 18 deletions waifu2x.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@ def denoise_image(src, model, cfg):
def upscale_image(src, model, cfg):
dst = src
log_scale = np.log2(cfg.scale_factor)
upsampling = (model.inner_scale == 1)
for _ in range(int(np.ceil(log_scale))):
six.print_('2.0x upscaling...', end=' ', flush=True)
if cfg.tta:
dst = reconstruct.image_tta(
dst, model, True,
dst, model, upsampling,
cfg.tta_level, cfg.block_size, cfg.batch_size)
else:
dst = reconstruct.image(
dst, model, True, cfg.block_size, cfg.batch_size)
dst, model, upsampling, cfg.block_size, cfg.batch_size)
six.print_('OK')
if np.round(log_scale % 1.0, 6) != 0:
six.print_('Resizing...', end=' ', flush=True)
Expand All @@ -61,6 +62,7 @@ def upscale_image(src, model, cfg):
p.add_argument('--scale', '-s', action='store_true')
p.add_argument('--scale_factor', '-S', type=float, default=2.0)
p.add_argument('--noise', '-n', action='store_true')
p.add_argument('--noise_scale', action='store_true')
p.add_argument('--noise_level', '-N', type=int, choices=[0, 1, 2, 3],
default=1)
p.add_argument('--color', '-c', choices=['y', 'rgb'], default='rgb')
Expand Down Expand Up @@ -88,18 +90,24 @@ def upscale_image(src, model, cfg):
if not os.path.exists(args.output):
os.makedirs(args.output)

if args.scale:
model_name = 'anime_style_scale_%s.npz' % args.color
model_path = os.path.join(model_dir, model_name)
model_scale = srcnn.archs[args.arch](ch)
chainer.serializers.load_npz(model_path, model_scale)

if args.noise:
model_name = ('anime_style_noise%d_%s.npz'
if args.noise_scale:
model_name = ('anime_style_noise%d_scale_%s.npz'
% (args.noise_level, args.color))
model_path = os.path.join(model_dir, model_name)
model_noise = srcnn.archs[args.arch](ch)
chainer.serializers.load_npz(model_path, model_noise)
model_noise_scale = srcnn.archs[args.arch](ch)
chainer.serializers.load_npz(model_path, model_noise_scale)
else:
if args.scale:
model_name = 'anime_style_scale_%s.npz' % args.color
model_path = os.path.join(model_dir, model_name)
model_scale = srcnn.archs[args.arch](ch)
chainer.serializers.load_npz(model_path, model_scale)
if args.noise:
model_name = ('anime_style_noise%d_%s.npz'
% (args.noise_level, args.color))
model_path = os.path.join(model_dir, model_name)
model_noise = srcnn.archs[args.arch](ch)
chainer.serializers.load_npz(model_path, model_noise)

if args.gpu >= 0:
cuda.check_cuda_available()
Expand All @@ -123,12 +131,16 @@ def upscale_image(src, model, cfg):
oname += ('_(tta%d)' % args.tta_level if args.tta else '_')

dst = src.copy()
if args.noise:
oname += '(noise%d)' % args.noise_level
dst = denoise_image(dst, model_noise, args)
if args.scale:
oname += '(scale%.1fx)' % args.scale_factor
dst = upscale_image(dst, model_scale, args)
if args.noise_scale:
oname += '(noise%d_scale)' % args.noise_level
dst = upscale_image(dst, model_noise_scale, args)
else:
if args.noise:
oname += '(noise%d)' % args.noise_level
dst = denoise_image(dst, model_noise, args)
if args.scale:
oname += '(scale%.1fx)' % args.scale_factor
dst = upscale_image(dst, model_scale, args)

if args.model_dir is None:
oname += '(%s_%s).png' % (args.arch.lower(), args.color)
Expand Down

0 comments on commit d002d92

Please sign in to comment.