In [1]:
from __future__ import print_function
import os

import numpy as np
import zipfile
import tarfile
from six.moves.urllib.request import urlretrieve
import shutil 
import random

import string
import tensorflow as tf

# Local dir where PTB files will be stored.
PTB_DIR = '/home/tkornuta/data/ptb/'

# Filenames.
TRAIN = "ptb.train.txt"
VALID = "ptb.valid.txt"
TEST = "ptb.test.txt"


### Check/maybe download PTB.

In [2]:
def maybe_download_ptb(path, 
                       filename='simple-examples.tgz', 
                       url='http://www.fit.vutbr.cz/~imikolov/rnnlm/', 
                       expected_bytes =34869662):
  """Download a file if not present, and make sure it's the right size."""
  _filename = path+filename
  if not os.path.exists(_filename):
    print('Downloading %s...' % filename)
    _filename, _ = urlretrieve(url+filename, _filename)
  statinfo = os.stat(_filename)
  if statinfo.st_size == expected_bytes:
    print('Found and verified', (_filename), '(', statinfo.st_size, ')')
  else:
    print(statinfo.st_size)
    raise Exception(
      'Failed to verify ' + _filename + '. Can you get to it with a browser?')
  return filename

filename = maybe_download_ptb(PTB_DIR)

Found and verified /home/tkornuta/data/ptb/simple-examples.tgz ( 34869662 )


### Extract dataset-related files from the PTB archive.

In [3]:
def extract_ptb(path, filename='simple-examples.tgz', files=["ptb.train.txt", "ptb.valid.txt", "ptb.test.txt", 
                                       "ptb.char.train.txt", "ptb.char.valid.txt", "ptb.char.test.txt"]):
    """Extracts files from PTB archive."""
    # Extract
    tar = tarfile.open(path+filename)
    tar.extractall(path)
    tar.close()
    # Copy files
    for file in files:
        shutil.copyfile(PTB_DIR+"simple-examples/data/"+file, PTB_DIR+file)
    # Delete directory
    shutil.rmtree(PTB_DIR+"simple-examples/")        

extract_ptb(PTB_DIR)
 

### Load train, valid and test texts.

In [4]:
def read_data(filename, path):
    with open(path+filename, 'r') as myfile:
        data=myfile.read()# .replace('\n', '')
        return data

train_text = read_data(TRAIN, PTB_DIR)
train_size=len(train_text)
print(train_size, train_text[:64])

valid_text = read_data(VALID, PTB_DIR)
valid_size=len(train_text)
print(valid_size, valid_text[:64])

test_text = read_data(TEST, PTB_DIR)
test_size=len(train_text)
print(test_size, test_text[:64])

5101618  aer banknote berlitz calloway centrust cluett fromstein gitano 
5101618  consumers may want to move their telephones a little closer to 
5101618  no it was n't black monday 
 but while the new york stock excha


### Utility functions to map characters to vocabulary IDs and back.

In [33]:
vocabulary_size = 59 + 32 + 10 # [A-Z] + [a-z] + ' ' +few 'in between; + punctuation
first_letter = ord(string.punctuation[0]) # uppercase before lowercase! 
print(vocabulary_size)
print(first_letter)

def char2id(char):
  """ Converts char to id (int) with chandling of unexpected characters"""
  if char in string.ascii_letters or char in string.punctuation or char in string.digits:
    return ord(char) - first_letter + 1
  elif char == ' ':
    return 0
  else:
    print('Unexpected character: %s' % char)
    return 0
  
def id2char(dictid):
  """ Converts single id (int) to character"""
  if dictid > 0:
    return chr(dictid + first_letter - 1)
  else:
    return ' '

def characters(probabilities):
  """Turn a 1-hot encoding or a probability distribution over the possible
  characters back into its (most likely) character representation."""
  return [id2char(c) for c in np.argmax(probabilities, 1)]

def batches2string(batches):
  """Convert a sequence of batches back into their (most likely) string
  representation."""
  s = [''] * batches[0].shape[0]
  for b in batches:
    s = [''.join(x) for x in zip(s, characters(b))]
  return s

#print(len(string.punctuation))
#for i in string.ascii_letters:
#    print (i, char2id(i))


print(char2id('a'), char2id('A'), char2id('z'), char2id('Z'), char2id(' '), char2id('ï'))
print(id2char(char2id('a')), id2char(char2id('A')))
print(id2char(65), id2char(33), id2char(90), id2char(58), id2char(0))

101
33
Unexpected character: ï
65 33 90 58 0 0
a A
a A z Z  


### Helper class

In [29]:
batch_size=64
num_unrollings=10

