From 37f38b462e412f69ef0e11619803184329f55ff6 Mon Sep 17 00:00:00 2001 From: Amin Mantrach Date: Thu, 29 Mar 2018 19:23:14 -0700 Subject: [PATCH] add criteo example --- examples/criteo/README.md | 124 +++++++++++ examples/criteo/spark/__init__.py | 0 examples/criteo/spark/criteo_dist.py | 294 +++++++++++++++++++++++++ examples/criteo/spark/criteo_spark.py | 66 ++++++ examples/criteo/spark/requirements.txt | 5 + 5 files changed, 489 insertions(+) create mode 100644 examples/criteo/README.md create mode 100644 examples/criteo/spark/__init__.py create mode 100644 examples/criteo/spark/criteo_dist.py create mode 100644 examples/criteo/spark/criteo_spark.py create mode 100644 examples/criteo/spark/requirements.txt diff --git a/examples/criteo/README.md b/examples/criteo/README.md new file mode 100644 index 00000000..8f73967e --- /dev/null +++ b/examples/criteo/README.md @@ -0,0 +1,124 @@ +# Learning Click-Through Rate at Scale with Tensorflow on Spark + +## Introduction +This project consists of learning a click-throughrate model at scale using TensorflowOnSpark technology. +Criteo released a 1TB dataset: http://labs.criteo.com/2013/12/download-terabyte-click-logs/ +In order to promote Google cloud technology, Google published a solution to train a model at scale using there +proprietary platform : https://cloud.google.com/blog/big-data/2017/02/using-google-cloud-machine-learning-to-predict-clicks-at-scale + +Instead, we propose a solution based on open source technology that can be leveraged on any cloud, +or private cluster relying on spark. + +We demonstrate how Tensorflow on Spark (https://github.com/yahoo/TensorFlowOnSpark) can be used to reach the state of the art when it comes to predicting the proba of click at scale. +Notice that the goal here is not to produce the best pCTR predictor, but rather establish a open method that still reaches the best performance published so far on this dataset. +Hence, our solutions remains very simple, and rely solely on basic feature extraction, cross-features and hashing, the all trained on logistic regression. + +## Install and test TF on spark +Before making use of this code, please make sure you can install TF on spark on your cluster and +run the mnist example as illustrated here: +https://github.com/yahoo/TensorFlowOnSpark/wiki/GetStarted_YARN +By so doing, you should make sure that did set up the following variables correctly: + +``` +export JAVA_HOME= +export HADOOP_HOME= +export SPARK_HOME= +export HADOOP_HDFS_HOME= +export SPARK_HOME= +export PYTHON_ROOT=./Python +export PATH=${PATH}:${HADOOP_HOME}/bin:${SPARK_HOME}/bin:${HADOOP_HDFS_HOME}/bin:${SPARK_HOME}/bin:${PYTHON_ROOT}/bin +export PYSPARK_PYTHON=${PYTHON_ROOT}/bin/python +export SPARK_YARN_USER_ENV="PYSPARK_PYTHON=/usr/bin/python" +export QUEUE=default +export LIB_HDFS= +export LIB_JVM= +``` + +## Data set + +The raw data can be accessed here: http://labs.criteo.com/2013/12/download-terabyte-click-logs/ + +### Download the data set +``` +for i in 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23; do + curl -O http://azuremlsampleexperiments.blob.core.windows.net/criteo/day_${i}.gz + aws s3 mv day_${i}.gz s3://criteo-display-ctr-dataset/released/ +done +``` + +### Upload training data on your AWS s3 using Pig + +``` +%declare awskey yourkey +%declare awssecretkey yoursecretkey +SET mapred.output.compress 'true'; +SET mapred.output.compression.codec 'org.apache.hadoop.io.compress.BZip2Codec'; +train_data = load 's3n://${awskey}:${awssecretkey}@criteo-display-ctr-dataset/released/day_{[0-9],1[0-9],2[0-2]}.gz '; +train_data = FOREACH (GROUP train_data BY ROUND(10000* RANDOM()) PARALLEL 10000) GENERATE FLATTEN(train_data); +store train_data into 's3n://${awskey}:${awssecretkey}@criteo-display-ctr-dataset/data/training/' using PigStorage(); +``` +We here divide the training data in 10000 chunks, which will allow TFonSpark to reduce its memory usage. + +### Upload validation data on your AWS s3 using Pig +``` +%declare awskey yourkey +%declare awssecretkey yoursecretkey +SET mapred.output.compress 'true'; +SET mapred.output.compression.codec 'org.apache.hadoop.io.compress.BZip2Codec'; +train_data = load 's3n://${awskey}:${awssecretkey}@criteo-display-ctr-dataset/released/day_23.gz'; +train_data = FOREACH (GROUP train_data BY ROUND(100* RANDOM()) PARALLEL 100) GENERATE FLATTEN(train_data); +store train_data into 's3n://${awskey}:${awssecretkey}@criteo-display-ctr-dataset/data/validation' using PigStorage(); +``` + + + + + + +## Running the example + +Set up task variables +``` +export TRAINING_DATA=hdfs_path_to_training_data_directory +export VALIDATION_DATA=hdfs_path_to_validation_data_directory +export MODEL_OUTPUT=hdfs://default/tmp/criteo_ctr_prediction +``` +Run command: + +``` +${SPARK_HOME}/bin/spark-submit \ +--master yarn \ +--deploy-mode cluster \ +--queue ${QUEUE} \ +--num-executors 12 \ +--executor-memory 27G \ +--py-files TensorFlowOnSpark/tfspark.zip,TensorFlowOnSpark/examples/criteo/spark/criteo_dist.py \ +--conf spark.dynamicAllocation.enabled=false \ +--conf spark.yarn.maxAppAttempts=1 \ +--archives hdfs:///user/${USER}/Python.zip#Python \ +--conf spark.executorEnv.LD_LIBRARY_PATH="$LIB_HDFS:$LIB_JVM" \ +--conf spark.executorEnv.HADOOP_HDFS_HOME="$HADOOP_HDFS_HOME" \ +--conf spark.executorEnv.CLASSPATH="$($HADOOP_HOME/bin/hadoop classpath --glob):${CLASSPATH}" \ +TensorFlowOnSpark/examples/criteo/spark/criteo_spark.py \ +--mode train \ +--data ${TRAINING_DATA} \ +--validation ${VALIDATION_DATA} \ +--steps 1000000 \ +--model ${MODEL_OUTPUT} --tensorboard \ +--tensorboardlogdir ${MODEL_OUTPUT} +``` +## Tensorboard tracking: + +By connecting to the Web UI tracker of your application, +you be able to retrieve the tensorboard URL in the stdout of the driver: + +``` + TensorBoard running at: http://10.4.112.234:36911 +``` + +You can then track the training loss, and validation loss: + + +![Alt Text](resources/data/TensorBoard-TFonSpark-Criteo-04.png) + + diff --git a/examples/criteo/spark/__init__.py b/examples/criteo/spark/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/criteo/spark/criteo_dist.py b/examples/criteo/spark/criteo_dist.py new file mode 100644 index 00000000..f482d706 --- /dev/null +++ b/examples/criteo/spark/criteo_dist.py @@ -0,0 +1,294 @@ +# Copyright 2018 Criteo +# Licensed under the terms of the Apache 2.0 license. +# Please see LICENSE file in the project root for terms. +# Distributed Criteo Display CTR prediction on grid based on TensorFlow on Spark +# https://github.com/yahoo/TensorFlowOnSpark + +from __future__ import absolute_import +from __future__ import division +from __future__ import nested_scopes +from __future__ import print_function + +validation_file = None + + +def print_log(worker_num, arg): + print("{0}: {1}".format(worker_num, arg)) + + +def map_fun(args, ctx): + from datetime import datetime + import math + import tensorflow as tf + import numpy as np + import time + from sklearn.metrics import roc_auc_score + import mmh3 + + class CircularFile(object): + def __init__(self, filename): + self.filename = filename + self.file = None + + def readline(self): + if (self.file is None): + self.file = tf.gfile.GFile(self.filename, "r") + + p_line = self.file.readline() + + if p_line == "": + self.file.close() + self.file = tf.gfile.GFile(self.filename, "r") + p_line = self.file.readline() + return p_line + + def close(self): + self.file.close() + self.file = None + + + worker_num = ctx.worker_num + job_name = ctx.job_name + task_index = ctx.task_index + + + # Delay PS nodes a bit, since workers seem to reserve GPUs more quickly/reliably (w/o conflict) + if job_name == "ps": + time.sleep((worker_num + 1) * 5) + + vocabulary_size = 39 + # Feature indexes as defined in input file + INDEX_CAT_FEATURES = 13 + + # These parameters values have been selected for illustration purpose and have not been tuned. + learning_rate = 0.0005 + droupout_rate = 0.4 + NB_OF_HASHES_CAT = 2 ** 15 + NB_OF_HASHES_CROSS = 2 ** 15 + NB_BUCKETS = 40 + + boundaries_bucket = [1.5 ** j - 0.51 for j in range(NB_BUCKETS)] + # Same as in: + # [https://github.com/GoogleCloudPlatform/cloudml-samples/blob/c272e9f3bf670404fb1570698d8808ab62f0fc9a/criteo_tft/trainer/task.py#L163] + + nb_input_features = ((INDEX_CAT_FEATURES) * NB_BUCKETS) + ( + (vocabulary_size - INDEX_CAT_FEATURES) * NB_OF_HASHES_CAT) + NB_OF_HASHES_CROSS + + + batch_size = args.batch_size + + # Get TF cluster and server instances + cluster, server = ctx.start_cluster_server(1, args.rdma) + + + def get_index_bucket(feature_value): + """ + maps the input feature to a one hot encoding index + :param feature_value: the value of the feature + :return: the index of the one hot encoding that activates for the input value + """ + for index, boundary_value in enumerate(boundaries_bucket): + if feature_value < boundary_value: + return index + return index + + + def get_batch_validation(batch_size): + """ + :param batch_size: + :return: a list of read lines, each lines being a list of the features as read from the input file + """ + global validation_file + if validation_file is None: + validation_file = CircularFile(args.validation) + return [validation_file.readline().split('\t') for _ in range(batch_size)] + + def get_cross_feature_name(index, features): + if index < INDEX_CAT_FEATURES: + index_str = str(index) + "_" + str(get_index_bucket(int(features[index]))) + else: + index_str = str(index) + "_" + features[index] + + return index_str + + def get_next_batch(batch): + """ + maps the batch read from the input file to a data array, and a label array that are fed to + the tf placeholders + :param batch: + :return: + """ + data = np.zeros((batch_size, nb_input_features)) + labels = np.zeros(batch_size) + + index = 0 + while True: + + features = batch[index][1:] + + if len(features) != vocabulary_size: + continue + + # BUCKETIZE CONTINIOUS FEATURES + for f_index in range(0, INDEX_CAT_FEATURES ): + if features[f_index]: + bucket_index = get_index_bucket(int(features[f_index])) + bucket_number_index = f_index * NB_BUCKETS + bucket_index_offset = bucket_index + bucket_number_index + data[index, bucket_index_offset] = 1 + + # BUCKETIZE CATEGORY FEATURES + offset = INDEX_CAT_FEATURES * NB_BUCKETS + for f_index in range(INDEX_CAT_FEATURES, vocabulary_size): + if features[f_index]: + hash_index = mmh3.hash(features[f_index]) % NB_OF_HASHES_CAT + hash_number_index = (f_index - INDEX_CAT_FEATURES) * NB_OF_HASHES_CAT + offset + hash_index_offset = hash_index + hash_number_index + data[index, hash_index_offset] = 1 + + # BUCKETIZE CROSS CATEGORY AND CONTINIOUS + offset = INDEX_CAT_FEATURES * NB_BUCKETS + (vocabulary_size - INDEX_CAT_FEATURES) * NB_OF_HASHES_CAT + + for index_i in range(0, vocabulary_size-1): + for index_j in range(index_i + 1, vocabulary_size): + if features[index_i].rstrip() == '' or features[index_j].rstrip() == '': + continue + + index_str_i = get_cross_feature_name(index_i,features) + index_str_j = get_cross_feature_name(index_j,features) + + hash_index = mmh3.hash(index_str_i + "_" + index_str_j) % NB_OF_HASHES_CROSS + offset + data[index, hash_index] = 1 + + labels[index] = batch[index][0] + index += 1 + if index == batch_size: + break + + return data.astype(int), labels.astype(int) + + + + if job_name == "ps": + server.join() + elif job_name == "worker": + is_chiefing = (task_index == 0) + with tf.device(tf.train.replica_device_setter( + worker_device="/job:worker/task:%d" % task_index, + cluster=cluster)): + + def lineartf(x, droupout_rate, is_training, name=None, reuse=None, dropout=None): + """ + Apply a simple lineartf transformation A*x+b to the input + """ + n_output = 1 + if len(x.get_shape()) != 2: + x = tf.contrib.layers.flatten(x) + + n_input = x.get_shape().as_list()[1] + + with tf.variable_scope(name, reuse=reuse): + W = tf.get_variable( + name='W', + shape=[n_input, n_output], + dtype=tf.float32, + initializer=tf.contrib.layers.xavier_initializer()) + + b = tf.get_variable( + name='b', + shape=[n_output], + dtype=tf.float32, + initializer=tf.constant_initializer(0.0)) + + h = tf.nn.bias_add( + name='h', + value=tf.matmul(x, W), + bias=b) + + if dropout: + h = tf.cond(is_training, lambda: tf.layers.dropout(h, rate=droupout_rate, training=True), + lambda: tf.layers.dropout(h, rate=0.0, training=True)) + + return h, W + + is_training = tf.placeholder(tf.bool, shape=()) + input_features = tf.placeholder(tf.float32, [None, nb_input_features], name="input_features") + input_features_lineartf, _ = lineartf(input_features, droupout_rate=droupout_rate, + name='linear_layer', + is_training=is_training, + dropout=None) + + y_true = tf.placeholder(tf.float32, shape=None) + y_prediction = input_features_lineartf + pCTR = tf.nn.sigmoid(y_prediction, name="pCTR") + global_step = tf.Variable(0) + cross_entropy = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_true, logits=y_prediction)) + tf.summary.scalar('cross_entropy', cross_entropy) + adam_train_step = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cross_entropy, + global_step=global_step) + + saver = tf.train.Saver() + summary_op = tf.summary.merge_all() + init_op = tf.global_variables_initializer() + + logdir = ctx.absolute_path(args.model) + print("Tensorflow model path: {0}".format(logdir)) + + if job_name == "worker" and is_chiefing: + summary_writer = tf.summary.FileWriter(logdir + "/train", graph=tf.get_default_graph()) + summary_val_writer = tf.summary.FileWriter(logdir + "/validation", graph=tf.get_default_graph()) + + options = dict(is_chief=is_chiefing, + logdir=logdir, + summary_op=None, + saver=saver, + global_step=global_step, + stop_grace_secs=300, + save_model_secs=0) + + if args.mode == "train": + options['save_model_secs'] = 120 + options['init_op'] = init_op + options['summary_writer'] = None + + sv = tf.train.Supervisor(**options) + + with sv.managed_session(server.target) as sess: + + print("{0} session ready".format(datetime.now().isoformat())) + + tf_feed = ctx.get_data_feed(args.mode == "train") + step = 0 + while not sv.should_stop() and not tf_feed.should_stop() and step < args.steps: + batch_data, batch_labels = get_next_batch(tf_feed.next_batch(batch_size)) + + if len(batch_data) > 0: + + if args.mode == "train": + + if sv.is_chief: + # Evaluate current state of the model on next batch of validation + batch_val = get_batch_validation(batch_size) + batch_data, batch_labels = get_next_batch(batch_val) + feed = {input_features: batch_data, y_true: batch_labels, is_training: False} + logloss, summary, step = sess.run([cross_entropy, summary_op, global_step], feed_dict=feed) + summary_val_writer.add_summary(summary, step) + print("validation loss: {0}".format(logloss)) + + feed = {input_features: batch_data, y_true: batch_labels, is_training: True} + _, logloss, summary, step = sess.run([adam_train_step, cross_entropy, summary_op, global_step], + feed_dict=feed) + + else: + feed = {input_features: batch_data, y_true: batch_labels, is_training: False} + yscore = sess.run(pCTR, feed_dict=feed) + tf_feed.batch_results(yscore) + + if sv.should_stop() or step >= args.steps: + tf_feed.terminate() + if is_chiefing: + summary_writer.close() + summary_val_writer.close() + + print("{0} stopping supervisor".format(datetime.now().isoformat())) + sv.stop() diff --git a/examples/criteo/spark/criteo_spark.py b/examples/criteo/spark/criteo_spark.py new file mode 100644 index 00000000..aad2bd66 --- /dev/null +++ b/examples/criteo/spark/criteo_spark.py @@ -0,0 +1,66 @@ +# Copyright 2018 Criteo +# Licensed under the terms of the Apache 2.0 license. +# Please see LICENSE file in the project root for terms. + +# Distributed Criteo Display CTR prediction on grid based on TensorFlow on Spark +# https://github.com/yahoo/TensorFlowOnSpark + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from pyspark.context import SparkContext +from pyspark.conf import SparkConf + +import argparse +from datetime import datetime + + + +from tensorflowonspark import TFCluster + + +import criteo_dist + + +if __name__ == "__main__": + sc = SparkContext(conf=SparkConf().setAppName("criteo_spark")) + executors = sc._conf.get("spark.executor.instances") + if executors is None: + raise Exception("Could not retrieve the number of executors from the SparkContext") + num_executors = int(executors) + num_ps = 1 + + parser = argparse.ArgumentParser() + parser.add_argument("-b", "--batch_size", help="number of records per batch", type=int, default=100) + parser.add_argument("-e", "--epochs", help="number of epochs", type=int, default=1) + parser.add_argument("-i", "--data", help="HDFS path to data in parallelized format") + parser.add_argument("-m", "--model", help="HDFS path to save/load model during train/inference", default="criteo_model") + parser.add_argument("-v", "--validation", help="HDFS path to validation data") + + parser.add_argument("-n", "--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors) + parser.add_argument("-o", "--output", help="HDFS path to save test/inference output", default="predictions") + parser.add_argument("-r", "--readers", help="number of reader/enqueue threads", type=int, default=1) + parser.add_argument("-s", "--steps", help="maximum number of steps", type=int, default=1000) + parser.add_argument("-tb", "--tensorboard", help="launch tensorboard process", action="store_true") + parser.add_argument("-X", "--mode", help="train|inference", default="train") + parser.add_argument("-c", "--rdma", help="use rdma connection", default=False) + parser.add_argument("-tbld", "--tensorboardlogdir", + help="Tensorboard log directory. It should on hdfs. Thus, it must be prefixed with hdfs://default") + + args = parser.parse_args() + print("args:", args) + + print("{0} ===== Start".format(datetime.now().isoformat())) + + dataRDD = sc.textFile(args.data).map(lambda ln: [x for x in ln.split('\t')]) + + cluster = TFCluster.run(sc, criteo_dist.map_fun, args, args.cluster_size, num_ps, args.tensorboard, + TFCluster.InputMode.SPARK, log_dir=args.model) + if args.mode == "train": + cluster.train(dataRDD, args.epochs) + else: + labelRDD = cluster.inference(dataRDD) + labelRDD.saveAsTextFile(args.output) + cluster.shutdown() + print("{0} ===== Stop".format(datetime.now().isoformat())) \ No newline at end of file diff --git a/examples/criteo/spark/requirements.txt b/examples/criteo/spark/requirements.txt new file mode 100644 index 00000000..4603de1e --- /dev/null +++ b/examples/criteo/spark/requirements.txt @@ -0,0 +1,5 @@ +mmh3==2.5.1 +tensorflow=1.2.1 +numpy==1.13.1 +scipy==1.0.0 +scikit-learn==0.19.1