Permalink
Browse files

Add wide_deep to the official models (#2554)

  • Loading branch information...
nealwu committed Oct 19, 2017
1 parent fb12693 commit b4cf230287c534bac10c78b25c4412ff209cb417
@@ -1,4 +1,4 @@
official/* @nealwu @jhseu @itsmeolivia
official/* @nealwu @k-w-w @jhseu @itsmeolivia
research/adversarial_crypto/* @dave-andersen
research/adversarial_text/* @rsepassi
research/adv_imagenet_models/* @AlexeyKurakin
@@ -8,6 +8,8 @@ Below is the list of models contained in the garden:

[resnet](resnet): A deep residual network that can be used to classify both CIFAR-10 and ImageNet's dataset of 1000 classes.

[wide_deep](wide_deep): A model that combines a wide model and deep network to classify census income data.

More models to come!

If you would like to make any fixes or improvements to the models, please [submit a pull request](https://github.com/tensorflow/models/compare).
@@ -0,0 +1,52 @@
# Predicting Income with the Census Income Dataset
## Overview
The [Census Income Data Set](https://archive.ics.uci.edu/ml/datasets/Census+Income) contains over 48,000 samples with attributes including age, occupation, education, and income (a binary label, either `>50K` or `<=50K`). The dataset is split into roughly 32,000 training and 16,000 testing samples.

Here, we use the [wide and deep model](https://research.googleblog.com/2016/06/wide-deep-learning-better-together-with.html) to predict the income labels. The **wide model** is able to memorize interactions with data with a large number of features but not able to generalize these learned interactions on new data. The **deep model** generalizes well but is unable to learn exceptions within the data. The **wide and deep model** combines the two models and is able to generalize while learning exceptions.

For the purposes of this example code, the Census Income Data Set was chosen to allow the model to train in a reasonable amount of time. You'll notice that the deep model performs almost as well as the wide and deep model on this dataset. The wide and deep model truly shines on larger data sets with high-cardinality features, where each feature has millions/billions of unique possible values (which is the specialty of the wide model).

---

The code sample in this directory uses the high level `tf.estimator.Estimator` API. This API is great for fast iteration and quickly adapting models to your own datasets without major code overhauls. It allows you to move from single-worker training to distributed training, and it makes it easy to export model binaries for prediction.

The input function for the `Estimator` uses `tf.contrib.data.TextLineDataset`, which creates a `Dataset` object. The `Dataset` API makes it easy to apply transformations (map, batch, shuffle, etc.) to the data. [Read more here](https://www.tensorflow.org/programmers_guide/datasets).

The `Estimator` and `Dataset` APIs are both highly encouraged for fast development and efficient training.

## Running the code
### Setup
The [Census Income Data Set](https://archive.ics.uci.edu/ml/datasets/Census+Income) that this sample uses for training is hosted by the [UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/). We have provided a script that downloads and cleans the necessary files.

```
python data_download.py
```

This will download the files to `/tmp/census_data`. To change the directory, set the `--data_dir` flag.

### Training
You can run the code locally as follows:

```
python wide_deep.py
```

The model is saved to `/tmp/census_model` by default, which can be changed using the `--model_dir` flag.

To run the *wide* or *deep*-only models, set the `--model_type` flag to `wide` or `deep`. Other flags are configurable as well; see `wide_deep.py` for details.

The final accuracy should be over 83% with any of the three model types.

### TensorBoard

Run TensorBoard to inspect the details about the graph and training progression.

```
tensorboard --logdir=/tmp/census_model
```

## Additional Links

If you are interested in distributed training, take a look at [Distributed TensorFlow](https://www.tensorflow.org/deploy/distributed).

You can also [run this model on Cloud ML Engine](https://cloud.google.com/ml-engine/docs/getting-started-training-prediction), which provides [hyperparameter tuning](https://cloud.google.com/ml-engine/docs/getting-started-training-prediction#hyperparameter_tuning) to maximize your model's results and enables [deploying your model for prediction](https://cloud.google.com/ml-engine/docs/getting-started-training-prediction#deploy_a_model_to_support_prediction).
No changes.
@@ -0,0 +1,69 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Download and clean the Census Income Dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os

from six.moves import urllib
import tensorflow as tf

DATA_URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult'
TRAINING_FILE = 'adult.data'
TRAINING_URL = '%s/%s' % (DATA_URL, TRAINING_FILE)
EVAL_FILE = 'adult.test'
EVAL_URL = '%s/%s' % (DATA_URL, EVAL_FILE)

parser = argparse.ArgumentParser()
parser.add_argument(
'--data_dir', type=str, default='/tmp/census_data',
help='Directory to download census data')


def _download_and_clean_file(filename, url):
"""Downloads data from url, and makes changes to match the CSV format."""
temp_file, _ = urllib.request.urlretrieve(url)
with tf.gfile.Open(temp_file, 'r') as temp_eval_file:
with tf.gfile.Open(filename, 'w') as eval_file:
for line in temp_eval_file:
line = line.strip()
line = line.replace(', ', ',')
if not line or ',' not in line:
continue
if line[-1] == '.':
line = line[:-1]
line += '\n'
eval_file.write(line)
tf.gfile.Remove(temp_file)


def main(unused_argv):
if not tf.gfile.Exists(FLAGS.data_dir):
tf.gfile.MkDir(FLAGS.data_dir)

training_file_path = os.path.join(FLAGS.data_dir, TRAINING_FILE)
_download_and_clean_file(training_file_path, TRAINING_URL)

eval_file_path = os.path.join(FLAGS.data_dir, EVAL_FILE)
_download_and_clean_file(eval_file_path, EVAL_URL)


if __name__ == '__main__':
FLAGS = parser.parse_args()
tf.app.run()
@@ -0,0 +1,231 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Example code for TensorFlow Wide & Deep Tutorial using TF.Learn API."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import shutil
import sys

import tensorflow as tf

_CSV_COLUMNS = [
'age', 'workclass', 'fnlwgt', 'education', 'education_num',
'marital_status', 'occupation', 'relationship', 'race', 'gender',
'capital_gain', 'capital_loss', 'hours_per_week', 'native_country',
'income_bracket'
]

_CSV_COLUMN_DEFAULTS = [[0], [''], [0], [''], [0], [''], [''], [''], [''], [''],
[0], [0], [0], [''], ['']]

parser = argparse.ArgumentParser()

parser.add_argument(
'--model_dir', type=str, default='/tmp/census_model',
help='Base directory for the model.')

parser.add_argument(
'--model_type', type=str, default='wide_deep',
help="Valid model types: {'wide', 'deep', 'wide_deep'}.")

parser.add_argument(
'--train_epochs', type=int, default=20, help='Number of training epochs.')

parser.add_argument(
'--epochs_per_eval', type=int, default=2,
help='The number of training epochs to run between evaluations.')

parser.add_argument(
'--batch_size', type=int, default=40, help='Number of examples per batch.')

parser.add_argument(
'--train_data', type=str, default='/tmp/census_data/adult.data',
help='Path to the training data.')

parser.add_argument(
'--test_data', type=str, default='/tmp/census_data/adult.test',
help='Path to the test data.')


def build_model_columns():
"""Builds a set of wide and deep feature columns."""
# Continuous columns
age = tf.feature_column.numeric_column('age')
education_num = tf.feature_column.numeric_column('education_num')
capital_gain = tf.feature_column.numeric_column('capital_gain')
capital_loss = tf.feature_column.numeric_column('capital_loss')
hours_per_week = tf.feature_column.numeric_column('hours_per_week')

education = tf.feature_column.categorical_column_with_vocabulary_list(
'education', [
'Bachelors', 'HS-grad', '11th', 'Masters', '9th', 'Some-college',
'Assoc-acdm', 'Assoc-voc', '7th-8th', 'Doctorate', 'Prof-school',
'5th-6th', '10th', '1st-4th', 'Preschool', '12th'])

marital_status = tf.feature_column.categorical_column_with_vocabulary_list(
'marital_status', [
'Married-civ-spouse', 'Divorced', 'Married-spouse-absent',
'Never-married', 'Separated', 'Married-AF-spouse', 'Widowed'])

relationship = tf.feature_column.categorical_column_with_vocabulary_list(
'relationship', [
'Husband', 'Not-in-family', 'Wife', 'Own-child', 'Unmarried',
'Other-relative'])

workclass = tf.feature_column.categorical_column_with_vocabulary_list(
'workclass', [
'Self-emp-not-inc', 'Private', 'State-gov', 'Federal-gov',
'Local-gov', '?', 'Self-emp-inc', 'Without-pay', 'Never-worked'])

# To show an example of hashing:
occupation = tf.feature_column.categorical_column_with_hash_bucket(
'occupation', hash_bucket_size=1000)

# Transformations.
age_buckets = tf.feature_column.bucketized_column(
age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])

# Wide columns and deep columns.
base_columns = [
education, marital_status, relationship, workclass, occupation,
age_buckets,
]

crossed_columns = [
tf.feature_column.crossed_column(
['education', 'occupation'], hash_bucket_size=1000),
tf.feature_column.crossed_column(
[age_buckets, 'education', 'occupation'], hash_bucket_size=1000),
]

wide_columns = base_columns + crossed_columns

deep_columns = [
age,
education_num,
capital_gain,
capital_loss,
hours_per_week,
tf.feature_column.indicator_column(workclass),
tf.feature_column.indicator_column(education),
tf.feature_column.indicator_column(marital_status),
tf.feature_column.indicator_column(relationship),
# To show an example of embedding
tf.feature_column.embedding_column(occupation, dimension=8),
]

return wide_columns, deep_columns


def build_estimator(model_dir, model_type):
"""Build an estimator appropriate for the given model type."""
wide_columns, deep_columns = build_model_columns()
hidden_units = [100, 75, 50, 25]

# Create a tf.estimator.RunConfig to ensure the model is run on CPU, which
# trains faster than GPU for this model.
run_config = tf.estimator.RunConfig().replace(
session_config=tf.ConfigProto(device_count={'GPU': 0}))

if model_type == 'wide':
return tf.estimator.LinearClassifier(
model_dir=model_dir,
feature_columns=wide_columns,
config=run_config)
elif model_type == 'deep':
return tf.estimator.DNNClassifier(
model_dir=model_dir,
feature_columns=deep_columns,
hidden_units=hidden_units,
config=run_config)
else:
return tf.estimator.DNNLinearCombinedClassifier(
model_dir=model_dir,
linear_feature_columns=wide_columns,
dnn_feature_columns=deep_columns,
dnn_hidden_units=hidden_units,
config=run_config)


def input_fn(data_file, num_epochs, shuffle, batch_size):
"""Generate an input function for the Estimator."""
assert tf.gfile.Exists(data_file), (
'%s not found. Please make sure you have either run data_download.py or '
'set both arguments --train_data and --test_data.' % data_file)
def parse_csv(value):
print('Parsing', data_file)
columns = tf.decode_csv(value, record_defaults=_CSV_COLUMN_DEFAULTS)
features = dict(zip(_CSV_COLUMNS, columns))
labels = features.pop('income_bracket')
return features, tf.equal(labels, '>50K')

# Extract lines from input files using the Dataset API.
dataset = tf.contrib.data.TextLineDataset(data_file)
dataset = dataset.map(parse_csv, num_threads=5)

# Apply transformations to the Dataset
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(num_epochs)

# Input function that is called by the Estimator
def _input_fn():
if shuffle:
# Apply shuffle transformation to re-shuffle the dataset in each call.
shuffled_dataset = dataset.shuffle(buffer_size=100000)
iterator = shuffled_dataset.make_one_shot_iterator()
else:
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
return _input_fn


def main(unused_argv):
# Clean up the model directory if present
shutil.rmtree(FLAGS.model_dir, ignore_errors=True)

model = build_estimator(FLAGS.model_dir, FLAGS.model_type)

# Set up input function generators for the train and test data files.
train_input_fn = input_fn(
data_file=FLAGS.train_data,
num_epochs=FLAGS.epochs_per_eval,
shuffle=True,
batch_size=FLAGS.batch_size)
eval_input_fn = input_fn(
data_file=FLAGS.test_data,
num_epochs=1,
shuffle=False,
batch_size=FLAGS.batch_size)

# Train and evaluate the model every `FLAGS.epochs_per_eval` epochs.
for n in range(FLAGS.train_epochs // FLAGS.epochs_per_eval):
model.train(input_fn=train_input_fn)
results = model.evaluate(input_fn=eval_input_fn)

# Display evaluation metrics
print('Results at epoch', (n + 1) * FLAGS.epochs_per_eval)
print('-' * 30)
for key in sorted(results):
print('%s: %s' % (key, results[key]))


if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
@@ -0,0 +1,20 @@
39,State-gov,77516,Bachelors,13,Never-married,Adm-clerical,Not-in-family,,,2174,0,40,,<=50K
50,Self-emp-not-inc,83311,Bachelors,13,Married-civ-spouse,Exec-managerial,Husband,,,0,0,13,,<=50K
38,Private,215646,HS-grad,9,Divorced,Handlers-cleaners,Not-in-family,,,0,0,40,,<=50K
53,Private,234721,11th,7,Married-civ-spouse,Handlers-cleaners,Husband,,,0,0,40,,<=50K
28,Private,338409,Bachelors,13,Married-civ-spouse,Prof-specialty,Wife,,,0,0,40,,<=50K
37,Private,284582,Masters,14,Married-civ-spouse,Exec-managerial,Wife,,,0,0,40,,<=50K
49,Private,160187,9th,5,Married-spouse-absent,Other-service,Not-in-family,,,0,0,16,,<=50K
52,Self-emp-not-inc,209642,HS-grad,9,Married-civ-spouse,Exec-managerial,Husband,,,0,0,45,,>50K
31,Private,45781,Masters,14,Never-married,Prof-specialty,Not-in-family,,,14084,0,50,,>50K
42,Private,159449,Bachelors,13,Married-civ-spouse,Exec-managerial,Husband,,,5178,0,40,,>50K
37,Private,280464,Some-college,10,Married-civ-spouse,Exec-managerial,Husband,,,0,0,80,,>50K
30,State-gov,141297,Bachelors,13,Married-civ-spouse,Prof-specialty,Husband,,,0,0,40,,>50K
23,Private,122272,Bachelors,13,Never-married,Adm-clerical,Own-child,,,0,0,30,,<=50K
32,Private,205019,Assoc-acdm,12,Never-married,Sales,Not-in-family,,,0,0,50,,<=50K
40,Private,121772,Assoc-voc,11,Married-civ-spouse,Craft-repair,Husband,,,0,0,40,,>50K
34,Private,245487,7th-8th,4,Married-civ-spouse,Transport-moving,Husband,,,0,0,45,,<=50K
25,Self-emp-not-inc,176756,HS-grad,9,Never-married,Farming-fishing,Own-child,,,0,0,35,,<=50K
32,Private,186824,HS-grad,9,Never-married,Machine-op-inspct,Unmarried,,,0,0,40,,<=50K
38,Private,28887,11th,7,Married-civ-spouse,Sales,Husband,,,0,0,50,,<=50K
43,Self-emp-not-inc,292175,Masters,14,Divorced,Exec-managerial,Unmarried,,,0,0,45,,>50K
Oops, something went wrong.

0 comments on commit b4cf230

Please sign in to comment.