class BatchGenerator(object):
  def __init__(self, text, batch_size, num_unrollings):
    self._text = text
    self._text_size = len(text)
    self._batch_size = batch_size
    self._num_unrollings = num_unrollings
    segment = self._text_size // batch_size
    self._cursor = [ offset * segment for offset in range(batch_size)]
    self._last_batch = self._next_batch()
  
  def _next_batch(self):
    """Generate a single batch from the current cursor position in the data."""
    batch = np.zeros(shape=(self._batch_size, vocabulary_size), dtype=np.float)
    for b in range(self._batch_size):
      batch[b, char2id(self._text[self._cursor[b]])] = 1.0
      self._cursor[b] = (self._cursor[b] + 1) % self._text_size
    return batch
  
  def next(self):
    """Generate the next array of batches from the data. The array consists of
    the last batch of the previous array, followed by num_unrollings new ones.
    """
    batches = [self._last_batch]
    for step in range(self._num_unrollings):
      batches.append(self._next_batch())
    self._last_batch = batches[-1]
    return batches


train_batches = BatchGenerator(train_text, batch_size, num_unrollings)
valid_batches = BatchGenerator(valid_text, 1, 1)

batch = train_batches.next()
#print(batch)
print(batches2string(batch))


Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

[' aer bankno', 'as the six ', 'ack said he', 'significant', "age does n'", 'rk stock ex', 's goodman t', 'rs who are ', ' <unk> tast', ' university', 'ite house a', 'ole over <u', ' <unk> elec', 'omputers   ', ' north <unk', '<unk> aircr', 'drug compan', 'annual rate', ' <unk>   <u', 's although ', 'cturing of ', 'tional coun', "'t add to v", 'na which ha', ' <unk>   hi', 'on the amen', ' while many', 'happens the', ' teams and ', ' board argu', 'ts a share ', 'r $ N a sha', ' <unk> to a', ' sang the d', 're from $ N', 'ackluster m', 'the least l', 'luding prof', 'eral norieg', 'unk> acquis', 'ood as well', ' arrest the', 'gh court ma', 'en a big lo', 'liquid crys', 'cle   durin', 'rrowed indi', 'police depa', ' tables sho', 'mr. baldwin', 'r the year ', 'rease despi', ' speaker th', 'N N N days ', 'ers general', 'ig customer', '> flight to', '> back in s', ' should do ', 'ill face an', 'by

In [30]:
def logprob(predictions, labels):
  """Log-probability of the true labels in a predicted batch."""
  predictions[predictions < 1e-10] = 1e-10
  return np.sum(np.multiply(labels, -np.log(predictions))) / labels.shape[0]

def sample_distribution(distribution):
  """Sample one element from a distribution assumed to be an array of normalized
  probabilities.
  """
  r = random.uniform(0, 1)
  s = 0
  for i in range(len(distribution)):
    s += distribution[i]
    if s >= r:
      return i
  return len(distribution) - 1

def sample(prediction):
  """Turn a (column) prediction into 1-hot encoded samples."""
  p = np.zeros(shape=[1, vocabulary_size], dtype=np.float)
  p[0, sample_distribution(prediction[0])] = 1.0
  return p

def random_distribution():
  """Generate a random column of probabilities."""
  b = np.random.uniform(0.0, 1.0, size=[1, vocabulary_size])
  return b/np.sum(b, 1)[:,None]

In [31]:
num_nodes = 64

graph = tf.Graph()
with graph.as_default():
  
  # Parameters:
  # Input gate: input, previous output, and bias.
  ix = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], -0.1, 0.1))
  im = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))
  ib = tf.Variable(tf.zeros([1, num_nodes]))
  # Forget gate: input, previous output, and bias.
  fx = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], -0.1, 0.1))
  fm = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))
  fb = tf.Variable(tf.zeros([1, num_nodes]))
  # Memory cell: input, state and bias.                             
  cx = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], -0.1, 0.1))
  cm = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))
  cb = tf.Variable(tf.zeros([1, num_nodes]))
  # Output gate: input, previous output, and bias.
  ox = tf.Variable(tf.truncated_normal([vocabulary_size, num_nodes], -0.1, 0.1))
  om = tf.Variable(tf.truncated_normal([num_nodes, num_nodes], -0.1, 0.1))
  ob = tf.Variable(tf.zeros([1, num_nodes]))
  # Variables saving state across unrollings.
  saved_output = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False)
  saved_state = tf.Variable(tf.zeros([batch_size, num_nodes]), trainable=False)
  # Classifier weights and biases.
  w = tf.Variable(tf.truncated_normal([num_nodes, vocabulary_size], -0.1, 0.1))
  b = tf.Variable(tf.zeros([vocabulary_size]))
  
  # Definition of the cell computation.
  def lstm_cell(i, o, state):
    """Create a LSTM cell. See e.g.: http://arxiv.org/pdf/1402.1128v1.pdf
    Note that in this formulation, we omit the various connections between the
    previous state and the gates."""
    input_gate = tf.sigmoid(tf.matmul(i, ix) + tf.matmul(o, im) + ib)
    forget_gate = tf.sigmoid(tf.matmul(i, fx) + tf.matmul(o, fm) + fb)
    update = tf.matmul(i, cx) + tf.matmul(o, cm) + cb
    state = forget_gate * state + input_gate * tf.tanh(update)
    output_gate = tf.sigmoid(tf.matmul(i, ox) + tf.matmul(o, om) + ob)
    return output_gate * tf.tanh(state), state

  # Input data.
  train_data = list()
  for _ in range(num_unrollings + 1):
    train_data.append(
      tf.placeholder(tf.float32, shape=[batch_size,vocabulary_size]))
  train_inputs = train_data[:num_unrollings]
  train_labels = train_data[1:]  # labels are inputs shifted by one time step.

  # Unrolled LSTM loop.
  outputs = list()
  output = saved_output
  state = saved_state
  for i in train_inputs:
    output, state = lstm_cell(i, output, state)
    outputs.append(output)

  # State saving across unrollings.
  with tf.control_dependencies([saved_output.assign(output),
                                saved_state.assign(state)]):
    # Classifier.
    logits = tf.nn.xw_plus_b(tf.concat(outputs, 0), w, b)
    loss = tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(
        labels=tf.concat(train_labels, 0), logits=logits))

  # Optimizer.
  global_step = tf.Variable(0)
  learning_rate = tf.train.exponential_decay(
    10.0, global_step, 5000, 0.1, staircase=True)
  optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  gradients, v = zip(*optimizer.compute_gradients(loss))
  gradients, _ = tf.clip_by_global_norm(gradients, 1.25)
  optimizer = optimizer.apply_gradients(
    zip(gradients, v), global_step=global_step)

  # Predictions.
  train_prediction = tf.nn.softmax(logits)
  
  # Sampling and validation eval: batch 1, no unrolling.
  sample_input = tf.placeholder(tf.float32, shape=[1, vocabulary_size])
  saved_sample_output = tf.Variable(tf.zeros([1, num_nodes]))
  saved_sample_state = tf.Variable(tf.zeros([1, num_nodes]))
  reset_sample_state = tf.group(
    saved_sample_output.assign(tf.zeros([1, num_nodes])),
    saved_sample_state.assign(tf.zeros([1, num_nodes])))
  sample_output, sample_state = lstm_cell(
    sample_input, saved_sample_output, saved_sample_state)
  with tf.control_dependencies([saved_sample_output.assign(sample_output),
                                saved_sample_state.assign(sample_state)]):
    sample_prediction = tf.nn.softmax(tf.nn.xw_plus_b(sample_output, w, b))

