Skip to content

Commit

Permalink
add upsample resnet model
Browse files Browse the repository at this point in the history
  • Loading branch information
ychfan committed Mar 29, 2017
1 parent 9b38703 commit 22d59af
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions model_resnet_up.py
@@ -0,0 +1,32 @@
import tensorflow as tf
import util

upsample = True

def build_model(x, scale, training):
hidden_size = 128
bottleneck_size = 32
x = tf.layers.conv2d(x, hidden_size, 3, activation=None, name='in')
for i in range(5):
x = util.crop_by_pixel(x, 1) + conv(x, hidden_size, bottleneck_size, training, 'lr_conv'+str(i), False)
x = tf.layers.conv2d_transpose(x, hidden_size, scale, strides=scale, activation=None, name='up')
for i in range(5):
x = util.crop_by_pixel(x, 1) + conv(x, hidden_size, bottleneck_size, training, 'hr_conv'+str(i), False)
x = tf.layers.conv2d(x, 3, 1, activation=None, name='out')
return x

def conv(x, hidden_size, bottleneck_size, training, name, reuse):
x = tf.layers.batch_normalization(x, training=training)
x = tf.nn.relu(x)
x = tf.layers.conv2d(x, bottleneck_size, 1, activation=None, name=name+'_proj', reuse=reuse)

x = tf.layers.batch_normalization(x, training=training)
x = tf.nn.relu(x)
x = tf.layers.conv2d(x, bottleneck_size, 3, activation=None, name=name+'_filt', reuse=reuse)

x = tf.layers.batch_normalization(x, training=training)
x = tf.nn.relu(x)
x = tf.layers.conv2d(x, hidden_size, 1, activation=None, name=name+'_recv', reuse=reuse)
return x


2 comments on commit 22d59af

@WeiHan3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When scale>2, composition of multiple stride=2 conv2d_transpose layers may work better than a single stride=scale conv2d_transpose layer.

@ychfan
Copy link
Owner Author

@ychfan ychfan commented on 22d59af Mar 29, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can only be applied to x4, x8 and so on. How about the non-linearity between them? Just batch_norm+relu? Do you think I need to add a projection from 128 to 32?
And do you have any comment on the filter size for deconv? 2x2 with stride=2 means that there is no overlap between kernels.

Please sign in to comment.