# TensorFlow Lattice estimators
In this tutorial, we will TensorFlow Lattice estimators.
The more detailed version of this notebook can be found in
https://github.com/tensorflow/lattice/blob/master/g3doc/tutorial/index.md

In [1]:
# import libraries
import pandas as pd
import tensorflow as tf
import tensorflow_lattice as tfl
import tempfile
#import urllib
from six.moves.urllib.request import urlretrieve
import os

In [2]:
def download_if_not_exists(train_data, test_data):
    """Maybe downloads training data and returns train and test file names."""
    train_file_name = train_data
    if not os.path.exists(train_file_name):
        train_file = tempfile.NamedTemporaryFile(delete=False)
        urlretrieve(
            "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data",
            train_file.name)  # pylint: disable=line-too-long
        train_file_name = train_file.name
        train_file.close()
        print("Training data is downloaded to %s" % train_file_name)

    test_file_name = test_data
    if not os.path.exists(test_file_name):
        test_file = tempfile.NamedTemporaryFile(delete=False)
        urlretrieve(
            "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test",
            test_file.name)  # pylint: disable=line-too-long
        test_file_name = test_file.name
        test_file.close()
        print("Test data is downloaded to %s"% test_file_name)
    
    return (train_file_name, test_file_name)

# Specify the dataset
(TRAIN_DATA, TEST_DATA) = download_if_not_exists("/tmp/tfl-data/adult.data",
                                                 "/tmp/tfl-data/adult.test") 

Training data is downloaded to /tmp/tmplaao4fkt
Test data is downloaded to /tmp/tmpomu_76ao


# Define features

In [3]:
CSV_COLUMNS = [
    "age", "workclass", "fnlwgt", "education", "education_num",
    "marital_status", "occupation", "relationship", "race", "gender",
    "capital_gain", "capital_loss", "hours_per_week", "native_country",
    "income_bracket"
]

def get_input_fn(file_path, batch_size, num_epochs, shuffle):
    df_data = pd.read_csv(
        tf.gfile.Open(file_path),
        names=CSV_COLUMNS,
        skipinitialspace=True,
        engine="python",
        skiprows=1)
    # Drop missing for the time being.
    df_data = df_data.dropna(how="any", axis=0)
    labels = df_data["income_bracket"].apply(lambda x: ">50K" in x).astype(int)
    return tf.estimator.inputs.pandas_input_fn(
        x=df_data,
        y=labels,
        batch_size=batch_size,
        num_epochs=num_epochs,
        shuffle=shuffle,
        num_threads=5)

def get_train_input_fn(batch_size, num_epochs=1, shuffle=False):
    train_data = TRAIN_DATA
    return get_input_fn(train_data, batch_size, num_epochs, shuffle)


def densify(fc, make_dense):
    if not make_dense:
        return fc
    return tf.feature_column.embedding_column(fc, 4)


def get_feature_columns(make_dense=False):
    # Categorical features.
    gender = densify(
        tf.feature_column.categorical_column_with_vocabulary_list(
            "gender", ["Female", "Male"]), make_dense)
    education = densify(
        tf.feature_column.categorical_column_with_vocabulary_list(
            "education", [
                "Bachelors", "HS-grad", "11th", "Masters", "9th", "Some-college",
                "Assoc-acdm", "Assoc-voc", "7th-8th", "Doctorate", "Prof-school",
                "5th-6th", "10th", "1st-4th", "Preschool", "12th"
            ]), make_dense)
    marital_status = densify(
        tf.feature_column.categorical_column_with_vocabulary_list(
            "marital_status", [
                "Married-civ-spouse", "Divorced", "Married-spouse-absent",
                "Never-married", "Separated", "Married-AF-spouse", "Widowed"
            ]), make_dense)
    relationship = densify(
        tf.feature_column.categorical_column_with_vocabulary_list(
            "relationship", [
                "Husband", "Not-in-family", "Wife", "Own-child", "Unmarried",
                "Other-relative"
            ]), make_dense)
    workclass = densify(
        tf.feature_column.categorical_column_with_vocabulary_list(
            "workclass", [
                "Self-emp-not-inc", "Private", "State-gov", "Federal-gov",
                "Local-gov", "?", "Self-emp-inc", "Without-pay", "Never-worked"
            ]), make_dense)

    # To show an example of hashing:
    occupation = densify(
        tf.feature_column.categorical_column_with_hash_bucket(
            "occupation", hash_bucket_size=1000), make_dense)
    native_country = densify(
        tf.feature_column.categorical_column_with_hash_bucket(
            "native_country", hash_bucket_size=1000), make_dense)

    # Continuous base columns.
    age = tf.feature_column.numeric_column("age")
    education_num = tf.feature_column.numeric_column("education_num")
    capital_gain = tf.feature_column.numeric_column("capital_gain")
    capital_loss = tf.feature_column.numeric_column("capital_loss")
    hours_per_week = tf.feature_column.numeric_column("hours_per_week")
    
    return [
        age,
        education_num,
        capital_gain,
        capital_loss,
        hours_per_week,
        gender,
        education,
        marital_status,
    ]

# Create a histogram
This information will be used to initialize the calibrator input keypoints.

In [4]:
quantiles_dir = tempfile.mkdtemp()

