Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Weichen Shen committed Jan 9, 2019
1 parent 18cf3c0 commit d7487a6
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 14 deletions.
32 changes: 19 additions & 13 deletions deepctr/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,9 @@ def call(self, inputs, **kwargs):
x_0 = tf.expand_dims(inputs, axis=2)
x_l = x_0
for i in range(self.layer_num):
xl_w = tf.tensordot(tf.transpose(
x_l, [0, 2, 1]), self.kernels[i], axes=(-1, 0))
xl_w = tf.tensordot(x_l, self.kernels[i], axes=(1, 0))
dot_ = tf.matmul(x_0, xl_w)
x_l = dot_ + x_l + self.bias[i]
x_l = dot_ + self.bias[i] + x_l
x_l = tf.squeeze(x_l, axis=2)
return x_l

Expand Down Expand Up @@ -504,7 +503,6 @@ def get_config(self,):
return dict(list(base_config.items()) + list(config.items()))



class InteractingLayer(Layer):
"""A Layer used in AutoInt that model the correlations between different feature fields by multi-head self-attention mechanism.
Expand All @@ -524,6 +522,7 @@ class InteractingLayer(Layer):
References
- [Song W, Shi C, Xiao Z, et al. AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks[J]. arXiv preprint arXiv:1810.11921, 2018.](https://arxiv.org/abs/1810.11921)
"""

def __init__(self, att_embedding_size=8, head_num=2, use_res=True, seed=1024, **kwargs):
if head_num <= 0:
raise ValueError('head_num must be a int > 0')
Expand All @@ -535,7 +534,8 @@ def __init__(self, att_embedding_size=8, head_num=2, use_res=True, seed=1024, **

def build(self, input_shape):
if len(input_shape) != 3:
raise ValueError("Unexpected inputs dimensions %d, expect to be 3 dimensions" % (len(input_shape)))
raise ValueError(
"Unexpected inputs dimensions %d, expect to be 3 dimensions" % (len(input_shape)))
embedding_size = input_shape[-1].value
self.W_Query = self.add_weight(name='query', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32,
initializer=tf.keras.initializers.glorot_uniform(seed=self.seed))
Expand All @@ -547,26 +547,32 @@ def build(self, input_shape):
self.W_Res = self.add_weight(name='res', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32,
initializer=tf.keras.initializers.glorot_uniform(seed=self.seed))

super(InteractingLayer, self).build(input_shape) # Be sure to call this somewhere!
# Be sure to call this somewhere!
super(InteractingLayer, self).build(input_shape)

def call(self, inputs, **kwargs):
if K.ndim(inputs) != 3:
raise ValueError("Unexpected inputs dimensions %d, expect to be 3 dimensions" % (K.ndim(inputs)))
raise ValueError(
"Unexpected inputs dimensions %d, expect to be 3 dimensions" % (K.ndim(inputs)))

querys = tf.tensordot(inputs, self.W_Query, axes=(-1, 0)) # None F D*head_num
querys = tf.tensordot(inputs, self.W_Query,
axes=(-1, 0)) # None F D*head_num
keys = tf.tensordot(inputs, self.W_key, axes=(-1, 0))
values = tf.tensordot(inputs, self.W_Value, axes=(-1, 0))

querys = tf.stack(tf.split(querys, self.head_num, axis=2)) # head_num None F D
# head_num None F D
querys = tf.stack(tf.split(querys, self.head_num, axis=2))
keys = tf.stack(tf.split(keys, self.head_num, axis=2))
values = tf.stack(tf.split(values, self.head_num, axis=2))

inner_product = tf.matmul(querys, keys, transpose_b=True) # head_num None F F
inner_product = tf.matmul(
querys, keys, transpose_b=True) # head_num None F F
self.normalized_att_scores = tf.nn.softmax(inner_product)

result = tf.matmul(self.normalized_att_scores, values)#head_num None F D
result = tf.matmul(self.normalized_att_scores,
values) # head_num None F D
result = tf.concat(tf.split(result, self.head_num, ), axis=-1)
result = tf.squeeze(result, axis=0)#None F D*head_num
result = tf.squeeze(result, axis=0) # None F D*head_num

if self.use_res:
result += tf.tensordot(inputs, self.W_Res, axes=(-1, 0))
Expand Down Expand Up @@ -901,7 +907,7 @@ class PredictionLayer(Layer):
Arguments
- **activation**: Activation function to use.
- **use_bias**: bool.Whther add bias term.
- **use_bias**: bool.Whether add bias term or not.
"""

def __init__(self, activation='sigmoid', use_bias=True, **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion deepctr/sequence.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from tensorflow.python.keras.layers import Layer
from .layers import LocalActivationUnit
import tensorflow as tf
from tensorflow.python.keras.layers import GRU


class SequencePoolingLayer(Layer):
Expand All @@ -9,7 +10,7 @@ class SequencePoolingLayer(Layer):
Input shape
- A list of two tensor [seq_value,seq_len]
- seq_value is a 3D tensor with shape: ``(batch_size, T, embedding_size``
- seq_value is a 3D tensor with shape: ``(batch_size, T, embedding_size)``
- seq_len is a 2D tensor with shape : ``(batch_size, 1)``,indicate valid length of each sequence.
Expand Down
1 change: 1 addition & 0 deletions docs/source/Examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ This example shows how to use ``DeepFM`` with sequence(multi-value) feature. You
```python
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder
from tensorflow.python.keras.preprocessing.sequence import pad_sequences
from deepctr.models import DeepFM
from deepctr.utils import VarLenFeature
Expand Down
1 change: 1 addition & 0 deletions examples/run_multivalue_movielens.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd
import numpy as np
from tensorflow.python.keras.preprocessing.sequence import pad_sequences
from sklearn.preprocessing import LabelEncoder
from deepctr.models import DeepFM
from deepctr.utils import VarLenFeature

Expand Down

0 comments on commit d7487a6

Please sign in to comment.