Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions namignizer/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Remove the pyc files
*.pyc

# Ignore the model and the data
model/
data/
82 changes: 82 additions & 0 deletions namignizer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Namignizer

Use a variation of the [PTB](https://www.tensorflow.org/versions/r0.8/tutorials/recurrent/index.html#recurrent-neural-networks) model to recognize and generate names using the [Kaggle Baby Name Database](https://www.kaggle.com/kaggle/us-baby-names).

### API
Namignizer is implemented in Tensorflow 0.8r and uses the python package `pandas` for some data processing.

#### How to use
Download the data from Kaggle and place it in your data directory (or use the small training data provided). The example data looks like so:

```
Id,Name,Year,Gender,Count
1,Mary,1880,F,7065
2,Anna,1880,F,2604
3,Emma,1880,F,2003
4,Elizabeth,1880,F,1939
5,Minnie,1880,F,1746
6,Margaret,1880,F,1578
7,Ida,1880,F,1472
8,Alice,1880,F,1414
9,Bertha,1880,F,1320
```

But any data with the two columns: `Name` and `Count` will work.

With the data, we can then train the model:

```python
train("data/SmallNames.txt", "model/namignizer", SmallConfig)
```

And you will get the output:

```
Reading Name data in data/SmallNames.txt
Epoch: 1 Learning rate: 1.000
0.090 perplexity: 18.539 speed: 282 lps
...
0.890 perplexity: 1.478 speed: 285 lps
0.990 perplexity: 1.477 speed: 284 lps
Epoch: 13 Train Perplexity: 1.477
```

This will as a side effect write model checkpoints to the `model` directory. With this you will be able to determine the perplexity your model will give you for any arbitrary set of names like so:

```python
namignize(["mary", "ida", "gazorpazorp", "houyhnhnms", "bob"],
tf.train.latest_checkpoint("model"), SmallConfig)
```
You will provide the same config and the same checkpoint directory. This will allow you to use a the model you just trained. You will then get a perplexity output for each name like so:

```
Name mary gives us a perplexity of 1.03105580807
Name ida gives us a perplexity of 1.07770049572
Name gazorpazorp gives us a perplexity of 175.940353394
Name houyhnhnms gives us a perplexity of 9.53870773315
Name bob gives us a perplexity of 6.03938627243
```

Finally, you will also be able generate names using the model like so:

```python
namignator(tf.train.latest_checkpoint("model"), SmallConfig)
```

Again, you will need to provide the same config and the same checkpoint directory. This will allow you to use a the model you just trained. You will then get a single generated name. Examples of output that I got when using the provided data are:

```
['b', 'e', 'r', 't', 'h', 'a', '`']
['m', 'a', 'r', 'y', '`']
['a', 'n', 'n', 'a', '`']
['m', 'a', 'r', 'y', '`']
['b', 'e', 'r', 't', 'h', 'a', '`']
['a', 'n', 'n', 'a', '`']
['e', 'l', 'i', 'z', 'a', 'b', 'e', 't', 'h', '`']
```

Notice that each name ends with a backtick. This marks the end of the name.

### Contact Info

Feel free to reach out to me at knt(at google) or k.nathaniel.tucker(at gmail)
119 changes: 119 additions & 0 deletions namignizer/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for parsing Kaggle baby names files."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import os

import numpy as np
import tensorflow as tf
import pandas as pd

# the default end of name rep will be zero
_EON = 0


def read_names(names_path):
"""read data from downloaded file. See SmallNames.txt for example format
or go to https://www.kaggle.com/kaggle/us-baby-names for full lists

Args:
names_path: path to the csv file similar to the example type
Returns:
Dataset: a namedtuple of two elements: deduped names and their associated
counts. The names contain only 26 chars and are all lower case
"""
names_data = pd.read_csv(names_path)
names_data.Name = names_data.Name.str.lower()

name_data = names_data.groupby(by=["Name"])["Count"].sum()
name_counts = np.array(name_data.tolist())
names_deduped = np.array(name_data.index.tolist())

Dataset = collections.namedtuple('Dataset', ['Name', 'Count'])
return Dataset(names_deduped, name_counts)


def _letter_to_number(letter):
"""converts letters to numbers between 1 and 27"""
# ord of lower case 'a' is 97
return ord(letter) - 96


def namignizer_iterator(names, counts, batch_size, num_steps, epoch_size):
"""Takes a list of names and counts like those output from read_names, and
makes an iterator yielding a batch_size by num_steps array of random names
separated by an end of name token. The names are choosen randomly according
to their counts. The batch may end mid-name

Args:
names: a set of lowercase names composed of 26 characters
counts: a list of the frequency of those names
batch_size: int
num_steps: int
epoch_size: number of batches to yield
Yields:
(x, y): a batch_size by num_steps array of ints representing letters, where
x will be the input and y will be the target
"""
name_distribution = counts / counts.sum()

for i in range(epoch_size):
data = np.zeros(batch_size * num_steps + 1)
samples = np.random.choice(names, size=batch_size * num_steps // 2,
replace=True, p=name_distribution)

data_index = 0
for sample in samples:
if data_index >= batch_size * num_steps:
break
for letter in map(_letter_to_number, sample) + [_EON]:
if data_index >= batch_size * num_steps:
break
data[data_index] = letter
data_index += 1

x = data[:batch_size * num_steps].reshape((batch_size, num_steps))
y = data[1:batch_size * num_steps + 1].reshape((batch_size, num_steps))

yield (x, y)


def name_to_batch(name, batch_size, num_steps):
""" Takes a single name and fills a batch with it

Args:
name: lowercase composed of 26 characters
batch_size: int
num_steps: int
Returns:
x, y: a batch_size by num_steps array of ints representing letters, where
x will be the input and y will be the target. The array is filled up
to the length of the string, the rest is filled with zeros
"""
data = np.zeros(batch_size * num_steps + 1)

data_index = 0
for letter in map(_letter_to_number, name) + [_EON]:
data[data_index] = letter
data_index += 1

x = data[:batch_size * num_steps].reshape((batch_size, num_steps))
y = data[1:batch_size * num_steps + 1].reshape((batch_size, num_steps))

return x, y
133 changes: 133 additions & 0 deletions namignizer/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RNN model with embeddings"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf


class NamignizerModel(object):
"""The Namignizer model ~ strongly based on PTB"""

def __init__(self, is_training, config):
self.batch_size = batch_size = config.batch_size
self.num_steps = num_steps = config.num_steps
size = config.hidden_size
# will always be 27
vocab_size = config.vocab_size

# placeholders for inputs
self._input_data = tf.placeholder(tf.int32, [batch_size, num_steps])
self._targets = tf.placeholder(tf.int32, [batch_size, num_steps])
# weights for the loss function
self._weights = tf.placeholder(tf.float32, [batch_size * num_steps])

# lstm for our RNN cell (GRU supported too)
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(size, forget_bias=0.0)
if is_training and config.keep_prob < 1:
lstm_cell = tf.nn.rnn_cell.DropoutWrapper(
lstm_cell, output_keep_prob=config.keep_prob)
cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * config.num_layers)

self._initial_state = cell.zero_state(batch_size, tf.float32)

with tf.device("/cpu:0"):
embedding = tf.get_variable("embedding", [vocab_size, size])
inputs = tf.nn.embedding_lookup(embedding, self._input_data)

if is_training and config.keep_prob < 1:
inputs = tf.nn.dropout(inputs, config.keep_prob)

outputs = []
state = self._initial_state
with tf.variable_scope("RNN"):
for time_step in range(num_steps):
if time_step > 0:
tf.get_variable_scope().reuse_variables()
(cell_output, state) = cell(inputs[:, time_step, :], state)
outputs.append(cell_output)

output = tf.reshape(tf.concat(1, outputs), [-1, size])
softmax_w = tf.get_variable("softmax_w", [size, vocab_size])
softmax_b = tf.get_variable("softmax_b", [vocab_size])
logits = tf.matmul(output, softmax_w) + softmax_b
loss = tf.nn.seq2seq.sequence_loss_by_example(
[logits],
[tf.reshape(self._targets, [-1])],
[self._weights])
self._loss = loss
self._cost = cost = tf.reduce_sum(loss) / batch_size
self._final_state = state

# probabilities of each letter
self._activations = tf.nn.softmax(logits)

# ability to save the model
self.saver = tf.train.Saver(tf.all_variables())

if not is_training:
return

self._lr = tf.Variable(0.0, trainable=False)
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars),
config.max_grad_norm)
optimizer = tf.train.GradientDescentOptimizer(self.lr)
self._train_op = optimizer.apply_gradients(zip(grads, tvars))

def assign_lr(self, session, lr_value):
session.run(tf.assign(self.lr, lr_value))

@property
def input_data(self):
return self._input_data

@property
def targets(self):
return self._targets

@property
def activations(self):
return self._activations

@property
def weights(self):
return self._weights

@property
def initial_state(self):
return self._initial_state

@property
def cost(self):
return self._cost

@property
def loss(self):
return self._loss

@property
def final_state(self):
return self._final_state

@property
def lr(self):
return self._lr

@property
def train_op(self):
return self._train_op
Loading