diff --git a/model.py b/model.py index 4ee7cfc1..93331ad9 100644 --- a/model.py +++ b/model.py @@ -3,6 +3,7 @@ from tensorflow.python.ops import seq2seq import numpy as np +import itertools class Model(): def __init__(self, args, infer=False): @@ -58,7 +59,7 @@ def loop(prev, _): optimizer = tf.train.AdamOptimizer(self.lr) self.train_op = optimizer.apply_gradients(zip(grads, tvars)) - def sample(self, sess, chars, vocab, num=200, prime='The ', sampling_type=1): + def stream(self, sess, chars, vocab, prime=u'The ', sampling_type=1): state = self.cell.zero_state(1, tf.float32).eval() for char in prime[:-1]: x = np.zeros((1, 1)) @@ -71,9 +72,11 @@ def weighted_pick(weights): s = np.sum(weights) return(int(np.searchsorted(t, np.random.rand(1)*s))) - ret = prime + for char in prime: + yield char + char = prime[-1] - for n in range(num): + while True: x = np.zeros((1, 1)) x[0, 0] = vocab[char] feed = {self.input_data: x, self.initial_state:state} @@ -90,9 +93,10 @@ def weighted_pick(weights): else: # sampling_type == 1 default: sample = weighted_pick(p) - pred = chars[sample] - ret += pred - char = pred - return ret + char = chars[sample] + yield char + def sample(self, sess, chars, vocab, num=200, prime=u'The ', sampling_type=1): + stream = self.stream(sess, chars, vocab, prime=prime, sampling_type=sampling_type) + return u''.join(itertools.islice(stream, num)) diff --git a/sample.py b/sample.py index d3f04285..56f5a3a4 100644 --- a/sample.py +++ b/sample.py @@ -9,6 +9,7 @@ from utils import TextLoader from model import Model +import sys from six import text_type @@ -17,7 +18,7 @@ def main(): parser.add_argument('--save_dir', type=str, default='save', help='model directory to store checkpointed models') parser.add_argument('-n', type=int, default=500, - help='number of characters to sample') + help='number of characters to sample, -1 to sample forever') parser.add_argument('--prime', type=text_type, default=u' ', help='prime text') parser.add_argument('--sample', type=int, default=1, @@ -38,7 +39,11 @@ def sample(args): ckpt = tf.train.get_checkpoint_state(args.save_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) - print(model.sample(sess, chars, vocab, args.n, args.prime, args.sample)) + if args.n > 0: + print(model.sample(sess, chars, vocab, args.n, args.prime, args.sample)) + else: + for char in model.stream(sess, chars, vocab, args.prime, args.sample): + sys.stdout.write(char) if __name__ == '__main__': main()