Skip to content

Commit

Permalink
merge to dev-dzc branch for continuous integration (#384)
Browse files Browse the repository at this point in the history
* delete the Lambda sublayer in LocalActivationUnit Layer class

* add vocabulary_path in the SparseFeat to support the csv HashTable functionality

* update docs and add examples in doc

* Remove trailing whitespace
  • Loading branch information
dengc367 committed Jun 29, 2021
1 parent 0df401c commit ab7b38a
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 20 deletions.
10 changes: 7 additions & 3 deletions deepctr/feature_column.py
Expand Up @@ -15,12 +15,12 @@


class SparseFeat(namedtuple('SparseFeat',
['name', 'vocabulary_size', 'embedding_dim', 'use_hash', 'dtype', 'embeddings_initializer',
['name', 'vocabulary_size', 'embedding_dim', 'use_hash', 'vocabulary_path', 'dtype', 'embeddings_initializer',
'embedding_name',
'group_name', 'trainable'])):
__slots__ = ()

def __new__(cls, name, vocabulary_size, embedding_dim=4, use_hash=False, dtype="int32", embeddings_initializer=None,
def __new__(cls, name, vocabulary_size, embedding_dim=4, use_hash=False, vocabulary_path=None, dtype="int32", embeddings_initializer=None,
embedding_name=None,
group_name=DEFAULT_GROUP_NAME, trainable=True):

Expand All @@ -32,7 +32,7 @@ def __new__(cls, name, vocabulary_size, embedding_dim=4, use_hash=False, dtype="
if embedding_name is None:
embedding_name = name

return super(SparseFeat, cls).__new__(cls, name, vocabulary_size, embedding_dim, use_hash, dtype,
return super(SparseFeat, cls).__new__(cls, name, vocabulary_size, embedding_dim, use_hash, vocabulary_path, dtype,
embeddings_initializer,
embedding_name, group_name, trainable)

Expand Down Expand Up @@ -64,6 +64,10 @@ def embedding_dim(self):
def use_hash(self):
return self.sparsefeat.use_hash

@property
def vocabulary_path(self):
return self.sparsefeat.vocabulary_path

@property
def dtype(self):
return self.sparsefeat.dtype
Expand Down
6 changes: 3 additions & 3 deletions deepctr/inputs.py
Expand Up @@ -51,7 +51,7 @@ def get_embedding_vec_list(embedding_dict, input_dict, sparse_feature_columns, r
feat_name = fg.name
if len(return_feat_list) == 0 or feat_name in return_feat_list:
if fg.use_hash:
lookup_idx = Hash(fg.vocabulary_size, mask_zero=(feat_name in mask_feat_list))(input_dict[feat_name])
lookup_idx = Hash(fg.vocabulary_size, mask_zero=(feat_name in mask_feat_list), vocabulary_path=fg.vocabulary_path)(input_dict[feat_name])
else:
lookup_idx = input_dict[feat_name]

Expand Down Expand Up @@ -80,7 +80,7 @@ def embedding_lookup(sparse_embedding_dict, sparse_input_dict, sparse_feature_co
embedding_name = fc.embedding_name
if (len(return_feat_list) == 0 or feature_name in return_feat_list):
if fc.use_hash:
lookup_idx = Hash(fc.vocabulary_size, mask_zero=(feature_name in mask_feat_list))(
lookup_idx = Hash(fc.vocabulary_size, mask_zero=(feature_name in mask_feat_list), vocabulary_path=fc.vocabulary_path)(
sparse_input_dict[feature_name])
else:
lookup_idx = sparse_input_dict[feature_name]
Expand All @@ -97,7 +97,7 @@ def varlen_embedding_lookup(embedding_dict, sequence_input_dict, varlen_sparse_f
feature_name = fc.name
embedding_name = fc.embedding_name
if fc.use_hash:
lookup_idx = Hash(fc.vocabulary_size, mask_zero=True)(sequence_input_dict[feature_name])
lookup_idx = Hash(fc.vocabulary_size, mask_zero=True, vocabulary_path=fc.vocabulary_path)(sequence_input_dict[feature_name])
else:
lookup_idx = sequence_input_dict[feature_name]
varlen_embedding_vec_dict[feature_name] = embedding_dict[embedding_name](lookup_idx)
Expand Down
9 changes: 3 additions & 6 deletions deepctr/layers/core.py
Expand Up @@ -68,8 +68,8 @@ def build(self, input_shape):
'inputs of a two inputs with shape (None,1,embedding_size) and (None,T,embedding_size)'
'Got different shapes: %s,%s' % (input_shape[0], input_shape[1]))
size = 4 * \
int(input_shape[0][-1]
) if len(self.hidden_units) == 0 else self.hidden_units[-1]
int(input_shape[0][-1]
) if len(self.hidden_units) == 0 else self.hidden_units[-1]
self.kernel = self.add_weight(shape=(size, 1),
initializer=glorot_normal(
seed=self.seed),
Expand All @@ -78,9 +78,6 @@ def build(self, input_shape):
shape=(1,), initializer=Zeros(), name="bias")
self.dnn = DNN(self.hidden_units, self.activation, self.l2_reg, self.dropout_rate, self.use_bn, seed=self.seed)

self.dense = tf.keras.layers.Lambda(lambda x: tf.nn.bias_add(tf.tensordot(
x[0], x[1], axes=(-1, 0)), x[2]))

super(LocalActivationUnit, self).build(
input_shape) # Be sure to call this somewhere!

Expand All @@ -96,7 +93,7 @@ def call(self, inputs, training=None, **kwargs):

att_out = self.dnn(att_input, training=training)

attention_score = self.dense([att_out, self.kernel, self.bias])
attention_score = tf.nn.bias_add(tf.tensordot(att_out, self.kernel, axes=(-1, 0)), self.bias)

return attention_score

Expand Down
49 changes: 43 additions & 6 deletions deepctr/layers/utils.py
Expand Up @@ -25,14 +25,47 @@ def compute_mask(self, inputs, mask):


class Hash(tf.keras.layers.Layer):
"""
hash the input to [0,num_buckets)
if mask_zero = True,0 or 0.0 will be set to 0,other value will be set in range[1,num_buckets)
"""Looks up keys in a table when setup `vocabulary_path`, which outputs the corresponding values.
If `vocabulary_path` is not set, `Hash` will hash the input to [0,num_buckets). When `mask_zero` = True,
input value `0` or `0.0` will be set to `0`, and other value will be set in range [1,num_buckets).
The following snippet initializes a `Hash` with `vocabulary_path` file with the first column as keys and
second column as values:
* `1,emerson`
* `2,lake`
* `3,palmer`
>>> hash = Hash(
... num_buckets=3+1,
... vocabulary_path=filename,
... default_value=0)
>>> hash(tf.constant('lake')).numpy()
2
>>> hash(tf.constant('lakeemerson')).numpy()
0
Args:
num_buckets: An `int` that is >= 1. The number of buckets or the vocabulary size + 1
when `vocabulary_path` is setup.
mask_zero: default is False. The `Hash` value will hash input `0` or `0.0` to value `0` when
the `mask_zero` is `True`. `mask_zero` is not used when `vocabulary_path` is setup.
vocabulary_path: default `None`. The `CSV` text file path of the vocabulary hash, which contains
two columns seperated by delimiter `comma`, the first column is the value and the second is
the key. The key data type is `string`, the value data type is `int`. The path must
be accessible from wherever `Hash` is initialized.
default_value: default '0'. The default value if a key is missing in the table.
**kwargs: Additional keyword arguments.
"""

def __init__(self, num_buckets, mask_zero=False, **kwargs):
def __init__(self, num_buckets, mask_zero=False, vocabulary_path=None, default_value=0, **kwargs):
self.num_buckets = num_buckets
self.mask_zero = mask_zero
self.vocabulary_path = vocabulary_path
self.default_value = default_value
if self.vocabulary_path:
initializer = tf.lookup.TextFileInitializer(vocabulary_path, 'string', 1, 'int64', 0, delimiter=',')
self.hash_table = tf.lookup.StaticHashTable(initializer, default_value=self.default_value)
super(Hash, self).__init__(**kwargs)

def build(self, input_shape):
Expand All @@ -41,13 +74,16 @@ def build(self, input_shape):

def call(self, x, mask=None, **kwargs):


if x.dtype != tf.string:
zero = tf.as_string(tf.zeros([1], dtype=x.dtype))
x = tf.as_string(x, )
else:
zero = tf.as_string(tf.zeros([1], dtype='int32'))

if self.vocabulary_path:
hash_x = self.hash_table.lookup(x)
return hash_x

num_buckets = self.num_buckets if not self.mask_zero else self.num_buckets - 1
try:
hash_x = tf.string_to_hash_bucket_fast(x, num_buckets,
Expand All @@ -60,8 +96,9 @@ def call(self, x, mask=None, **kwargs):
hash_x = (hash_x + 1) * mask

return hash_x

def get_config(self, ):
config = {'num_buckets': self.num_buckets, 'mask_zero': self.mask_zero, }
config = {'num_buckets': self.num_buckets, 'mask_zero': self.mask_zero, 'vocabulary_path': self.vocabulary_path, 'default_value': self.default_value}
base_config = super(Hash, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

Expand Down
5 changes: 3 additions & 2 deletions docs/source/Features.md
Expand Up @@ -23,12 +23,13 @@ DNN based CTR prediction models usually have following 4 modules:
## Feature Columns
### SparseFeat
``SparseFeat`` is a namedtuple with signature ``SparseFeat(name, vocabulary_size, embedding_dim, use_hash, dtype, embeddings_initializer, embedding_name, group_name, trainable)``
``SparseFeat`` is a namedtuple with signature ``SparseFeat(name, vocabulary_size, embedding_dim, use_hash, vocabulary_path, dtype, embeddings_initializer, embedding_name, group_name, trainable)``

- name : feature name
- vocabulary_size : number of unique feature values for sprase feature or hashing space when `use_hash=True`
- embedding_dim : embedding dimension
- use_hash : defualt `False`.If `True` the input will be hashed to space of size `vocabulary_size`.
- use_hash : default `False`.If `True` the input will be hashed to space of size `vocabulary_size`.
- vocabulary_path : default `None`. The `CSV` text file path of the vocabulary table used by `tf.lookup.TextFileInitializer`, which assigns one entry in the table for each line in the file. One entry contains two columns seperated by comma, the first is the value column, the second is the key column. The `0` value is reserved to use if a key is missing in the table, so hash value need start from `1`.
- dtype : default `int32`.dtype of input tensor.
- embeddings_initializer : initializer for the `embeddings` matrix.
- embedding_name : default `None`. If None, the embedding_name will be same as `name`.
Expand Down

0 comments on commit ab7b38a

Please sign in to comment.