Skip to content

Commit

Permalink
Finalize design based on review feedbacks.
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhenyu Tan committed Jan 17, 2020
1 parent f86ee91 commit 4df736c
Showing 1 changed file with 34 additions and 31 deletions.
65 changes: 34 additions & 31 deletions rfcs/20191212-keras-categorical-inputs.md
Expand Up @@ -8,13 +8,15 @@

## Objective

This document proposes 4 new preprocessing Keras layers (`CategoryLookup`, `CategoryCrossing`, `CategoryEncoding`, `CategoryHashing`), and an extension to existing op (`tf.sparse.from_dense`) to allow users to:
This document proposes 4 new preprocessing Keras layers (`Lookup`, `CategoryCrossing`, `Vectorize`, `FingerPrint`), and an extension to existing op (`tf.sparse.from_dense`) to allow users to:
* Perform feature engineering for categorical inputs
* Replace feature columns and `tf.keras.layers.DenseFeatures` with proposed layers
* Introduce sparse inputs that work with Keras linear models and other layers that support sparsity

Other proposed layers for replacement of feature columns such as `tf.feature_column.bucketized_column` and `tf.feature_column.numeric_column` has been discussed [here](https://github.com/keras-team/governance/blob/master/rfcs/20190502-preprocessing-layers.md) and are not the focus of this document.

The proposed layers should support ragged tensors.

## Motivation

Specifically, by introducing the 4 layers, we aim to address these pain points:
Expand Down Expand Up @@ -48,8 +50,8 @@ for feature_name in CATEGORICAL_COLUMNS:
feature_input = tf.keras.Input(shape=(1,), dtype=tf.string, name=feature_name, sparse=True)
vocab_list = sorted(dftrain[feature_name].unique())
# Map string values to indices
x = tf.keras.layers.CategoryLookup(vocabulary=vocab_list, name=feature_name)(feature_input)
x = tf.keras.layers.CategoryEncoding(num_categories=len(vocab_list))(x)
x = tf.keras.layers.Lookup(vocabulary=vocab_list, name=feature_name)(feature_input)
x = tf.keras.layers.Vectorize(num_categories=len(vocab_list))(x)
linear_inputs.append(x)
model_inputs.append(feature_input)

Expand Down Expand Up @@ -87,7 +89,7 @@ fc = tf.feature_column.categorical_feature_column_with_vocabulary_list(
Proposed:
```python
x = tf.keras.Input(shape=(1,), name=key, dtype=dtype)
layer = tf.keras.layers.CategoryLookup(
layer = tf.keras.layers.Lookup(
vocabulary=vocabulary_list, num_oov_tokens=num_oov_buckets)
out = layer(x)
```
Expand All @@ -103,7 +105,7 @@ fc = tf.feature_column.categorical_column_with_vocab_file(
Proposed:
```python
x = tf.keras.Input(shape=(1,), name=key, dtype=dtype)
layer = tf.keras.layers.CategoryLookup(
layer = tf.keras.layers.Lookup(
vocabulary=vocabulary_file, num_oov_tokens=num_oov_buckets)
out = layer(x)
```
Expand All @@ -119,7 +121,7 @@ fc = tf.feature_column.categorical_column_with_hash_bucket(
Proposed:
```python
x = tf.keras.Input(shape=(1,), name=key, dtype=dtype)
layer = tf.keras.layers.CategoryHashing(num_bins=hash_bucket_size)
layer = tf.keras.layers.FingerPrint(num_bins=hash_bucket_size)
out = layer(x)
```

Expand Down Expand Up @@ -151,11 +153,11 @@ Proposed:
```python
x1 = tf.keras.Input(shape=(1,), name=key_1, dtype=dtype)
x2 = tf.keras.Input(shape=(1,), name=key_2, dtype=dtype)
layer1 = tf.keras.layers.CategoryLookup(
layer1 = tf.keras.layers.Lookup(
vocabulary=vocabulary_list,
num_oov_tokens=num_oov_buckets)
x1 = layer1(x1)
layer2 = tf.keras.layers.CategoryHashing(
layer2 = tf.keras.layers.FingerPrint(
num_bins=hash_bucket_size)
x2 = layer2(x2)
layer = tf.keras.layers.CategoryCrossing(num_bins=hash_bucket_size)
Expand All @@ -176,27 +178,27 @@ Proposed:
```python
x1 = tf.keras.Input(shape=(1,), name=key, dtype=dtype)
x2 = tf.keras.Input(shape=(1,), name=weight_feature_key, dtype=weight_dtype)
layer = tf.keras.layers.CategoryLookup(
layer = tf.keras.layers.Lookup(
vocabulary=vocabulary_list,
num_oov_tokens=num_oov_buckets)
x1 = layer(x1)
x = tf.keras.layers.CategoryEncoding(num_categories=len(vocabulary_list)+num_oov_buckets)([x1, x2])
x = tf.keras.layers.Vectorize(num_categories=len(vocabulary_list)+num_oov_buckets)([x1, x2])
linear_model = tf.keras.premade.LinearModel(units)
linear_logits = linear_model(x)
```

## Design Proposal
We propose a `CategoryLookup` layer to replace `tf.feature_column.categorical_column_with_vocabulary_list` and `tf.feature_column.categorical_column_with_vocabulary_file`, a `CategoryHashing` layer to replace `tf.feature_column.categorical_column_with_hash_bucket`, a `CategoryCrossing` layer to replace `tf.feature_column.crossed_column`, and another `CategoryEncoding` layer to convert the sparse input to the format required by linear models.
We propose a `Lookup` layer to replace `tf.feature_column.categorical_column_with_vocabulary_list` and `tf.feature_column.categorical_column_with_vocabulary_file`, a `FingerPrint` layer to replace `tf.feature_column.categorical_column_with_hash_bucket`, a `CategoryCrossing` layer to replace `tf.feature_column.crossed_column`, and another `Vectorize` layer to convert the sparse input to the format required by linear models.

```python
`tf.keras.layers.CategoryLookup`
CategoryLookup(PreprocessingLayer):
`tf.keras.layers.Lookup`
Lookup(PreprocessingLayer):
"""This layer transforms categorical inputs to index space.
If input is dense/sparse, then output is dense/sparse."""

def __init__(self, max_tokens=None, num_oov_tokens=1, vocabulary=None,
name=None, **kwargs):
"""Constructs a CategoryLookup layer.
"""Constructs a Lookup layer.
Args:
max_tokens: The maximum size of the vocabulary for this layer. If None,
Expand Down Expand Up @@ -251,21 +253,22 @@ CategoryCrossing(PreprocessingLayer):
"""
pass

`tf.keras.layers.CategoryEncoding`
CategoryEncoding(PreprocessingLayer):
`tf.keras.layers.Vectorize`
Vectorize(PreprocessingLayer):
"""This layer transforms categorical inputs from index space to category space.
If input is dense/sparse, then output is dense/sparse."""

def __init__(self, num_categories, mode="sum", axis=-1, name=None, **kwargs):
"""Constructs a CategoryEncoding layer.
def __init__(self, num_categories, mode="count", axis=-1, sparse_out=True, name=None, **kwargs):
"""Constructs a Vectorize layer.
Args:
num_categories: Number of elements in the vocabulary.
mode: how to reduce a categorical input if multivalent, can be one of "sum",
"mean", "binary", "tfidf". It can also be None if this is not a multivalent input,
mode: how to reduce a categorical input if multivalent, can be one of "count",
"avg_count", "binary", "tfidf". It can also be None if this is not a multivalent input,
and simply needs to convert input from index space to category space. "tfidf" is only
valid when adapt is called on this layer.
axis: the axis to reduce, by default will be the last axis, specially true
for sequential feature columns.
sparse_out: boolean to indicate whether the output should be dense or sparse tensor.
name: Name to give to the layer.
**kwargs: Keyword arguments to construct a layer.
Expand All @@ -274,18 +277,18 @@ CategoryEncoding(PreprocessingLayer):
Example:
If the input is 2 by 2 dense integer tensor '[[0, 2], [2, 2]]' with `num_categories=3`, then
output is 2 by 3 dense integer tensor '[[1, 0, 1], [0, 0, 2]]' with a `sum` encoding, or
dense float tensor '[[.5, 0, .5], [0, 0, 1.]]' with a `mean` encoding, or dense integer tensor
output is 2 by 3 dense integer tensor '[[1, 0, 1], [0, 0, 2]]' with a `count` encoding, or
dense float tensor '[[.5, 0, .5], [0, 0, 1.]]' with a `avg_count` encoding, or dense integer tensor
'[[1, 0, 1], [0, 0, 1]]' with a `binary` encoding.
"""
pass

`tf.keras.layers.CategoryHashing`
CategoryHashing(PreprocessingLayer):
`tf.keras.layers.FingerPrint`
FingerPrint(PreprocessingLayer):
"""This layer transforms categorical inputs to hashed output.
If input is dense/sparse, then output is dense/sparse."""
def __init__(self, num_bins, name=None, **kwargs):
"""Constructs a CategoryHashing layer.
"""Constructs a FingerPrint layer.
Args:
num_bins: Number of hash bins.
Expand Down Expand Up @@ -352,7 +355,7 @@ Below is a more detailed illustration of how each layer works. If there is a voc
vocabulary_list = ["Italy", "France", "England", "Austria", "Germany"]
inp = np.asarray([["Italy", "Italy"], ["Germany", ""]])
sp_inp = tf.sparse.from_dense(inp, ignore_value="")
cat_layer = tf.keras.layers.CategoryLookup(vocabulary=vocabulary_list)
cat_layer = tf.keras.layers.Lookup(vocabulary=vocabulary_list)
sp_out = cat_layer(sp_inp)
```

Expand All @@ -364,9 +367,9 @@ sp_out.values = <tf.Tensor: id=28, shape=(3,), dtype=int64,
numpy=array([0, 0, 4])>
```

The `CategoryEncoding` layer will then convert the input from index space to category space, e.g., from a sparse tensor with indices shape as [batch_size, n_columns] and values in the range of [0, n_categories) to a sparse tensor with indices shape as [batch_size, n_categories] and values as the frequency of each value that occured in the example:
The `Vectorize` layer will then convert the input from index space to category space, e.g., from a sparse tensor with indices shape as [batch_size, n_columns] and values in the range of [0, n_categories) to a sparse tensor with indices shape as [batch_size, n_categories] and values as the frequency of each value that occured in the example:
```python
encoding_layer = CategoryEncoding(num_categories=len(vocabulary_list))
encoding_layer = Vectorize(num_categories=len(vocabulary_list))
sp_encoded_out = encoding_layer(sp_out)
sp_encoded_out.indices = <tf.Tensor: id=8, shape=(2, 2), dtype=int64, numpy=
array([[0, 0], [1, 4]])>
Expand All @@ -379,14 +382,14 @@ If this input needs to be crossed with another categorical input, say a vocabula
```python
days = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]
inp_days = tf.sparse.from_dense(np.asarray([["Sunday"], [""]]), ignore_value="")
layer_days = CategoryLookup(vocabulary=days)
layer_days = Lookup(vocabulary=days)
sp_out_2 = layer_days(inp_days)

sp_out_2.indices = <tf.Tensor: id=161, shape=(1, 2), dtype=int64, numpy=array([[0, 0]])>
sp_out_2.values = <tf.Tensor: id=181, shape=(1,), dtype=int64, numpy=array([6])>

cross_layer = CategoryCrossing(num_bins=5)
# Use the output from CategoryLookup (sp_out), not CategoryEncoding (sp_combined_out)
# Use the output from Lookup (sp_out), not Vectorize (sp_combined_out)
crossed_out = cross_layer([sp_out, sp_out_2])

cross_out.indices = <tf.Tensor: id=186, shape=(2, 2), dtype=int64, numpy=
Expand All @@ -395,4 +398,4 @@ cross_out.values = <tf.Tensor: id=187, shape=(2,), dtype=int64, numpy=array([3,
```

## Questions and Discussion Topics
We'd like to gather feedbacks on `CategoryLookup`, specifically we propose migrating off from mutually exclusive `num_oov_buckets` and `default_value` and replace with `num_oov_tokens`.
We'd like to gather feedbacks on `Lookup`, specifically we propose migrating off from mutually exclusive `num_oov_buckets` and `default_value` and replace with `num_oov_tokens`.

0 comments on commit 4df736c

Please sign in to comment.