Skip to content

Commit

Permalink
better handle params
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed May 1, 2016
1 parent 0ca9fef commit a3d6c93
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 46 deletions.
73 changes: 42 additions & 31 deletions examples/char-rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import argparse
from collections import Counter
import operator
import six
from six.moves import map, range

from tensorpack import *
from tensorpack.models import *
Expand All @@ -20,16 +22,26 @@
from tensorflow.models.rnn import rnn_cell
from tensorflow.models.rnn import seq2seq

BATCH_SIZE = 128
RNN_SIZE = 128 # hidden state size
NUM_RNN_LAYER = 2
SEQ_LEN = 50
VOCAB_SIZE = None # will be initialized by CharRNNData
CORPUS = 'input.txt'
if six.PY2:
class NS: pass # this is a hack
else:
import types
NS = types.SimpleNamespace # this is what I wanted..
param = NS()
# some model hyperparams to set
param.batch_size = 128
param.rnn_size = 128
param.num_rnn_layer = 2
param.seq_len = 50
param.grad_clip = 5.
param.vocab_size = None
param.softmax_temprature = 1
param.corpus = 'input.txt'
# Get corpus to play with at: http://cs.stanford.edu/people/karpathy/char-rnn/

class CharRNNData(DataFlow):
def __init__(self, input_file, size):
self.seq_length = SEQ_LEN
self.seq_length = param.seq_len
self._size = size
self.rng = get_rng(self)

Expand All @@ -40,8 +52,7 @@ def __init__(self, input_file, size):
char_cnt = sorted(counter.items(), key=operator.itemgetter(1), reverse=True)
self.chars = [x[0] for x in char_cnt]
self.vocab_size = len(self.chars)
global VOCAB_SIZE
VOCAB_SIZE = self.vocab_size
param.vocab_size = self.vocab_size
self.lut = LookUpTable(self.chars)
self.whole_seq = np.array(list(map(self.lut.get_idx, data)), dtype='int32')

Expand All @@ -61,50 +72,48 @@ def get_data(self):

class Model(ModelDesc):
def _get_input_vars(self):
return [InputVar(tf.int32, (None, SEQ_LEN), 'input'),
InputVar(tf.int32, (None, SEQ_LEN), 'nextinput')
]
return [InputVar(tf.int32, (None, param.seq_len), 'input'),
InputVar(tf.int32, (None, param.seq_len), 'nextinput') ]

def _get_cost(self, input_vars, is_training):
input, nextinput = input_vars

cell = rnn_cell.BasicLSTMCell(RNN_SIZE)
cell = rnn_cell.MultiRNNCell([cell] * NUM_RNN_LAYER)
cell = rnn_cell.BasicLSTMCell(num_units=param.rnn_size)
cell = rnn_cell.MultiRNNCell([cell] * param.num_rnn_layer)

self.initial = initial = cell.zero_state(tf.shape(input)[0], tf.float32)

embeddingW = tf.get_variable('embedding', [VOCAB_SIZE, RNN_SIZE])
embeddingW = tf.get_variable('embedding', [param.vocab_size, param.rnn_size])
input_feature = tf.nn.embedding_lookup(embeddingW, input) # B x seqlen x rnnsize

input_list = tf.split(1, SEQ_LEN, input_feature) #seqlen x (Bx1xrnnsize)
input_list = tf.split(1, param.seq_len, input_feature) #seqlen x (Bx1xrnnsize)
input_list = [tf.squeeze(x, [1]) for x in input_list]

# seqlen is 1 in inference. don't need loop_function
outputs, last_state = seq2seq.rnn_decoder(input_list, initial, cell, scope='rnnlm')
self.last_state = tf.identity(last_state, 'last_state')
# seqlen x (Bxrnnsize)
output = tf.reshape(tf.concat(1, outputs), [-1, RNN_SIZE]) # (seqlenxB) x rnnsize
logits = FullyConnected('fc', output, VOCAB_SIZE, nl=tf.identity)
self.prob = tf.nn.softmax(logits)
output = tf.reshape(tf.concat(1, outputs), [-1, param.rnn_size]) # (seqlenxB) x rnnsize
logits = FullyConnected('fc', output, param.vocab_size, nl=tf.identity)
self.prob = tf.nn.softmax(logits / param.softmax_temprature)

