Skip to content

Commit

Permalink
better model and config
Browse files Browse the repository at this point in the history
  • Loading branch information
ychfan committed Apr 13, 2017
1 parent 4bc36f9 commit 7609d1c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
11 changes: 10 additions & 1 deletion model_pixel_up.py
Expand Up @@ -10,7 +10,16 @@ def build_model(x, scale, training, reuse):
for i in range(6):
x = util.crop_by_pixel(x, 1) + conv(x, hidden_size, bottleneck_size, training, 'lr_conv'+str(i), reuse)
x = util.lrelu(x)
x = tf.image.resize_nearest_neighbor(x, tf.shape(x)[1:3] * scale) + tf.layers.conv2d_transpose(x, hidden_size, scale, strides=scale, activation=None, name='up', reuse=reuse)
if (scale == 4):
scale = 2
x = tf.layers.conv2d_transpose(x, hidden_size, scale, strides=scale, activation=None, name='up1', reuse=reuse)
x = util.crop_by_pixel(x, 1) + conv(x, hidden_size, bottleneck_size, training, 'up_conv', reuse)
x = util.lrelu(x)
hidden_size = 64
x = tf.layers.conv2d_transpose(x, hidden_size, scale, strides=scale, activation=None, name='up2', reuse=reuse)
else:
hidden_size = 64
x = tf.layers.conv2d_transpose(x, hidden_size, scale, strides=scale, activation=None, name='up', reuse=reuse)
for i in range(4):
x = util.crop_by_pixel(x, 1) + conv(x, hidden_size, bottleneck_size, training, 'hr_conv'+str(i), reuse)
x = util.lrelu(x)
Expand Down
10 changes: 5 additions & 5 deletions run.sh
Expand Up @@ -24,10 +24,10 @@ set -x

EXPR_NAME="try"
TRAIN_DIR="tmp"
MODEL_NAME="model_resnet_up"
MODEL_NAME="model_pixel_up"
DATA_NAME="data_residual"
HR_FLIST="flist/hr.flist"
LR_FLIST="flist/lrX2.flist"
HR_FLIST="flist/hr_tv.flist"
LR_FLIST="flist/lrX2_bicubic_tv.flist"
SCALE=2
LEARNING_RATE=0.001

Expand All @@ -49,10 +49,10 @@ ARGS="--data_name=$DATA_NAME --hr_flist=$HR_FLIST --lr_flist=$LR_FLIST --model_n

iter=0
rate=$LEARNING_RATE
for i in `seq 1 8`;
for i in `seq 1 16`;
do
python $SCRIPT $ARGS --model_file_in=$MODEL_FILE-$iter --model_file_out=$MODEL_FILE-$((iter+1)) --learning_rate=$rate
iter=$((iter+1))
rate=$(echo "$rate" | awk '{print $1*0.5}')
rate=$(echo "$rate" | awk '{print $1*0.618}')
echo "Iteration $iter Finished"
done

0 comments on commit 7609d1c

Please sign in to comment.