Skip to content

Commit

Permalink
refactor for image residual
Browse files Browse the repository at this point in the history
  • Loading branch information
ychfan committed Apr 1, 2017
1 parent ea44226 commit 91d9567
Show file tree
Hide file tree
Showing 11 changed files with 64 additions and 73 deletions.
28 changes: 19 additions & 9 deletions data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import tensorflow as tf
import util

resize_func = None
resize = False
residual = False

def dataset(hr_flist, lr_flist, scale, resize_func=None):
def dataset(hr_flist, lr_flist, scale, resize=resize, residual=residual):
with open(hr_flist) as f:
hr_filename_list = f.read().splitlines()
with open(lr_flist) as f:
Expand All @@ -15,19 +16,28 @@ def dataset(hr_flist, lr_flist, scale, resize_func=None):
lr_image = tf.image.decode_image(lr_image_file, channels=3)
hr_image = tf.image.convert_image_dtype(hr_image, tf.float32)
lr_image = tf.image.convert_image_dtype(lr_image, tf.float32)
hr_patches0, lr_patches0 = make_patches(hr_image, lr_image, scale, resize_func)
hr_patches1, lr_patches1 = make_patches(tf.image.rot90(hr_image), tf.image.rot90(lr_image), scale, resize_func)
if (residual):
hr_image = make_residual(hr_image, lr_image)
hr_patches0, lr_patches0 = make_patches(hr_image, lr_image, scale, resize)
hr_patches1, lr_patches1 = make_patches(tf.image.rot90(hr_image), tf.image.rot90(lr_image), scale, resize)
return tf.concat([hr_patches0, hr_patches1], 0), tf.concat([lr_patches0, lr_patches1], 0)

def make_patches(hr_image, lr_image, scale, resize_func):
def make_residual(hr_image, lr_image):
hr_image = tf.expand_dims(hr_image, 0)
lr_image = tf.expand_dims(lr_image, 0)
hr_image_shape = tf.shape(hr_image)[1:3]
res_image = hr_image - util.resize_func(lr_image, hr_image_shape)
return tf.reshape(res_image, [hr_image_shape[0], hr_image_shape[1], 3])

def make_patches(hr_image, lr_image, scale, resize):
hr_image = tf.stack(flip([hr_image]))
lr_image = tf.stack(flip([lr_image]))
hr_patches = util.image_to_patches(hr_image)
if (resize_func is None):
lr_patches = util.image_to_patches(lr_image, scale)
else:
lr_image = resize_func(lr_image, tf.shape(hr_image)[1:3])
if (resize):
lr_image = util.resize_func(lr_image, tf.shape(hr_image)[1:3])
lr_patches = util.image_to_patches(lr_image)
else:
lr_patches = util.image_to_patches(lr_image, scale)
return hr_patches, lr_patches

def flip(img_list):
Expand Down
8 changes: 8 additions & 0 deletions data_residual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import tensorflow as tf
import data

resize = False
residual = True

def dataset(hr_flist, lr_flist, scale):
return data.dataset(hr_flist, lr_flist, scale, resize, residual)
7 changes: 0 additions & 7 deletions data_resize.py

This file was deleted.

8 changes: 8 additions & 0 deletions data_resize_residual.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import tensorflow as tf
import data

resize = True
residual = True

def dataset(hr_flist, lr_flist, scale):
return data.dataset(hr_flist, lr_flist, scale, resize, residual)
12 changes: 0 additions & 12 deletions model_conv_res.py

This file was deleted.

28 changes: 0 additions & 28 deletions model_resnet_res.py

This file was deleted.

17 changes: 11 additions & 6 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
flags = tf.app.flags
FLAGS = flags.FLAGS

flags.DEFINE_string('data_name', 'data_resize', 'Directory to put the training data.')
flags.DEFINE_string('data_name', 'data_resize_residual', 'Directory to put the training data.')
flags.DEFINE_string('hr_flist', 'flist/set5_predict.flist', 'file_list put the training data.')
flags.DEFINE_string('lr_flist', 'flist/set5_lrX2.flist', 'Directory to put the training data.')
flags.DEFINE_integer('scale', '2', 'batch size for training')
Expand All @@ -13,7 +13,7 @@

data = __import__(FLAGS.data_name)
model = __import__(FLAGS.model_name)
if ((data.resize_func is None) != model.upsample):
if (data.resize == model.upsample):
print "Config Error"
quit()

