Skip to content

Commit

Permalink
use rnn.rnn
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed May 1, 2016
1 parent a3d6c93 commit aed3438
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions examples/char-rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from tensorpack.utils.lut import LookUpTable
from tensorpack.callbacks import *

from tensorflow.models.rnn import rnn_cell
from tensorflow.models.rnn import seq2seq
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import rnn

if six.PY2:
class NS: pass # this is a hack
Expand All @@ -30,7 +30,7 @@ class NS: pass # this is a hack
param = NS()
# some model hyperparams to set
param.batch_size = 128
param.rnn_size = 128
param.rnn_size = 256
param.num_rnn_layer = 2
param.seq_len = 50
param.grad_clip = 5.
Expand Down Expand Up @@ -90,7 +90,7 @@ def _get_cost(self, input_vars, is_training):
input_list = [tf.squeeze(x, [1]) for x in input_list]

# seqlen is 1 in inference. don't need loop_function
outputs, last_state = seq2seq.rnn_decoder(input_list, initial, cell, scope='rnnlm')
outputs, last_state = rnn.rnn(cell, input_list, initial, scope='rnnlm')
self.last_state = tf.identity(last_state, 'last_state')
# seqlen x (Bxrnnsize)
output = tf.reshape(tf.concat(1, outputs), [-1, param.rnn_size]) # (seqlenxB) x rnnsize
Expand Down Expand Up @@ -125,7 +125,8 @@ def get_config():
callbacks=Callbacks([
StatPrinter(),
ModelSaver(),
HumanHyperParamSetter('learning_rate', 'hyper.txt')
#HumanHyperParamSetter('learning_rate', 'hyper.txt')
SeduledHyperParamSetter('learning_rate', [(25, 2e-4)])
]),
model=Model(),
step_per_epoch=step_per_epoch,
Expand Down Expand Up @@ -184,6 +185,7 @@ def pick(prob):
default='The ', help='initial text sequence')
parser_sample.add_argument('-t', '--temperature', type=float,
default=1, help='softmax temperature')
parser_train = subparsers.add_parser('train', help='train')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
Expand Down

0 comments on commit aed3438

Please sign in to comment.