Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion official/boosted_trees/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Note that the model_dir is cleaned up before every time training starts.

Model parameters can be adjusted by flags, like `--n_trees`, `--max_depth`, `--learning_rate` and so on. Check out the code for details.

The final accuacy will be around 74% and loss will be around 0.516 over the eval set, when trained with the default parameters.
The final accuracy will be around 74% and loss will be around 0.516 over the eval set, when trained with the default parameters.

By default, the first 1 million examples among 11 millions are used for training, and the last 1 million examples are used for evaluation.
The training/evaluation data can be selected as index ranges by flags `--train_start`, `--train_count`, `--eval_start`, `--eval_count`, etc.
Expand Down
59 changes: 32 additions & 27 deletions official/boosted_trees/data_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,59 +12,54 @@
from __future__ import division
from __future__ import print_function

import argparse
import gzip
import os
import sys
import tempfile

# pylint: disable=g-bad-import-order
import numpy as np
import pandas as pd
from six.moves import urllib
from absl import app as absl_app
from absl import flags
import tensorflow as tf

URL_ROOT = 'https://archive.ics.uci.edu/ml/machine-learning-databases/00280'
INPUT_FILE = 'HIGGS.csv.gz'
NPZ_FILE = 'HIGGS.csv.gz.npz' # numpy compressed file to contain 'data' array.
from official.utils.flags import core as flags_core


def parse_args():
"""Parses arguments and returns a tuple (known_args, unparsed_args)."""
parser = argparse.ArgumentParser()
parser.add_argument(
'--data_dir', type=str, default='/tmp/higgs_data',
help='Directory to download higgs dataset and store training/eval data.')
return parser.parse_known_args()
URL_ROOT = "https://archive.ics.uci.edu/ml/machine-learning-databases/00280"
INPUT_FILE = "HIGGS.csv.gz"
NPZ_FILE = "HIGGS.csv.gz.npz" # numpy compressed file to contain "data" array.