Expand All @@ -30,14 +30,19 @@
lr_image = tf.expand_dims(lr_image, 0)
lr_image_shape = tf.shape(lr_image)[1:3]
hr_image_shape = lr_image_shape * FLAGS.scale
if (data.resize_func is not None):
lr_image = data.resize_func(lr_image, hr_image_shape)
if (data.resize):
lr_image = util.resize_func(lr_image, hr_image_shape)
lr_image = tf.reshape(lr_image, [1, hr_image_shape[0], hr_image_shape[1], 3])
else:
lr_image = tf.reshape(lr_image, [1, lr_image_shape[0], lr_image_shape[1], 3])
lr_image = util.pad_boundary(lr_image)
hr_image = model.build_model(lr_image, FLAGS.scale, training=False, reuse=False)
lr_image_padded = util.pad_boundary(lr_image)
hr_image = model.build_model(lr_image_padded, FLAGS.scale, training=False, reuse=False)
hr_image = util.crop_center(hr_image, hr_image_shape)
if (data.residual):
if (data.resize):
hr_image += lr_image
else:
hr_image += util.resize_func(lr_image, hr_image_shape)
hr_image = tf.image.convert_image_dtype(hr_image, tf.uint8, saturate=True)
hr_image = tf.reshape(hr_image, [hr_image_shape[0], hr_image_shape[1], 3])
hr_image = tf.image.encode_png(hr_image)
Expand Down
2 changes: 1 addition & 1 deletion run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ set -x
EXPR_NAME="try"
TRAIN_DIR="tmp"
MODEL_NAME="model_resnet"
DATA_NAME="data_resize"
DATA_NAME="data_resize_residual"
HR_FLIST="flist/hr.flist"
LR_FLIST="flist/lrX2.flist"
SCALE=2
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
flags = tf.app.flags
FLAGS = flags.FLAGS

flags.DEFINE_string('data_name', 'data_resize', 'Directory to put the training data.')
flags.DEFINE_string('data_name', 'data_resize_residual', 'Directory to put the training data.')
flags.DEFINE_string('hr_flist', 'flist/hr_debug.flist', 'file_list put the training data.')
flags.DEFINE_string('lr_flist', 'flist/lrX2_debug.flist', 'Directory to put the training data.')
flags.DEFINE_integer('scale', '2', 'batch size for training')
Expand All @@ -17,7 +17,7 @@

data = __import__(FLAGS.data_name)
model = __import__(FLAGS.model_name)
if ((data.resize_func is None) != model.upsample):
if (data.resize == model.upsample):
print "Config Error"
quit()

Expand Down
2 changes: 2 additions & 0 deletions util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import tensorflow as tf

resize_func = tf.image.resize_nearest_neighbor

def image_to_patches(image, scale=1):
patch_height = 108 / scale
patch_width = 108 / scale
Expand Down
21 changes: 13 additions & 8 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
flags = tf.app.flags
FLAGS = flags.FLAGS

flags.DEFINE_string('data_name', 'data_resize', 'Directory to put the training data.')
flags.DEFINE_string('data_name', 'data_resize_residual', 'Directory to put the training data.')
flags.DEFINE_string('hr_flist', 'flist/set5_hr.flist', 'file_list put the training data.')
flags.DEFINE_string('lr_flist', 'flist/set5_lrX2.flist', 'Directory to put the training data.')
flags.DEFINE_integer('scale', '2', 'batch size for training')
Expand All @@ -13,7 +13,7 @@

data = __import__(FLAGS.data_name)
model = __import__(FLAGS.model_name)
if ((data.resize_func is None) != model.upsample):
if (data.resize == model.upsample):
print "Config Error"
quit()

Expand All @@ -33,15 +33,20 @@
lr_image = tf.expand_dims(lr_image, 0)
lr_image_shape = tf.shape(lr_image)[1:3]
hr_image_shape = tf.shape(hr_image)[1:3]
if (data.resize_func is not None):
lr_image = data.resize_func(lr_image, hr_image_shape)
if (data.resize):
lr_image = util.resize_func(lr_image, hr_image_shape)
lr_image = tf.reshape(lr_image, [1, hr_image_shape[0], hr_image_shape[1], 3])
else:
lr_image = tf.reshape(lr_image, [1, lr_image_shape[0], lr_image_shape[1], 3])
lr_image = util.pad_boundary(lr_image)
lr_image = model.build_model(lr_image, FLAGS.scale, training=False, reuse=False)
lr_image = util.crop_center(lr_image, hr_image_shape)
error = tf.losses.mean_squared_error(hr_image, lr_image)
lr_image_padded = util.pad_boundary(lr_image)
sr_image = model.build_model(lr_image_padded, FLAGS.scale, training=False, reuse=False)
sr_image = util.crop_center(sr_image, hr_image_shape)
if (data.residual):
if (data.resize):
sr_image += lr_image
else:
sr_image += util.resize_func(lr_image, hr_image_shape)
error = tf.losses.mean_squared_error(hr_image, sr_image)

init = tf.global_variables_initializer()
init_local = tf.local_variables_initializer()
Expand Down

0 comments on commit 91d9567

Please sign in to comment.