Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions official/README-TPU.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,7 @@
* [shapemask](vision/detection): An object detection and instance segmentation model using shape priors. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/ZbXgVoc6Rf6mBRlPj0JpLA).

## Recommendation
* [dlrm](recommendation/ranking): [Deep Learning Recommendation Model for
Personalization and Recommendation Systems](https://arxiv.org/abs/1906.00091).
* [dcn v2](recommendation/ranking): [Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535).
* [ncf](recommendation): Neural Collaborative Filtering. See [Tensorboard.dev training metrics](https://tensorboard.dev/experiment/0k3gKjZlR1ewkVTRyLB6IQ).
8 changes: 5 additions & 3 deletions official/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ In the near future, we will add:

### Recommendation

| Model | Reference (Paper) |
|-------|-------------------|
| [NCF](recommendation) | [Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031) |
Model | Reference (Paper)
-------------------------------- | -----------------
[DLRM](recommendation/ranking) | [Deep Learning Recommendation Model for Personalization and Recommendation Systems](https://arxiv.org/abs/1906.00091)
[DCN v2](recommendation/ranking) | [Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535)
[NCF](recommendation) | [Neural Collaborative Filtering](https://arxiv.org/abs/1708.05031)

## How to get started with the official models

Expand Down
39 changes: 22 additions & 17 deletions official/core/train_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,31 @@
"""TFM common training driver library."""
# pytype: disable=attribute-error
import os
from typing import Any, Mapping, Tuple
from typing import Any, Mapping, Tuple, Optional

# Import libraries
from absl import logging
import orbit
import tensorflow as tf

from official.core import base_task
from official.core import base_trainer
from official.core import config_definitions
from official.core import train_utils

BestCheckpointExporter = train_utils.BestCheckpointExporter
maybe_create_best_ckpt_exporter = train_utils.maybe_create_best_ckpt_exporter


def run_experiment(distribution_strategy: tf.distribute.Strategy,
task: base_task.Task,
mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True) \
-> Tuple[tf.keras.Model, Mapping[str, Any]]:
def run_experiment(
distribution_strategy: tf.distribute.Strategy,
task: base_task.Task,
mode: str,
params: config_definitions.ExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True,
trainer: Optional[base_trainer.Trainer] = None
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params.

Args:
Expand All @@ -50,6 +52,8 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
trainer: the base_trainer.Trainer instance. It should be created within the
strategy.scope().

Returns:
A 2-tuple of (model, eval_logs).
Expand All @@ -59,13 +63,14 @@ def run_experiment(distribution_strategy: tf.distribute.Strategy,
"""

with distribution_strategy.scope():
trainer = train_utils.create_trainer(
params,
task,
train='train' in mode,
evaluate=('eval' in mode) or run_post_eval,
checkpoint_exporter=maybe_create_best_ckpt_exporter(
params, model_dir))
if not trainer:
trainer = train_utils.create_trainer(
params,
task,
train='train' in mode,
evaluate=('eval' in mode) or run_post_eval,
checkpoint_exporter=maybe_create_best_ckpt_exporter(
params, model_dir))

if trainer.checkpoint:
checkpoint_manager = tf.train.CheckpointManager(
Expand Down
164 changes: 164 additions & 0 deletions official/recommendation/ranking/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# TF Model Garden Ranking Models

## Overview
This is an implementation of [DLRM](https://arxiv.org/abs/1906.00091) and
[DCN v2](https://arxiv.org/abs/2008.13535) ranking models that can be used for
tasks such as CTR prediction.

The model inputs are numerical and categorical features, and output is a scalar
(for example click probability).
The model can be trained and evaluated on GPU, TPU and CPU. The deep ranking
models are both memory intensive (for embedding tables/lookup) and compute
intensive for deep networks (MLPs). CPUs are best suited for large sparse
embedding lookup, GPUs for fast compute. TPUs are designed for both.

When training on TPUs we use
[TPUEmbedding layer](https://github.com/tensorflow/recommenders/blob/main/tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py)
for categorical features. TPU embedding supports large embedding tables with
fast lookup, the size of embedding tables scales linearly with the size of TPU
pod. We can have up to 90 GB embedding tables for TPU v3-8 and 5.6 TB for
v3-512 and 22,4 TB for TPU Pod v3-2048.

The Model code is in
[TensorFlow Recommenders](https://github.com/tensorflow/recommenders/tree/main/tensorflow_recommenders/experimental/models)
library, while input pipeline, configuration and training loop is here.

## Prerequisites
To get started, download the code from TensorFlow models GitHub repository or
use the pre-installed Google Cloud VM.

```bash
git clone https://github.com/tensorflow/models.git
export PYTHONPATH=$PYTHONPATH:$(pwd)/models
```

We also need to install
[TensorFlow Recommenders](https://www.tensorflow.org/recommenders) library.
If you are using [tf-nightly](https://pypi.org/project/tf-nightly/) make
sure to install
[tensorflow-recommenders](https://pypi.org/project/tensorflow-recommenders/)
without its dependancies by passing `--no-deps` argument.

For tf-nightly:
```bash
pip install tensorflow-recommenders --no-deps
```

For stable TensorFlow 2.4+ [releases](https://pypi.org/project/tensorflow/):
```bash
pip install tensorflow-recommenders
```


## Dataset

The models can be trained on various datasets, Two commonly used ones are
[Criteo Terabyte](https://labs.criteo.com/2013/12/download-terabyte-click-logs/)
and [Criteo Kaggle](https://labs.criteo.com/2014/02/kaggle-display-advertising-challenge-dataset/)
datasets.
We can train on synthetic data, by setting the flag `use_synthetic_data=True`.

### Download

The dataset is the Terabyte click logs dataset provided by Criteo. Follow the
[instructions](https://labs.criteo.com/2013/12/download-terabyte-click-logs/) at
the Criteo website to download the data.

Note that the dataset is large (~1TB).

### Preprocess the data

Data preprocessing steps are summarized below.

Integer feature processing steps, sequentially:

1. Missing values are replaced with zeros.
2. Negative values are replaced with zeros.
3. Integer features are transformed by log(x+1) and are hence tf.float32.

Categorical features:

1. Categorical data is bucketized to tf.int32.
2. Optionally, the resulting integers are hashed to a lower dimensionality.
This is necessary to reduce the sizes of the large tables. Simple hashing
function such as modulus will suffice, i.e. feature_value % MAX_INDEX.

The vocabulary sizes resulting from pre-processing are passed in to the model
trainer using the model.vocab_sizes config.

The full dataset is composed of 24 directories. Partition the data into training
and eval sets, for example days 1-23 for training and day 24 for evaluation.

Training and eval datasets are expected to be saved in many tab-separated values
(TSV) files in the following format: numberical fetures, categorical features
and label.

On each row of the TSV file first `num_dense_features` inputs are numerical
features, then `vocab_sizes` categorical features and the last one is the label
(either 0 or 1). Each i-th categorical feature is expected to be an integer in
the range of `[0, vocab_sizes[i])`.

## Train and Evaluate

To train DLRM model we use dot product feature interaction, i.e.
`interaction: 'dot'` to train DCN v2 model we use `interaction: 'cross'`.


### Training on TPU

```shell
export TPU_NAME=my-dlrm-tpu
export EXPERIMENT_NAME=my_experiment_name
export BUCKET_NAME="gs://my_dlrm_bucket"
export DATA_DIR="${BUCKET_NAME}/data"

python3 models/official/recommendation/ranking/train.py --mode=train_and_eval \
--model_dir=${BUCKET_NAME}/model_dirs/${EXPERIMENT_NAME} --params_override="
runtime:
distribution_strategy: 'tpu'
task:
use_synthetic_data: false
train_data:
input_path: '${DATA_DIR}/train/*'
global_batch_size: 16384
validation_data:
input_path: '${DATA_DIR}/eval/*'
global_batch_size: 16384
model:
num_dense_features: 13
bottom_mlp: [512,256,128]
embedding_dim: 128
top_mlp: [1024,1024,512,256,1]
interaction: 'dot'
vocab_sizes: [39884406, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63,
38532951, 2953546, 403346, 10, 2208, 11938, 155, 4, 976, 14,
39979771, 25641295, 39664984, 585935, 12972, 108, 36]
trainer:
use_orbit: true
validation_interval: 90000
checkpoint_interval: 100000
validation_steps: 5440
train_steps: 256054
steps_per_loop: 1000
"
```

The data directory should have two subdirectories:

* $DATA_DIR/train
* $DATA_DIR/eval

### Training on GPU

Training on GPUs are similar to TPU training. Only distribution strategy needs
to be updated and number of GPUs provided (for 4 GPUs):

```shell
python3 official/recommendation/ranking/main.py --mode=train_and_eval \
--model_dir=${BUCKET_NAME}/model_dirs/${EXPERIMENT_NAME} --params_override="
runtime:
distribution_strategy: 'mirrored'
num_gpus: 4
...
"
```
14 changes: 14 additions & 0 deletions official/recommendation/ranking/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

113 changes: 113 additions & 0 deletions official/recommendation/ranking/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Flags and common definitions for Ranking Models."""

from absl import flags
import tensorflow as tf

from official.common import flags as tfm_flags

FLAGS = flags.FLAGS


def define_flags() -> None:
"""Defines flags for training the Ranking model."""
tfm_flags.define_flags()

FLAGS.set_default(name='experiment', value='dlrm_criteo')
FLAGS.set_default(name='mode', value='train_and_eval')

flags.DEFINE_integer(
name='seed',
default=None,
help='This value will be used to seed both NumPy and TensorFlow.')
flags.DEFINE_string(
name='profile_steps',
default='20,40',
help='Save profiling data to model dir at given range of global steps. '
'The value must be a comma separated pair of positive integers, '
'specifying the first and last step to profile. For example, '
'"--profile_steps=2,4" triggers the profiler to process 3 steps, starting'
' from the 2nd step. Note that profiler has a non-trivial performance '
'overhead, and the output file can be gigantic if profiling many steps.')


@tf.keras.utils.register_keras_serializable(package='RANKING')
class WarmUpAndPolyDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Learning rate callable for the embeddings.

Linear warmup on [0, warmup_steps] then
Constant on [warmup_steps, decay_start_steps]
And polynomial decay on [decay_start_steps, decay_start_steps + decay_steps].
"""

def __init__(self,
batch_size: int,
decay_exp: float = 2.0,
learning_rate: float = 40.0,
warmup_steps: int = 8000,
decay_steps: int = 12000,
decay_start_steps: int = 10000):
super(WarmUpAndPolyDecay, self).__init__()
self.batch_size = batch_size
self.decay_exp = decay_exp
self.learning_rate = learning_rate
self.warmup_steps = warmup_steps
self.decay_steps = decay_steps
self.decay_start_steps = decay_start_steps

def __call__(self, step):
decay_exp = self.decay_exp
learning_rate = self.learning_rate
warmup_steps = self.warmup_steps
decay_steps = self.decay_steps
decay_start_steps = self.decay_start_steps

scal = self.batch_size / 2048

adj_lr = learning_rate * scal
if warmup_steps == 0:
return adj_lr

warmup_lr = step / warmup_steps * adj_lr
global_step = tf.cast(step, tf.float32)
decay_steps = tf.cast(decay_steps, tf.float32)
decay_start_step = tf.cast(decay_start_steps, tf.float32)
warmup_lr = tf.cast(warmup_lr, tf.float32)

steps_since_decay_start = global_step - decay_start_step
already_decayed_steps = tf.minimum(steps_since_decay_start, decay_steps)
decay_lr = adj_lr * (
(decay_steps - already_decayed_steps) / decay_steps)**decay_exp
decay_lr = tf.maximum(0.0001, decay_lr)

lr = tf.where(
global_step < warmup_steps, warmup_lr,
tf.where(
tf.logical_and(decay_steps > 0, global_step > decay_start_step),
decay_lr, adj_lr))

lr = tf.maximum(0.01, lr)
return lr

def get_config(self):
return {
'batch_size': self.batch_size,
'decay_exp': self.decay_exp,
'learning_rate': self.learning_rate,
'warmup_steps': self.warmup_steps,
'decay_steps': self.decay_steps,
'decay_start_steps': self.decay_start_steps
}
Loading