def create_histogram(quantiles_dir):
    input_fn = get_train_input_fn(batch_size=10000, num_epochs=1, shuffle=False)
    tfl.save_quantiles_for_keypoints(
        input_fn=input_fn,
        save_dir=quantiles_dir,
        feature_columns=get_feature_columns(make_dense=False),
        num_steps=10)
    
create_histogram(quantiles_dir)



# Estimator!

In [5]:
def _build_linear_estimator(model_dir, feature_columns, learning_rate):
    """Build linear estimator."""
    feature_names = [fc.name for fc in feature_columns]
    hparams = tfl.CalibratedLinearHParams(
        feature_names=feature_names,
        learning_rate=learning_rate,
        num_keypoints=20)

    m = tfl.calibrated_linear_classifier(
        model_dir=model_dir,
        quantiles_dir=quantiles_dir,
        feature_columns=feature_columns,
        hparams=hparams)
    return m

def _build_rtl_estimator(model_dir, feature_columns, learning_rate):
    """Build rtl estimator."""
    feature_names = [fc.name for fc in feature_columns]
    # Create 100 number of 2 x 2 x 2 x 2 lattices.
    hparams = tfl.CalibratedRtlHParams(
        feature_names=feature_names,
        learning_rate=learning_rate,
        lattice_rank=4,
        num_lattices=100,
        num_keypoints=20)
    
    m = tfl.calibrated_rtl_classifier(
        model_dir=model_dir,
        quantiles_dir=quantiles_dir,
        feature_columns=feature_columns,
        hparams=hparams)
    return m

def build_estimator(model_dir, learning_rate, model_type='rtl'):
    """Build an estimator."""
    if not tf.gfile.Exists(model_dir):
        tf.gfile.MkDir(model_dir)
    feature_columns = get_feature_columns(make_dense=False)
    if model_type == 'rtl':
        return _build_rtl_estimator(model_dir, feature_columns, learning_rate)
    elif model_type == 'linear':
        return _build_linear_estimator(model_dir, feature_columns, learning_rate)
    else:
        raise ValueError('unsupported model type: %s' % model_type)

# Train linear model

In [6]:
model_dir = tempfile.mkdtemp()
learning_rate = 0.01
batch_size = 100

m = build_estimator(model_dir, learning_rate, model_type='linear')
m.train(input_fn=get_train_input_fn(batch_size=100, num_epochs=1, shuffle=True))

print('=====Training set=====')
results = m.evaluate(input_fn=get_train_input_fn(batch_size=batch_size))
for key in sorted(results):
    print('%s: %s' % (key, results[key]))

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_keep_checkpoint_every_n_hours': 10000, '_model_dir': '/tmp/tmpoo1em44h', '_task_id': 0, '_keep_checkpoint_max': 5, '_save_checkpoints_secs': 600, '_task_type': 'worker', '_service': None, '_save_checkpoints_steps': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f42be948ef0>, '_is_chief': True, '_num_worker_replicas': 1, '_master': '', '_session_config': None, '_save_summary_steps': 100, '_num_ps_replicas': 0, '_log_step_count_steps': 100, '_tf_random_seed': None}
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Saving checkpoints for 1 into /tmp/tmpoo1em44h/model.ckpt.
INFO:tensorflow:loss = 48.070786, step = 1
INFO:tensorflow:global_step/sec: 207.008
INFO:tensorflow:loss = 29.85246, step = 101 (0.484 sec)
INFO:tensorflow:global_step/sec: 233.638
INFO:tensorflow:loss = 31.883923, step = 201 (0.429 sec)
INFO:tensorflow:global_step/sec: 235.962
INFO:tensorflow:loss = 

# Train RTL model

In [7]:
model_dir = tempfile.mkdtemp()
learning_rate = 0.01
batch_size = 100

m = build_estimator(model_dir, learning_rate, model_type='rtl')
m.train(input_fn=get_train_input_fn(batch_size=100, num_epochs=1, shuffle=True))

print('=====Training set=====')
results = m.evaluate(input_fn=get_train_input_fn(batch_size=batch_size))
for key in sorted(results):
    print('%s: %s' % (key, results[key]))

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_keep_checkpoint_every_n_hours': 10000, '_model_dir': '/tmp/tmpzlxizk_c', '_task_id': 0, '_keep_checkpoint_max': 5, '_save_checkpoints_secs': 600, '_task_type': 'worker', '_service': None, '_save_checkpoints_steps': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f42bec8cf28>, '_is_chief': True, '_num_worker_replicas': 1, '_master': '', '_session_config': None, '_save_summary_steps': 100, '_num_ps_replicas': 0, '_log_step_count_steps': 100, '_tf_random_seed': None}
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Saving checkpoints for 1 into /tmp/tmpzlxizk_c/model.ckpt.
INFO:tensorflow:loss = 69.31474, step = 1
INFO:tensorflow:global_step/sec: 42.3544
INFO:tensorflow:loss = 36.53876, step = 101 (2.362 sec)
INFO:tensorflow:global_step/sec: 54.2894
INFO:tensorflow:loss = 45.341515, step = 201 (1.842 sec)
INFO:tensorflow:global_step/sec: 69.8215
INFO:tensorflow:loss = 3