Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the CSV hash table in Hash layer and fix a bug. #385

Merged
merged 14 commits into from Jul 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 7 additions & 3 deletions deepctr/feature_column.py
Expand Up @@ -15,12 +15,12 @@


class SparseFeat(namedtuple('SparseFeat',
['name', 'vocabulary_size', 'embedding_dim', 'use_hash', 'dtype', 'embeddings_initializer',
['name', 'vocabulary_size', 'embedding_dim', 'use_hash', 'vocabulary_path', 'dtype', 'embeddings_initializer',
'embedding_name',
'group_name', 'trainable'])):
__slots__ = ()

def __new__(cls, name, vocabulary_size, embedding_dim=4, use_hash=False, dtype="int32", embeddings_initializer=None,
def __new__(cls, name, vocabulary_size, embedding_dim=4, use_hash=False, vocabulary_path=None, dtype="int32", embeddings_initializer=None,
embedding_name=None,
group_name=DEFAULT_GROUP_NAME, trainable=True):

Expand All @@ -32,7 +32,7 @@ def __new__(cls, name, vocabulary_size, embedding_dim=4, use_hash=False, dtype="
if embedding_name is None:
embedding_name = name

return super(SparseFeat, cls).__new__(cls, name, vocabulary_size, embedding_dim, use_hash, dtype,
return super(SparseFeat, cls).__new__(cls, name, vocabulary_size, embedding_dim, use_hash, vocabulary_path, dtype,
embeddings_initializer,
embedding_name, group_name, trainable)

Expand Down Expand Up @@ -64,6 +64,10 @@ def embedding_dim(self):
def use_hash(self):
return self.sparsefeat.use_hash

@property
def vocabulary_path(self):
return self.sparsefeat.vocabulary_path

@property
def dtype(self):
return self.sparsefeat.dtype
Expand Down
6 changes: 3 additions & 3 deletions deepctr/inputs.py
Expand Up @@ -51,7 +51,7 @@ def get_embedding_vec_list(embedding_dict, input_dict, sparse_feature_columns, r
feat_name = fg.name
if len(return_feat_list) == 0 or feat_name in return_feat_list:
if fg.use_hash:
lookup_idx = Hash(fg.vocabulary_size, mask_zero=(feat_name in mask_feat_list))(input_dict[feat_name])
lookup_idx = Hash(fg.vocabulary_size, mask_zero=(feat_name in mask_feat_list), vocabulary_path=fg.vocabulary_path)(input_dict[feat_name])
else:
lookup_idx = input_dict[feat_name]

Expand Down Expand Up @@ -80,7 +80,7 @@ def embedding_lookup(sparse_embedding_dict, sparse_input_dict, sparse_feature_co
embedding_name = fc.embedding_name
if (len(return_feat_list) == 0 or feature_name in return_feat_list):
if fc.use_hash:
lookup_idx = Hash(fc.vocabulary_size, mask_zero=(feature_name in mask_feat_list))(
lookup_idx = Hash(fc.vocabulary_size, mask_zero=(feature_name in mask_feat_list), vocabulary_path=fc.vocabulary_path)(
sparse_input_dict[feature_name])
else:
lookup_idx = sparse_input_dict[feature_name]
Expand All @@ -97,7 +97,7 @@ def varlen_embedding_lookup(embedding_dict, sequence_input_dict, varlen_sparse_f
feature_name = fc.name
embedding_name = fc.embedding_name
if fc.use_hash:
lookup_idx = Hash(fc.vocabulary_size, mask_zero=True)(sequence_input_dict[feature_name])
lookup_idx = Hash(fc.vocabulary_size, mask_zero=True, vocabulary_path=fc.vocabulary_path)(sequence_input_dict[feature_name])
else:
lookup_idx = sequence_input_dict[feature_name]
varlen_embedding_vec_dict[feature_name] = embedding_dict[embedding_name](lookup_idx)
Expand Down
9 changes: 3 additions & 6 deletions deepctr/layers/core.py
Expand Up @@ -68,8 +68,8 @@ def build(self, input_shape):
'inputs of a two inputs with shape (None,1,embedding_size) and (None,T,embedding_size)'
'Got different shapes: %s,%s' % (input_shape[0], input_shape[1]))
size = 4 * \
int(input_shape[0][-1]
) if len(self.hidden_units) == 0 else self.hidden_units[-1]
int(input_shape[0][-1]
) if len(self.hidden_units) == 0 else self.hidden_units[-1]
self.kernel = self.add_weight(shape=(size, 1),
initializer=glorot_normal(
seed=self.seed),
Expand All @@ -78,9 +78,6 @@ def build(self, input_shape):
shape=(1,), initializer=Zeros(), name="bias")
self.dnn = DNN(self.hidden_units, self.activation, self.l2_reg, self.dropout_rate, self.use_bn, seed=self.seed)

self.dense = tf.keras.layers.Lambda(lambda x: tf.nn.bias_add(tf.tensordot(
x[0], x[1], axes=(-1, 0)), x[2]))

super(LocalActivationUnit, self).build(
input_shape) # Be sure to call this somewhere!

Expand All @@ -96,7 +93,7 @@ def call(self, inputs, training=None, **kwargs):

att_out = self.dnn(att_input, training=training)

attention_score = self.dense([att_out, self.kernel, self.bias])
attention_score = tf.nn.bias_add(tf.tensordot(att_out, self.kernel, axes=(-1, 0)), self.bias)

return attention_score

Expand Down
13 changes: 4 additions & 9 deletions deepctr/layers/sequence.py
Expand Up @@ -560,10 +560,10 @@ def call(self, inputs, mask=None, training=None, **kwargs):
if self.blinding:
try:
outputs = tf.matrix_set_diag(outputs, tf.ones_like(outputs)[
:, :, 0] * (-2 ** 32 + 1))
:, :, 0] * (-2 ** 32 + 1))
except:
outputs = tf.compat.v1.matrix_set_diag(outputs, tf.ones_like(outputs)[
:, :, 0] * (-2 ** 32 + 1))
:, :, 0] * (-2 ** 32 + 1))

outputs -= reduce_max(outputs, axis=-1, keep_dims=True)
outputs = softmax(outputs)
Expand Down Expand Up @@ -640,7 +640,8 @@ def build(self, input_shape):
# Second part, apply the cosine to even columns and sin to odds.
position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # dim 2i
position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # dim 2i+1

if self.zero_pad:
position_enc[0, :] = np.zeros(num_units)
self.lookup_table = self.add_weight("lookup_table", (T, num_units),
initializer=tf.initializers.identity(position_enc),
trainable=self.pos_embedding_trainable)
Expand All @@ -651,13 +652,7 @@ def build(self, input_shape):
def call(self, inputs, mask=None):
_, T, num_units = inputs.get_shape().as_list()
position_ind = tf.expand_dims(tf.range(T), 0)

if self.zero_pad:
self.lookup_table = tf.concat((tf.zeros(shape=[1, num_units]),
self.lookup_table[1:, :]), 0)

outputs = tf.nn.embedding_lookup(self.lookup_table, position_ind)

if self.scale:
outputs = outputs * num_units ** 0.5
return outputs + inputs
Expand Down
52 changes: 46 additions & 6 deletions deepctr/layers/utils.py
Expand Up @@ -25,14 +25,47 @@ def compute_mask(self, inputs, mask):


class Hash(tf.keras.layers.Layer):
"""
hash the input to [0,num_buckets)
if mask_zero = True,0 or 0.0 will be set to 0,other value will be set in range[1,num_buckets)
"""Looks up keys in a table when setup `vocabulary_path`, which outputs the corresponding values.
If `vocabulary_path` is not set, `Hash` will hash the input to [0,num_buckets). When `mask_zero` = True,
input value `0` or `0.0` will be set to `0`, and other value will be set in range [1,num_buckets).

The following snippet initializes a `Hash` with `vocabulary_path` file with the first column as keys and
second column as values:

* `1,emerson`
* `2,lake`
* `3,palmer`

>>> hash = Hash(
... num_buckets=3+1,
... vocabulary_path=filename,
... default_value=0)
>>> hash(tf.constant('lake')).numpy()
2
>>> hash(tf.constant('lakeemerson')).numpy()
0

Args:
num_buckets: An `int` that is >= 1. The number of buckets or the vocabulary size + 1
when `vocabulary_path` is setup.
mask_zero: default is False. The `Hash` value will hash input `0` or `0.0` to value `0` when
the `mask_zero` is `True`. `mask_zero` is not used when `vocabulary_path` is setup.
vocabulary_path: default `None`. The `CSV` text file path of the vocabulary hash, which contains
two columns seperated by delimiter `comma`, the first column is the value and the second is
the key. The key data type is `string`, the value data type is `int`. The path must
be accessible from wherever `Hash` is initialized.
default_value: default '0'. The default value if a key is missing in the table.
**kwargs: Additional keyword arguments.
"""

def __init__(self, num_buckets, mask_zero=False, **kwargs):
def __init__(self, num_buckets, mask_zero=False, vocabulary_path=None, default_value=0, **kwargs):
self.num_buckets = num_buckets
self.mask_zero = mask_zero
self.vocabulary_path = vocabulary_path
self.default_value = default_value
if self.vocabulary_path:
initializer = tf.lookup.TextFileInitializer(vocabulary_path, 'string', 1, 'int64', 0, delimiter=',')
self.hash_table = tf.lookup.StaticHashTable(initializer, default_value=self.default_value)
super(Hash, self).__init__(**kwargs)

def build(self, input_shape):
Expand All @@ -41,13 +74,16 @@ def build(self, input_shape):

def call(self, x, mask=None, **kwargs):


if x.dtype != tf.string:
zero = tf.as_string(tf.zeros([1], dtype=x.dtype))
x = tf.as_string(x, )
else:
zero = tf.as_string(tf.zeros([1], dtype='int32'))

if self.vocabulary_path:
hash_x = self.hash_table.lookup(x)
return hash_x

num_buckets = self.num_buckets if not self.mask_zero else self.num_buckets - 1
try:
hash_x = tf.string_to_hash_bucket_fast(x, num_buckets,
Expand All @@ -60,8 +96,12 @@ def call(self, x, mask=None, **kwargs):
hash_x = (hash_x + 1) * mask

return hash_x

def compute_output_shape(self, input_shape):
return input_shape

def get_config(self, ):
config = {'num_buckets': self.num_buckets, 'mask_zero': self.mask_zero, }
config = {'num_buckets': self.num_buckets, 'mask_zero': self.mask_zero, 'vocabulary_path': self.vocabulary_path, 'default_value': self.default_value}
base_config = super(Hash, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

Expand Down
5 changes: 3 additions & 2 deletions docs/source/Features.md
Expand Up @@ -23,12 +23,13 @@ DNN based CTR prediction models usually have following 4 modules:

## Feature Columns
### SparseFeat
``SparseFeat`` is a namedtuple with signature ``SparseFeat(name, vocabulary_size, embedding_dim, use_hash, dtype, embeddings_initializer, embedding_name, group_name, trainable)``
``SparseFeat`` is a namedtuple with signature ``SparseFeat(name, vocabulary_size, embedding_dim, use_hash, vocabulary_path, dtype, embeddings_initializer, embedding_name, group_name, trainable)``

- name : feature name
- vocabulary_size : number of unique feature values for sprase feature or hashing space when `use_hash=True`
- embedding_dim : embedding dimension
- use_hash : defualt `False`.If `True` the input will be hashed to space of size `vocabulary_size`.
- use_hash : default `False`.If `True` the input will be hashed to space of size `vocabulary_size`.
- vocabulary_path : default `None`. The `CSV` text file path of the vocabulary table used by `tf.lookup.TextFileInitializer`, which assigns one entry in the table for each line in the file. One entry contains two columns seperated by comma, the first is the value column, the second is the key column. The `0` value is reserved to use if a key is missing in the table, so hash value need start from `1`.
- dtype : default `int32`.dtype of input tensor.
- embeddings_initializer : initializer for the `embeddings` matrix.
- embedding_name : default `None`. If None, the embedding_name will be same as `name`.
Expand Down
14 changes: 12 additions & 2 deletions tests/feature_test.py
@@ -1,6 +1,8 @@
from deepctr.models import DeepFM
from deepctr.feature_column import SparseFeat, DenseFeat,get_feature_names
from deepctr.feature_column import SparseFeat, DenseFeat, VarLenSparseFeat, get_feature_names
import numpy as np


def test_long_dense_vector():

feature_columns = [SparseFeat('user_id', 4, ), SparseFeat('item_id', 5, ), DenseFeat("pic_vec", 5)]
Expand All @@ -16,4 +18,12 @@ def test_long_dense_vector():

model = DeepFM(feature_columns, feature_columns[:-1])
model.compile('adagrad', 'binary_crossentropy')
model.fit(model_input, label)
model.fit(model_input, label)


def test_feature_column_sparsefeat_vocabulary_path():
vocab_path = "./dummy_test.csv"
sf = SparseFeat('user_id', 4, vocabulary_path=vocab_path)
assert sf.vocabulary_path == vocab_path
vlsf = VarLenSparseFeat(sf, 6)
assert vlsf.vocabulary_path == vocab_path
4 changes: 1 addition & 3 deletions tests/layers/sequence_test.py
Expand Up @@ -79,8 +79,6 @@ def test_BiLSTM(merge_mode):


def test_Transformer():
if tf.__version__ >= '2.0.0':
tf.compat.v1.disable_eager_execution() # todo
with CustomObjectScope({'Transformer': sequence.Transformer}):
layer_test(sequence.Transformer,
kwargs={'att_embedding_size': 1, 'head_num': 8, 'use_layer_norm': True, 'supports_masking': False,
Expand All @@ -102,7 +100,7 @@ def test_KMaxPooling():
]
)
def test_PositionEncoding(pos_embedding_trainable, zero_pad):
with CustomObjectScope({'PositionEncoding': sequence.PositionEncoding}):
with CustomObjectScope({'PositionEncoding': sequence.PositionEncoding, "tf": tf}):
layer_test(sequence.PositionEncoding,
kwargs={'pos_embedding_trainable': pos_embedding_trainable, 'zero_pad': zero_pad},
input_shape=(BATCH_SIZE, SEQ_LENGTH, EMBEDDING_SIZE))
27 changes: 27 additions & 0 deletions tests/layers/utils_test.py
@@ -0,0 +1,27 @@
import pytest
import numpy as np
import tensorflow as tf
from deepctr.layers.utils import Hash
from tests.utils import layer_test
try:
from tensorflow.python.keras.utils import CustomObjectScope
except:
from tensorflow.keras.utils import CustomObjectScope


@pytest.mark.parametrize(
'num_buckets,mask_zero,vocabulary_path,input_data,expected_output',
[
(3+1, False, None, ['lakemerson'], None),
(3+1, True, None, ['lakemerson'], None),
(3+1, False, "./tests/layers/vocabulary_example.csv", [['lake'], ['johnson'], ['lakemerson']], [[1], [3], [0]])
]
)
def test_Hash(num_buckets, mask_zero, vocabulary_path, input_data, expected_output):
if not hasattr(tf, 'version') or tf.version.VERSION < '2.0.0':
return

with CustomObjectScope({'Hash': Hash}):
layer_test(Hash, kwargs={'num_buckets': num_buckets, 'mask_zero': mask_zero, 'vocabulary_path': vocabulary_path},
input_dtype=tf.string, input_data=np.array(input_data, dtype='str'),
expected_output_dtype=tf.int64, expected_output=expected_output)
3 changes: 3 additions & 0 deletions tests/layers/vocabulary_example.csv
@@ -0,0 +1,3 @@
1,lake
2,merson
3,johnson
17 changes: 10 additions & 7 deletions tests/utils.py
Expand Up @@ -18,6 +18,7 @@
VOCABULARY_SIZE = 4
Estimator_TEST_TF1 = True


def gen_sequence(dim, max_len, sample_size):
return np.array([np.random.randint(0, dim, max_len) for _ in range(sample_size)]), np.random.randint(1, max_len + 1,
sample_size)
Expand All @@ -44,15 +45,15 @@ def get_test_data(sample_size=1000, embedding_size=4, sparse_feature_num=1, dens

for i in range(sparse_feature_num):
if use_group:
group_name = str(i%3)
group_name = str(i % 3)
else:
group_name = DEFAULT_GROUP_NAME
dim = np.random.randint(1, 10)
feature_columns.append(
SparseFeat(prefix + 'sparse_feature_' + str(i), dim, embedding_size, use_hash=hash_flag, dtype=tf.int32,group_name=group_name))
SparseFeat(prefix + 'sparse_feature_' + str(i), dim, embedding_size, use_hash=hash_flag, dtype=tf.int32, group_name=group_name))

for i in range(dense_feature_num):
transform_fn = lambda x: (x - 0.0)/ 1.0
def transform_fn(x): return (x - 0.0) / 1.0
feature_columns.append(
DenseFeat(
prefix + 'dense_feature_' + str(i),
Expand Down Expand Up @@ -363,6 +364,7 @@ def check_model(model, model_name, x, y, check_model_io=True):

print(model_name + " test pass!")


def get_test_data_estimator(sample_size=1000, embedding_size=4, sparse_feature_num=1, dense_feature_num=1, classification=True):

x = {}
Expand All @@ -372,7 +374,7 @@ def get_test_data_estimator(sample_size=1000, embedding_size=4, sparse_feature_n
for i in range(sparse_feature_num):
name = 's_'+str(i)
x[name] = np.random.randint(0, voc_size, sample_size)
dnn_feature_columns.append(tf.feature_column.embedding_column(tf.feature_column.categorical_column_with_identity(name,voc_size),embedding_size))
dnn_feature_columns.append(tf.feature_column.embedding_column(tf.feature_column.categorical_column_with_identity(name, voc_size), embedding_size))
linear_feature_columns.append(tf.feature_column.categorical_column_with_identity(name, voc_size))

for i in range(dense_feature_num):
Expand All @@ -390,8 +392,9 @@ def get_test_data_estimator(sample_size=1000, embedding_size=4, sparse_feature_n
else:
input_fn = tf.estimator.inputs.numpy_input_fn(x, y, shuffle=False)

return linear_feature_columns,dnn_feature_columns,input_fn
return linear_feature_columns, dnn_feature_columns, input_fn


def check_estimator(model,input_fn):
def check_estimator(model, input_fn):
model.train(input_fn)
model.evaluate(input_fn)
model.evaluate(input_fn)