Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a script to compute the perplexity of test data #56

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 47 additions & 0 deletions eval.py
@@ -0,0 +1,47 @@
from __future__ import print_function
import numpy as np
import tensorflow as tf

import argparse
import codecs
import time
import os
from six.moves import cPickle

from utils import TextLoader
from model import Model

from six import text_type

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--save_dir', type=str, default='save',
help='model directory to store checkpointed models')
parser.add_argument('--text', type=str,
help='filename of text to evaluate on')
args = parser.parse_args()
eval(args)

def eval(args):
with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
saved_args = cPickle.load(f)
saved_args.batch_size = 1
saved_args.seq_length = 200
with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f:
chars, vocab = cPickle.load(f)
model = Model(saved_args, False)

with codecs.open(args.text, 'r', encoding='utf-8') as f:
text = f.read()

with tf.Session() as sess:
tf.initialize_all_variables().run()
saver = tf.train.Saver(tf.all_variables())
ckpt = tf.train.get_checkpoint_state(args.save_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
ppl = model.eval(sess, chars, vocab, text)
print('perplexity: {0}'.format(ppl))

if __name__ == '__main__':
main()
31 changes: 27 additions & 4 deletions model.py
Expand Up @@ -33,8 +33,8 @@ def __init__(self, args, infer=False):
softmax_b = tf.get_variable("softmax_b", [args.vocab_size])
with tf.device("/cpu:0"):
embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size])
inputs = tf.split(1, args.seq_length, tf.nn.embedding_lookup(embedding, self.input_data))
inputs = [tf.squeeze(input_, [1]) for input_ in inputs]
input_embeddings = tf.nn.embedding_lookup(embedding, self.input_data)
inputs = tf.unpack(input_embeddings, axis=1)

def loop(prev, _):
prev = tf.matmul(prev, softmax_w) + softmax_b
Expand All @@ -45,11 +45,11 @@ def loop(prev, _):
output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size])
self.logits = tf.matmul(output, softmax_w) + softmax_b
self.probs = tf.nn.softmax(self.logits)
loss = seq2seq.sequence_loss_by_example([self.logits],
self.loss = seq2seq.sequence_loss_by_example([self.logits],
[tf.reshape(self.targets, [-1])],
[tf.ones([args.batch_size * args.seq_length])],
args.vocab_size)
self.cost = tf.reduce_sum(loss) / args.batch_size / args.seq_length
self.cost = tf.reduce_sum(self.loss) / args.batch_size / args.seq_length
self.final_state = last_state
self.lr = tf.Variable(0.0, trainable=False)
tvars = tf.trainable_variables()
Expand All @@ -58,6 +58,29 @@ def loop(prev, _):
optimizer = tf.train.AdamOptimizer(self.lr)
self.train_op = optimizer.apply_gradients(zip(grads, tvars))

def eval(self, sess, chars, vocab, text):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably better to move this to eval.py

batch_size = 200

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seq_length you mean?

state = sess.run(self.cell.zero_state(1, tf.float32))
x = [vocab[c] if c in vocab else vocab['UNK'] for c in text]
x = [vocab['<S>']] + x + [vocab['</S>']]
total_len = len(x) - 1
# pad x so the batch_size divides it
while len(x) % 200 != 1:
x.append(vocab[' '])
y = np.array(x[1:]).reshape((-1, batch_size))
x = np.array(x[:-1]).reshape((-1, batch_size))

total_loss = 0.0
for i in range(x.shape[0]):
feed = {self.input_data: x[i:i+1, :], self.targets: y[i:i+1, :],
self.initial_state: state}
[state, loss] = sess.run([self.final_state, self.loss], feed)
total_loss += loss.sum()
# need to subtract off loss from padding tokens
total_loss -= loss[total_len % batch_size - batch_size:].sum()
avg_entropy = total_loss / len(text)
return np.exp(avg_entropy) # this is the perplexity

def sample(self, sess, chars, vocab, num=200, prime='The ', sampling_type=1):
state = sess.run(self.cell.zero_state(1, tf.float32))
for char in prime[:-1]:
Expand Down
123 changes: 123 additions & 0 deletions model.py~
@@ -0,0 +1,123 @@
import tensorflow as tf
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import seq2seq

import numpy as np

class Model():
def __init__(self, args, infer=False):
self.args = args
if infer:
args.batch_size = 1
args.seq_length = 1

if args.model == 'rnn':
cell_fn = rnn_cell.BasicRNNCell
elif args.model == 'gru':
cell_fn = rnn_cell.GRUCell
elif args.model == 'lstm':
cell_fn = rnn_cell.BasicLSTMCell
else:
raise Exception("model type not supported: {}".format(args.model))

cell = cell_fn(args.rnn_size, state_is_tuple=True)

self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers, state_is_tuple=True)

self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.targets = tf.placeholder(tf.int32, [args.batch_size, args.seq_length])
self.initial_state = cell.zero_state(args.batch_size, tf.float32)