xent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits, symbolic_functions.flatten(nextinput))
xent_loss = tf.reduce_mean(xent_loss, name='xent_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, xent_loss)

summary.add_param_summary([('.*/W', ['histogram'])]) # monitor histogram of all W
return tf.add_n([xent_loss], name='cost')

def get_gradient_processor(self):
return [MapGradient(lambda grad: tf.clip_by_global_norm([grad], 5.)[0][0])]
return [MapGradient(lambda grad: tf.clip_by_global_norm(
[grad], param.grad_clip)[0][0])]

def get_config():
basename = os.path.basename(__file__)
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))

ds = CharRNNData(CORPUS, 100000)
ds = BatchData(ds, 128)
ds = CharRNNData(param.corpus, 100000)
ds = BatchData(ds, param.batch_size)
step_per_epoch = ds.size()

lr = tf.Variable(2e-3, trainable=False, name='learning_rate')
Expand All @@ -130,9 +139,8 @@ def sample(path, start, length):
:param length: a `int`. the length of text to generate
"""
# initialize vocabulary and sequence length
global SEQ_LEN
SEQ_LEN = 1
ds = CharRNNData(CORPUS, 100000)
param.seq_len = 1
ds = CharRNNData(param.corpus, 100000)

model = Model()
input_vars = model.get_input_vars()
Expand Down Expand Up @@ -170,17 +178,20 @@ def pick(prob):
parser.add_argument('--load', help='load model')
subparsers = parser.add_subparsers(title='command', dest='command')
parser_sample = subparsers.add_parser('sample', help='sample a trained model')
parser_sample.add_argument('-n', '--num', type=int, default=300,
help='length of text to generate')
parser_sample.add_argument('-s', '--start', required=True, default='The ',
help='initial text sequence')
parser_sample.add_argument('-n', '--num', type=int,
default=300, help='length of text to generate')
parser_sample.add_argument('-s', '--start',
default='The ', help='initial text sequence')
parser_sample.add_argument('-t', '--temperature', type=float,
default=1, help='softmax temperature')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

if args.command == 'sample':
param.softmax_temprature = args.temperature
sample(args.load, args.start, args.num)
sys.exit()
else:
Expand Down
23 changes: 9 additions & 14 deletions examples/mnist-convnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os, sys
import argparse

import tensorpack as tp
from tensorpack import *
from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.callbacks import *
Expand Down Expand Up @@ -56,7 +56,7 @@ def _get_cost(self, input_vars, is_training):
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)

# compute the number of failed samples, for ClassificationError to use at test time
wrong = tp.symbolic_functions.prediction_incorrect(logits, label)
wrong = symbolic_functions.prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
tf.add_to_collection(
Expand All @@ -68,7 +68,7 @@ def _get_cost(self, input_vars, is_training):
name='regularize_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)

tp.summary.add_param_summary([('.*/W', ['histogram'])]) # monitor histogram of all W
summary.add_param_summary([('.*/W', ['histogram'])]) # monitor histogram of all W
return tf.add_n([wd_cost, cost], name='cost')

def get_config():
Expand All @@ -77,22 +77,18 @@ def get_config():
os.path.join('train_log', basename[:basename.rfind('.')]))

# prepare dataset
dataset_train = tp.BatchData(tp.dataset.Mnist('train'), 128)
dataset_test = tp.BatchData(tp.dataset.Mnist('test'), 256, remainder=True)
dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size()

# prepare session
sess_config = tp.get_default_sess_config()
sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5

lr = tf.train.exponential_decay(
learning_rate=1e-3,
global_step=tp.get_global_step_var(),
global_step=get_global_step_var(),
decay_steps=dataset_train.size() * 10,
decay_rate=0.3, staircase=True, name='learning_rate')
tf.scalar_summary('learning_rate', lr)

return tp.TrainConfig(
return TrainConfig(
dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([
Expand All @@ -101,7 +97,7 @@ def get_config():
InferenceRunner(dataset_test,
[ScalarStats('cost'), ClassificationError() ])
]),
session_config=sess_config,
session_config=get_default_sess_config(0.5),
model=Model(),
step_per_epoch=step_per_epoch,
max_epoch=100,
Expand All @@ -121,6 +117,5 @@ def get_config():
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
#tp.SimpleTrainer(config).train()
tp.QueueInputTrainer(config).train()
QueueInputTrainer(config).train()

1 change: 0 additions & 1 deletion tensorpack/tfutils/gradproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def _process(self, grads):
ret = []
for grad, var in grads:
if re.match(self.regex, var.op.name):
logger.info("DEBUG {}".format(var.op.name))
ret.append((self.func(grad), var))
else:
ret.append((grad, var))
Expand Down

0 comments on commit a3d6c93

Please sign in to comment.