diff --git a/model_conv.py b/model_conv.py index e813646..68d4997 100644 --- a/model_conv.py +++ b/model_conv.py @@ -2,9 +2,9 @@ upsample = False -def build_model(x, scale, training): - x = tf.layers.conv2d(x, 64, 3, activation=tf.sigmoid, name='conv1') - x = tf.layers.conv2d(x, 64, 3, activation=tf.sigmoid, name='conv2') - x = tf.layers.conv2d(x, 64, 3, activation=tf.sigmoid, name='conv3') - x = tf.layers.conv2d(x, 3, 1, activation=None, name='out') +def build_model(x, scale, training, reuse): + x = tf.layers.conv2d(x, 64, 3, activation=tf.sigmoid, name='conv1', reuse=reuse) + x = tf.layers.conv2d(x, 64, 3, activation=tf.sigmoid, name='conv2', reuse=reuse) + x = tf.layers.conv2d(x, 64, 3, activation=tf.sigmoid, name='conv3', reuse=reuse) + x = tf.layers.conv2d(x, 3, 1, activation=None, name='out', reuse=reuse) return x \ No newline at end of file diff --git a/model_conv_res.py b/model_conv_res.py index 95aca24..525a662 100644 --- a/model_conv_res.py +++ b/model_conv_res.py @@ -3,10 +3,10 @@ upsample = False -def build_model(x, scale, training): +def build_model(x, scale, training, reuse): origin_x = x - x = tf.layers.conv2d(x, 64, 3, activation=tf.sigmoid, name='conv1') - x = tf.layers.conv2d(x, 64, 3, activation=tf.sigmoid, name='conv2') - x = tf.layers.conv2d(x, 64, 3, activation=tf.sigmoid, name='conv3') - x = tf.layers.conv2d(x, 3, 1, activation=None, name='out') + x = tf.layers.conv2d(x, 64, 3, activation=tf.sigmoid, name='conv1', reuse=reuse) + x = tf.layers.conv2d(x, 64, 3, activation=tf.sigmoid, name='conv2', reuse=reuse) + x = tf.layers.conv2d(x, 64, 3, activation=tf.sigmoid, name='conv3', reuse=reuse) + x = tf.layers.conv2d(x, 3, 1, activation=None, name='out', reuse=reuse) return x + util.crop_center(origin_x, tf.shape(x)[1:3]) \ No newline at end of file diff --git a/model_pixelcnn.py b/model_pixelcnn.py index c2de2ea..e5e0451 100644 --- a/model_pixelcnn.py +++ b/model_pixelcnn.py @@ -3,13 +3,13 @@ upsample = False -def build_model(x, scale, training): +def build_model(x, scale, training, reuse): hidden_size = 128 projection_size = 32 - x = conv_gated(x, hidden_size, projection_size, 'conv00', False) + x = conv_gated(x, hidden_size, projection_size, 'conv00', reuse) for i in range(10): - x = util.crop_by_pixel(x, 1) + conv_gated(x, hidden_size, projection_size, 'conv'+str(i), False) - x = tf.layers.conv2d(x, 3, 1, activation=None, name='out') + x = util.crop_by_pixel(x, 1) + conv_gated(x, hidden_size, projection_size, 'conv'+str(i), reuse) + x = tf.layers.conv2d(x, 3, 1, activation=None, name='out', reuse=reuse) return x def conv_gated(x, hidden_size, projection_size, name, reuse): diff --git a/model_resnet.py b/model_resnet.py index 8350017..96384c7 100644 --- a/model_resnet.py +++ b/model_resnet.py @@ -3,27 +3,25 @@ upsample = False -def build_model(x, scale, training): +def build_model(x, scale, training, reuse): hidden_size = 128 bottleneck_size = 32 - x = tf.layers.conv2d(x, hidden_size, 3, activation=None, name='in') + x = tf.layers.conv2d(x, hidden_size, 3, activation=None, name='in', reuse=reuse) for i in range(10): - x = util.crop_by_pixel(x, 1) + conv(x, hidden_size, bottleneck_size, training, 'conv'+str(i), False) - x = tf.layers.conv2d(x, 3, 1, activation=None, name='out') + x = util.crop_by_pixel(x, 1) + conv(x, hidden_size, bottleneck_size, training, 'conv'+str(i), reuse) + x = tf.layers.conv2d(x, 3, 1, activation=None, name='out', reuse=reuse) return x def conv(x, hidden_size, bottleneck_size, training, name, reuse): - x = tf.layers.batch_normalization(x, training=training) + x = tf.layers.batch_normalization(x, training=training, name=name+'_norm_proj', reuse=reuse) x = tf.nn.relu(x) x = tf.layers.conv2d(x, bottleneck_size, 1, activation=None, name=name+'_proj', reuse=reuse) - x = tf.layers.batch_normalization(x, training=training) + x = tf.layers.batch_normalization(x, training=training, name=name+'_norm_filt', reuse=reuse) x = tf.nn.relu(x) x = tf.layers.conv2d(x, bottleneck_size, 3, activation=None, name=name+'_filt', reuse=reuse) - x = tf.layers.batch_normalization(x, training=training) + x = tf.layers.batch_normalization(x, training=training, name=name+'_norm_recv', reuse=reuse) x = tf.nn.relu(x) x = tf.layers.conv2d(x, hidden_size, 1, activation=None, name=name+'_recv', reuse=reuse) return x - - diff --git a/model_resnet_res.py b/model_resnet_res.py index c8491d7..348e575 100644 --- a/model_resnet_res.py +++ b/model_resnet_res.py @@ -3,26 +3,26 @@ upsample = False -def build_model(x, scale, training): +def build_model(x, scale, training, reuse): hidden_size = 128 bottleneck_size = 32 origin_x = x - x = tf.layers.conv2d(x, hidden_size, 3, activation=None, name='in') + x = tf.layers.conv2d(x, hidden_size, 3, activation=None, name='in', reuse=reuse) for i in range(10): - x = util.crop_by_pixel(x, 1) + conv(x, hidden_size, bottleneck_size, training, 'conv'+str(i), False) - x = tf.layers.conv2d(x, 3, 1, activation=None, name='out') + x = util.crop_by_pixel(x, 1) + conv(x, hidden_size, bottleneck_size, training, 'conv'+str(i), reuse) + x = tf.layers.conv2d(x, 3, 1, activation=None, name='out', reuse=reuse) return x + util.crop_center(origin_x, tf.shape(x)[1:3]) def conv(x, hidden_size, bottleneck_size, training, name, reuse): - x = tf.layers.batch_normalization(x, training=training) + x = tf.layers.batch_normalization(x, training=training, name=name+'_norm_proj', reuse=reuse) x = tf.nn.relu(x) x = tf.layers.conv2d(x, bottleneck_size, 1, activation=None, name=name+'_proj', reuse=reuse) - x = tf.layers.batch_normalization(x, training=training) + x = tf.layers.batch_normalization(x, training=training, name=name+'_norm_filt', reuse=reuse) x = tf.nn.relu(x) x = tf.layers.conv2d(x, bottleneck_size, 3, activation=None, name=name+'_filt', reuse=reuse) - x = tf.layers.batch_normalization(x, training=training) + x = tf.layers.batch_normalization(x, training=training, name=name+'_norm_recv', reuse=reuse) x = tf.nn.relu(x) x = tf.layers.conv2d(x, hidden_size, 1, activation=None, name=name+'_recv', reuse=reuse) return x diff --git a/model_resnet_up.py b/model_resnet_up.py index 6081cbd..b5cfc70 100644 --- a/model_resnet_up.py +++ b/model_resnet_up.py @@ -3,30 +3,28 @@ upsample = True -def build_model(x, scale, training): +def build_model(x, scale, training, reuse): hidden_size = 128 bottleneck_size = 32 - x = tf.layers.conv2d(x, hidden_size, 3, activation=None, name='in') + x = tf.layers.conv2d(x, hidden_size, 3, activation=None, name='in', reuse=reuse) for i in range(5): - x = util.crop_by_pixel(x, 1) + conv(x, hidden_size, bottleneck_size, training, 'lr_conv'+str(i), False) - x = tf.layers.conv2d_transpose(x, hidden_size, scale, strides=scale, activation=None, name='up') + x = util.crop_by_pixel(x, 1) + conv(x, hidden_size, bottleneck_size, training, 'lr_conv'+str(i), reuse) + x = tf.layers.conv2d_transpose(x, hidden_size, scale, strides=scale, activation=None, name='up', reuse=reuse) for i in range(5): - x = util.crop_by_pixel(x, 1) + conv(x, hidden_size, bottleneck_size, training, 'hr_conv'+str(i), False) - x = tf.layers.conv2d(x, 3, 1, activation=None, name='out') + x = util.crop_by_pixel(x, 1) + conv(x, hidden_size, bottleneck_size, training, 'hr_conv'+str(i), reuse) + x = tf.layers.conv2d(x, 3, 1, activation=None, name='out', reuse=reuse) return x def conv(x, hidden_size, bottleneck_size, training, name, reuse): - x = tf.layers.batch_normalization(x, training=training) + x = tf.layers.batch_normalization(x, training=training, name=name+'_norm_proj', reuse=reuse) x = tf.nn.relu(x) x = tf.layers.conv2d(x, bottleneck_size, 1, activation=None, name=name+'_proj', reuse=reuse) - x = tf.layers.batch_normalization(x, training=training) + x = tf.layers.batch_normalization(x, training=training, name=name+'_norm_filt', reuse=reuse) x = tf.nn.relu(x) x = tf.layers.conv2d(x, bottleneck_size, 3, activation=None, name=name+'_filt', reuse=reuse) - x = tf.layers.batch_normalization(x, training=training) + x = tf.layers.batch_normalization(x, training=training, name=name+'_norm_recv', reuse=reuse) x = tf.nn.relu(x) x = tf.layers.conv2d(x, hidden_size, 1, activation=None, name=name+'_recv', reuse=reuse) return x - - diff --git a/predict.py b/predict.py index 2ebaee5..86afac0 100644 --- a/predict.py +++ b/predict.py @@ -36,7 +36,7 @@ else: lr_image = tf.reshape(lr_image, [1, lr_image_shape[0], lr_image_shape[1], 3]) lr_image = util.pad_boundary(lr_image) - hr_image = model.build_model(lr_image, FLAGS.scale, False) + hr_image = model.build_model(lr_image, FLAGS.scale, training=False, reuse=False) hr_image = util.crop_center(hr_image, hr_image_shape) hr_image = tf.image.convert_image_dtype(hr_image, tf.uint8, saturate=True) hr_image = tf.reshape(hr_image, [hr_image_shape[0], hr_image_shape[1], 3]) diff --git a/train.py b/train.py index 2300fac..51cabd7 100644 --- a/train.py +++ b/train.py @@ -28,7 +28,7 @@ stager = data_flow_ops.StagingArea([tf.float32, tf.float32], shapes=[[None, None, None, 3], [None, None, None, 3]]) stage = stager.put([target_batch_staging, source_batch_staging]) target_batch, source_batch = stager.get() - predict_batch = model.build_model(source_batch, FLAGS.scale, True) + predict_batch = model.build_model(source_batch, FLAGS.scale, training=True, reuse=False) target_cropped_batch = util.crop_center(target_batch, tf.shape(predict_batch)[1:3]) loss = tf.losses.mean_squared_error(target_cropped_batch, predict_batch) optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(loss) diff --git a/validate.py b/validate.py index ccced1f..0cd706a 100644 --- a/validate.py +++ b/validate.py @@ -39,7 +39,7 @@ else: lr_image = tf.reshape(lr_image, [1, lr_image_shape[0], lr_image_shape[1], 3]) lr_image = util.pad_boundary(lr_image) - lr_image = model.build_model(lr_image, FLAGS.scale, False) + lr_image = model.build_model(lr_image, FLAGS.scale, training=False, reuse=False) lr_image = util.crop_center(lr_image, hr_image_shape) error = tf.losses.mean_squared_error(hr_image, lr_image)