with tf.variable_scope('rnnlm'):
softmax_w = tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size])
softmax_b = tf.get_variable("softmax_b", [args.vocab_size])
with tf.device("/cpu:0"):
embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size])
input_embeddings = tf.nn.embedding_lookup(embedding, self.input_data)
inputs = tf.unpack(input_embeddings, axis=1)
# inputs = tf.split(1, args.seq_length, tf.nn.embedding_lookup(embedding, self.input_data))
# inputs = [tf.squeeze(input_, [1]) for input_ in inputs]

def loop(prev, _):
prev = tf.matmul(prev, softmax_w) + softmax_b
prev_symbol = tf.stop_gradient(tf.argmax(prev, 1))
return tf.nn.embedding_lookup(embedding, prev_symbol)

outputs, last_state = seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if infer else None, scope='rnnlm')
output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size])
self.logits = tf.matmul(output, softmax_w) + softmax_b
self.probs = tf.nn.softmax(self.logits)
self.loss = seq2seq.sequence_loss_by_example([self.logits],
[tf.reshape(self.targets, [-1])],
[tf.ones([args.batch_size * args.seq_length])],
args.vocab_size)
self.cost = tf.reduce_sum(self.loss) / args.batch_size / args.seq_length
self.final_state = last_state
self.lr = tf.Variable(0.0, trainable=False)
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars),
args.grad_clip)
optimizer = tf.train.AdamOptimizer(self.lr)
self.train_op = optimizer.apply_gradients(zip(grads, tvars))

def eval(self, sess, chars, vocab, text):
batch_size = 200
state = sess.run(self.cell.zero_state(1, tf.float32))
x = [vocab[c] if c in vocab else vocab['UNK'] for c in text]
x = [vocab['<S>']] + x + [vocab['</S>']]
total_len = len(x) - 1
# pad x so the batch_size divides it
while len(x) % 200 != 1:
x.append(vocab[' '])
y = np.array(x[1:]).reshape((-1, batch_size))
x = np.array(x[:-1]).reshape((-1, batch_size))

total_loss = 0.0
for i in range(x.shape[0]):
feed = {self.input_data: x[i:i+1, :], self.targets: y[i:i+1, :],
self.initial_state: state}
[state, loss] = sess.run([self.final_state, self.loss], feed)
total_loss += loss.sum()
# need to subtract off loss from padding tokens
total_loss -= loss[total_len % batch_size - batch_size:].sum()
avg_entropy = total_loss / len(text)
return np.exp(avg_entropy) # this is the perplexity

def sample(self, sess, chars, vocab, num=200, prime='The ', sampling_type=1):
state = sess.run(self.cell.zero_state(1, tf.float32))
for char in prime[:-1]:
x = np.zeros((1, 1))
x[0, 0] = vocab[char]
feed = {self.input_data: x, self.initial_state:state}
[state] = sess.run([self.final_state], feed)

def weighted_pick(weights):
t = np.cumsum(weights)
s = np.sum(weights)
return(int(np.searchsorted(t, np.random.rand(1)*s)))

ret = prime
char = prime[-1]
for n in range(num):
x = np.zeros((1, 1))
x[0, 0] = vocab[char]
feed = {self.input_data: x, self.initial_state:state}
[probs, state] = sess.run([self.probs, self.final_state], feed)
p = probs[0]

if sampling_type == 0:
sample = np.argmax(p)
elif sampling_type == 2:
if char == ' ':
sample = weighted_pick(p)
else:
sample = np.argmax(p)
else: # sampling_type == 1 default:
sample = weighted_pick(p)

pred = chars[sample]
ret += pred
char = pred
return ret


4 changes: 3 additions & 1 deletion utils.py
Expand Up @@ -4,6 +4,7 @@
from six.moves import cPickle
import numpy as np


class TextLoader():
def __init__(self, data_dir, batch_size, seq_length, encoding='utf-8'):
self.data_dir = data_dir
Expand All @@ -28,13 +29,14 @@ def preprocess(self, input_file, vocab_file, tensor_file):
with codecs.open(input_file, "r", encoding=self.encoding) as f:
data = f.read()
counter = collections.Counter(data)
counter.update(('<S>', '</S>', 'UNK')) # add tokens for start end and unk
count_pairs = sorted(counter.items(), key=lambda x: -x[1])
self.chars, _ = zip(*count_pairs)
self.vocab_size = len(self.chars)
self.vocab = dict(zip(self.chars, range(len(self.chars))))
with open(vocab_file, 'wb') as f:
cPickle.dump(self.chars, f)
self.tensor = np.array(list(map(self.vocab.get, data)))
self.tensor = np.array(list(map(self.vocab.get, ['<S>'] + list(data) + ['</S>'])))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it would be a better idea to write this after line 59, self.tensor = self.tensor[:self.num_batches * self.batch_size * self.seq_length], since it's unlikely that you will get the </S> character

np.save(tensor_file, self.tensor)

def load_preprocessed(self, vocab_file, tensor_file):
Expand Down