def _download_higgs_data_and_save_npz(data_dir):
"""Download higgs data and store as a numpy compressed file."""
input_url = os.path.join(URL_ROOT, INPUT_FILE)
np_filename = os.path.join(data_dir, NPZ_FILE)
if tf.gfile.Exists(np_filename):
raise ValueError('data_dir already has the processed data file: {}'.format(
raise ValueError("data_dir already has the processed data file: {}".format(
np_filename))
if not tf.gfile.Exists(data_dir):
tf.gfile.MkDir(data_dir)
# 2.8 GB to download.
try:
print('Data downloading..')
tf.logging.info("Data downloading...")
temp_filename, _ = urllib.request.urlretrieve(input_url)

# Reading and parsing 11 million csv lines takes 2~3 minutes.
print('Data processing.. taking multiple minutes..')
data = pd.read_csv(
temp_filename,
dtype=np.float32,
names=['c%02d' % i for i in range(29)] # label + 28 features.
).as_matrix()
tf.logging.info("Data processing... taking multiple minutes...")
with gzip.open(temp_filename, "rb") as csv_file:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my learning, pandas supports reading from .gz directly. Do we prefer to use explicitly gzip?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out! It's strange then, as when I tested the original code, I got the following error:

pandas.errors.ParserError: Error tokenizing data. C error: Buffer overflow caught - possible malformed input file.

That's why we explicitly gzip it here. Any idea on the issue?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe related to pandas version, but as gzip works, I think this change is fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which version are you using? I use 0.22.0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. the same 0.22.0. Do you get errors when running locally in virtualenv? or in travis or whatever?
FYI, I'm using Linux with virtualenv (python 2.7.13 numpy 1.14.3).
I ran it just now and confirmed pd.read_csv() reads and processes the csv.gz file properly..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, I see the problem. I ran it with python3. When I test it with python2, it works well as yours. So I will just keep gzip explicitly for py2 and py3 compatibility. Thanks a lot! :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks for the fix!

data = pd.read_csv(
csv_file,
dtype=np.float32,
names=["c%02d" % i for i in range(29)] # label + 28 features.
).as_matrix()
finally:
os.remove(temp_filename)
tf.gfile.Remove(temp_filename)

# Writing to temporary location then copy to the data_dir (0.8 GB).
f = tempfile.NamedTemporaryFile()
np.savez_compressed(f, data=data)
tf.gfile.Copy(f.name, np_filename)
print('Data saved to: {}'.format(np_filename))
tf.logging.info("Data saved to: {}".format(np_filename))


def main(unused_argv):
Expand All @@ -73,6 +68,16 @@ def main(unused_argv):
_download_higgs_data_and_save_npz(FLAGS.data_dir)


if __name__ == '__main__':
FLAGS, unparsed = parse_args()
tf.app.run(argv=[sys.argv[0]] + unparsed)
def define_data_download_flags():
"""Add flags specifying data download arguments."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to ourselves: we should consider having a flags_core fn specifically for download module flags, as I think we now have several separate data_dir definitions. No need to solve here though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! @robieta Maybe we should add one in utils/flags?

flags.DEFINE_string(
name="data_dir", default="/tmp/higgs_data",
help=flags_core.help_wrap(
"Directory to download higgs dataset and store training/eval data."))


if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
define_data_download_flags()
FLAGS = flags.FLAGS
absl_app.run(main)
139 changes: 83 additions & 56 deletions official/boosted_trees/train_higgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,64 +29,44 @@
from __future__ import division
from __future__ import print_function

import argparse
import os
import sys

# pylint: disable=g-bad-import-order
import numpy as np
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import order seems wrong. Numpy should be below, and we need an enable= statement as well, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got lint errors in Kokoro checking if numpy goes after absl. :(

from absl import app as absl_app
from absl import flags
import numpy as np # pylint: disable=wrong-import-order
import tensorflow as tf # pylint: disable=wrong-import-order
import tensorflow as tf
# pylint: enable=g-bad-import-order

from official.utils.flags import core as flags_core
from official.utils.flags._conventions import help_wrap
from official.utils.logs import logger

NPZ_FILE = "HIGGS.csv.gz.npz" # numpy compressed file containing "data" array

NPZ_FILE = 'HIGGS.csv.gz.npz' # numpy compressed file containing 'data' array


def define_train_higgs_flags():
"""Add tree related flags as well as training/eval configuration."""
flags_core.define_base(stop_threshold=False, batch_size=False, num_gpu=False)
flags.adopt_module_key_flags(flags_core)

flags.DEFINE_integer(
name='train_start', default=0,
help=help_wrap('Start index of train examples within the data.'))
flags.DEFINE_integer(
name='train_count', default=1000000,
help=help_wrap('Number of train examples within the data.'))
flags.DEFINE_integer(
name='eval_start', default=10000000,
help=help_wrap('Start index of eval examples within the data.'))
flags.DEFINE_integer(
name='eval_count', default=1000000,
help=help_wrap('Number of eval examples within the data.'))

flags.DEFINE_integer(
'n_trees', default=100, help=help_wrap('Number of trees to build.'))
flags.DEFINE_integer(
'max_depth', default=6, help=help_wrap('Maximum depths of each tree.'))
flags.DEFINE_float(
'learning_rate', default=0.1,
help=help_wrap('Maximum depths of each tree.'))

flags_core.set_defaults(data_dir='/tmp/higgs_data',
model_dir='/tmp/higgs_model')

def read_higgs_data(data_dir, train_start, train_count, eval_start, eval_count):
"""Reads higgs data from csv and returns train and eval data.

Args:
data_dir: A string, the directory of higgs dataset.
train_start: An integer, the start index of train examples within the data.
train_count: An integer, the number of train examples within the data.
eval_start: An integer, the start index of eval examples within the data.
eval_count: An integer, the number of eval examples within the data.

def read_higgs_data(data_dir, train_start, train_count, eval_start, eval_count):
"""Reads higgs data from csv and returns train and eval data."""
Returns:
Numpy array of train data and eval data.
"""
npz_filename = os.path.join(data_dir, NPZ_FILE)
try:
# gfile allows numpy to read data from network data sources as well.
with tf.gfile.Open(npz_filename, 'rb') as npz_file:
with tf.gfile.Open(npz_filename, "rb") as npz_file:
with np.load(npz_file) as npz:
data = npz['data']
data = npz["data"]
except Exception as e:
raise RuntimeError(
'Error loading data; use data_download.py to prepare the data:\n{}: {}'
"Error loading data; use data_download.py to prepare the data:\n{}: {}"
.format(type(e).__name__, e))
return (data[train_start:train_start+train_count],
data[eval_start:eval_start+eval_count])
Expand All @@ -105,18 +85,18 @@ def make_inputs_from_np_arrays(features_np, label_np):
as a single tensor. Don't use batch.

Args:
features_np: a numpy ndarray (shape=[batch_size, num_features]) for
features_np: A numpy ndarray (shape=[batch_size, num_features]) for
float32 features.
label_np: a numpy ndarray (shape=[batch_size, 1]) for labels.
label_np: A numpy ndarray (shape=[batch_size, 1]) for labels.

Returns:
input_fn: a function returning a Dataset of feature dict and label.
feature_column: a list of tf.feature_column.BucketizedColumn.
input_fn: A function returning a Dataset of feature dict and label.
feature_column: A list of tf.feature_column.BucketizedColumn.
"""
num_features = features_np.shape[1]
features_np_list = np.split(features_np, num_features, axis=1)
# 1-based feature names.
feature_names = ['feature_%02d' % (i + 1) for i in range(num_features)]
feature_names = ["feature_%02d" % (i + 1) for i in range(num_features)]

# Create source feature_columns and bucketized_columns.
def get_bucket_boundaries(feature):
Expand Down Expand Up @@ -155,16 +135,16 @@ def make_eval_inputs_from_np_arrays(features_np, label_np):
num_features = features_np.shape[1]
features_np_list = np.split(features_np, num_features, axis=1)
# 1-based feature names.
feature_names = ['feature_%02d' % (i + 1) for i in range(num_features)]
feature_names = ["feature_%02d" % (i + 1) for i in range(num_features)]

def input_fn():
features = {
feature_name: tf.constant(features_np_list[i])
for i, feature_name in enumerate(feature_names)
}
return tf.data.Dataset.zip(
(tf.data.Dataset.from_tensor_slices(features),
tf.data.Dataset.from_tensor_slices(label_np),)).batch(1000)
return tf.data.Dataset.zip((
tf.data.Dataset.from_tensor_slices(features),
tf.data.Dataset.from_tensor_slices(label_np),)).batch(1000)

return input_fn

Expand All @@ -175,22 +155,37 @@ def train_boosted_trees(flags_obj):
Args:
flags_obj: An object containing parsed flag values.
"""

# Clean up the model directory if present.
if tf.gfile.Exists(flags_obj.model_dir):
tf.gfile.DeleteRecursively(flags_obj.model_dir)
print('## data loading..')
tf.logging.info("## Data loading...")
train_data, eval_data = read_higgs_data(
flags_obj.data_dir, flags_obj.train_start, flags_obj.train_count,
flags_obj.eval_start, flags_obj.eval_count)
print('## data loaded; train: {}{}, eval: {}{}'.format(
tf.logging.info("## Data loaded; train: {}{}, eval: {}{}".format(
train_data.dtype, train_data.shape, eval_data.dtype, eval_data.shape))
# data consists of one label column and 28 feature columns following.

# Data consists of one label column followed by 28 feature columns.
train_input_fn, feature_columns = make_inputs_from_np_arrays(
features_np=train_data[:, 1:], label_np=train_data[:, 0:1])
eval_input_fn = make_eval_inputs_from_np_arrays(
features_np=eval_data[:, 1:], label_np=eval_data[:, 0:1])
print('## features prepared. training starts..')
tf.logging.info("## Features prepared. Training starts...")

# Create benchmark logger to log info about the training and metric values
run_params = {
"train_start": flags_obj.train_start,
"train_count": flags_obj.train_count,
"eval_start": flags_obj.eval_start,
"eval_count": flags_obj.eval_count,
"n_trees": flags_obj.n_trees,
"max_depth": flags_obj.max_depth,
}
benchmark_logger = logger.config_benchmark_logger(flags_obj)
benchmark_logger.log_run_info(
model_name="boosted_trees",
dataset_name="higgs",
run_params=run_params)

# Though BoostedTreesClassifier is under tf.estimator, faster in-memory
# training is yet provided as a contrib library.
Expand All @@ -203,7 +198,9 @@ def train_boosted_trees(flags_obj):
learning_rate=flags_obj.learning_rate)

# Evaluation.
eval_result = classifier.evaluate(eval_input_fn)
eval_results = classifier.evaluate(eval_input_fn)
# Benchmark the evaluation results
benchmark_logger.log_evaluation_result(eval_results)

# Exporting the savedmodel.
if flags_obj.export_dir is not None:
Expand All @@ -216,7 +213,37 @@ def main(_):
train_boosted_trees(flags.FLAGS)


if __name__ == '__main__':
def define_train_higgs_flags():
"""Add tree related flags as well as training/eval configuration."""
flags_core.define_base(stop_threshold=False, batch_size=False, num_gpu=False)
flags.adopt_module_key_flags(flags_core)

flags.DEFINE_integer(
name="train_start", default=0,
help=help_wrap("Start index of train examples within the data."))
flags.DEFINE_integer(
name="train_count", default=1000000,
help=help_wrap("Number of train examples within the data."))
flags.DEFINE_integer(
name="eval_start", default=10000000,
help=help_wrap("Start index of eval examples within the data."))
flags.DEFINE_integer(
name="eval_count", default=1000000,
help=help_wrap("Number of eval examples within the data."))

flags.DEFINE_integer(
"n_trees", default=100, help=help_wrap("Number of trees to build."))
flags.DEFINE_integer(
"max_depth", default=6, help=help_wrap("Maximum depths of each tree."))
flags.DEFINE_float(
"learning_rate", default=0.1,
help=help_wrap("The learning rate."))

flags_core.set_defaults(data_dir="/tmp/higgs_data",
model_dir="/tmp/higgs_model")


if __name__ == "__main__":
# Training progress and eval results are shown as logging.INFO; so enables it.
tf.logging.set_verbosity(tf.logging.INFO)
define_train_higgs_flags()
Expand Down
Loading