In [32]:
num_steps = 7001
summary_frequency = 100

with tf.Session(graph=graph) as session:
  tf.global_variables_initializer().run()
  print('Initialized')
  mean_loss = 0
  for step in range(num_steps):
    batches = train_batches.next()
    feed_dict = dict()
    for i in range(num_unrollings + 1):
      feed_dict[train_data[i]] = batches[i]
    _, l, predictions, lr = session.run(
      [optimizer, loss, train_prediction, learning_rate], feed_dict=feed_dict)
    mean_loss += l
    if step % summary_frequency == 0:
      if step > 0:
        mean_loss = mean_loss / summary_frequency
      # The mean loss is an estimate of the loss over the last few batches.
      print(
        'Average loss at step %d: %f learning rate: %f' % (step, mean_loss, lr))
      mean_loss = 0
      labels = np.concatenate(list(batches)[1:])
      print('Minibatch perplexity: %.2f' % float(
        np.exp(logprob(predictions, labels))))
      if step % (summary_frequency * 10) == 0:
        # Generate some samples.
        print('=' * 80)
        for _ in range(5):
          feed = sample(random_distribution())
          sentence = characters(feed)[0]
          reset_sample_state.run()
          for _ in range(79):
            prediction = sample_prediction.eval({sample_input: feed})
            feed = sample(prediction)
            sentence += characters(feed)[0]
          print(sentence)
        print('=' * 80)
      # Measure validation set perplexity.
      reset_sample_state.run()
      valid_logprob = 0
      for _ in range(valid_size):
        b = valid_batches.next()
        predictions = sample_prediction.eval({sample_input: b[0]})
        valid_logprob = valid_logprob + logprob(predictions, b[1])
      print('Validation set perplexity: %.2f' % float(np.exp(
        valid_logprob / valid_size)))

Initialized
Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Average loss at step 0: 4.619434 learning rate: 10.000000
Minibatch perplexity: 101.44
NNc1$mb_lHsM:z`qJae*9Ct+t({[HJeHmZ. IZQ!rr$'Wq7:cu#6sQBE^<Bwq:&Kq}vAts~atl
\)# d)-UtW\UY;EeAsrm0h@e05qoxT]#jko1?YaBkeJ{_[ZAFWa c<G7PQC E)ms/Cn Ri }=tZB
m!rojATu1o?ZkDaaM<PrnDtz+Vijta}C~oH2wDg ?LZ\#y9Fa9i6WoP Q5(Jd k{M4]Uiknk4eM2W
ovMbtixGsOg5EzkXQ h'{|L2d;d&n+F8bt{ C#m0 %t\]U)o7/Hfidi{9ByPe0t7X1  f}5 >we*
$#ys8n"0 O -b. [o0(OG9"nel}`-RaT/]w)(+W]g>dmy0GeaSILW\tI*@em}B*LYMlS'a"|( {
Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected chara

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected chara

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected chara

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected chara

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected chara

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected character: 

Unexpected chara

KeyboardInterrupt: 