Skip to content

Commit

Permalink
feat:add FGCNN&FGCNNLayer
Browse files Browse the repository at this point in the history
  • Loading branch information
浅梦 committed Apr 27, 2019
1 parent 7d5bfd9 commit ab7cdd8
Show file tree
Hide file tree
Showing 20 changed files with 311 additions and 42 deletions.
3 changes: 2 additions & 1 deletion README.md
Expand Up @@ -36,7 +36,8 @@ Let's [**Get Started!**](https://deepctr-doc.readthedocs.io/en/latest/Quick-Star
| Deep Interest Network | [KDD 2018][Deep Interest Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1706.06978.pdf) |
| Deep Interest Evolution Network | [AAAI 2019][Deep Interest Evolution Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1809.03672.pdf) |
| AutoInt | [arxiv 2018][AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks](https://arxiv.org/abs/1810.11921) |
| NFFM | [arxiv 2019][Field-aware Neural Factorization Machine for Click-Through Rate Prediction ](https://arxiv.org/pdf/1902.09096.pdf)(The original NFFM was first used by Yi Yang(yangyi868@gmail.com) in TSA competition in 2017.) |
| NFFM | [arxiv 2019][Field-aware Neural Factorization Machine for Click-Through Rate Prediction ](https://arxiv.org/pdf/1902.09096.pdf) (The original NFFM was first used by Yi Yang(yangyi868@gmail.com) in TSA competition in 2017.) |
| FGCNN | [WWW 2019][Feature Generation by Convolutional Neural Network for Click-Through Rate Prediction ](https://arxiv.org/pdf/1904.04447))



2 changes: 1 addition & 1 deletion deepctr/__init__.py
Expand Up @@ -2,5 +2,5 @@
from . import models
from .utils import check_version, SingleFeat, VarLenFeat

__version__ = '0.3.3'
__version__ = '0.3.4'
check_version(__version__)
6 changes: 3 additions & 3 deletions deepctr/input_embedding.py
Expand Up @@ -155,10 +155,10 @@ def get_inputs_list(inputs):

def get_inputs_embedding(feature_dim_dict, embedding_size, l2_reg_embedding, l2_reg_linear, init_std, seed,
sparse_input_dict, dense_input_dict, sequence_input_dict, sequence_input_len_dict,
sequence_max_len_dict, sequence_pooling_dict, include_linear):
sequence_max_len_dict, sequence_pooling_dict, include_linear,prefix=""):

deep_sparse_emb_dict = create_embedding_dict(
feature_dim_dict, embedding_size, init_std, seed, l2_reg_embedding)
feature_dim_dict, embedding_size, init_std, seed, l2_reg_embedding,prefix=prefix+'sparse')

deep_emb_list = get_embedding_vec_list(
deep_sparse_emb_dict, sparse_input_dict)
Expand All @@ -171,7 +171,7 @@ def get_inputs_embedding(feature_dim_dict, embedding_size, l2_reg_embedding, l2_

if include_linear:
linear_sparse_emb_dict = create_embedding_dict(
feature_dim_dict, 1, init_std, seed, l2_reg_linear, 'linear')
feature_dim_dict, 1, init_std, seed, l2_reg_linear,prefix+ 'linear')
linear_emb_list = get_embedding_vec_list(
linear_sparse_emb_dict, sparse_input_dict)
linear_emb_list = merge_sequence_input(linear_sparse_emb_dict, linear_emb_list, sequence_input_dict,
Expand Down
5 changes: 3 additions & 2 deletions deepctr/layers/__init__.py
Expand Up @@ -4,7 +4,7 @@
from .core import MLP, LocalActivationUnit, PredictionLayer
from .interaction import (CIN, FM, AFMLayer, BiInteractionPooling, CrossNet,
InnerProductLayer, InteractingLayer,
OutterProductLayer)
OutterProductLayer,FGCNNLayer)
from .normalization import LayerNormalization
from .sequence import (AttentionSequencePoolingLayer, BiasEncoding, BiLSTM,
KMaxPooling, Position_Embedding, SequencePoolingLayer,
Expand All @@ -31,4 +31,5 @@
'Transformer': Transformer,
'NoMask': NoMask,
'BiasEncoding': BiasEncoding,
'KMaxPooling': KMaxPooling}
'KMaxPooling': KMaxPooling,
'FGCNNLayer':FGCNNLayer}
131 changes: 111 additions & 20 deletions deepctr/layers/interaction.py
Expand Up @@ -16,6 +16,7 @@
from tensorflow.python.keras.regularizers import l2

from .activation import activation_fun
from .utils import concat_fun


class AFMLayer(Layer):
Expand Down Expand Up @@ -113,7 +114,7 @@ def call(self, inputs, **kwargs):
self.normalized_att_score = tf.nn.softmax(tf.tensordot(
attention_temp, self.projection_h, axes=(-1, 0)), dim=1)
attention_output = tf.reduce_sum(
self.normalized_att_score*bi_interaction, axis=1)
self.normalized_att_score * bi_interaction, axis=1)

attention_output = tf.nn.dropout(
attention_output, self.keep_prob, seed=1024)
Expand All @@ -130,7 +131,7 @@ def compute_output_shape(self, input_shape):
'on a list of inputs.')
return (None, 1)

def get_config(self,):
def get_config(self, ):
config = {'attention_factor': self.attention_factor,
'l2_reg_w': self.l2_reg_w, 'keep_prob': self.keep_prob, 'seed': self.seed}
base_config = super(AFMLayer, self).get_config()
Expand Down Expand Up @@ -175,7 +176,7 @@ def call(self, inputs, **kwargs):
concated_embeds_value, axis=1, keep_dims=True))
sum_of_square = tf.reduce_sum(
concated_embeds_value * concated_embeds_value, axis=1, keep_dims=True)
cross_term = 0.5*(square_of_sum - sum_of_square)
cross_term = 0.5 * (square_of_sum - sum_of_square)

return cross_term

Expand Down Expand Up @@ -206,7 +207,7 @@ class CIN(Layer):
- [Lian J, Zhou X, Zhang F, et al. xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems[J]. arXiv preprint arXiv:1803.05170, 2018.] (https://arxiv.org/pdf/1803.05170.pdf)
"""

def __init__(self, layer_size=(128, 128), activation='relu',split_half=True, l2_reg=1e-5,seed=1024, **kwargs):
def __init__(self, layer_size=(128, 128), activation='relu', split_half=True, l2_reg=1e-5, seed=1024, **kwargs):
if len(layer_size) == 0:
raise ValueError(
"layer_size must be a list(tuple) of length greater than 1")
Expand All @@ -230,7 +231,8 @@ def build(self, input_shape):
self.filters.append(self.add_weight(name='filter' + str(i),
shape=[1, self.field_nums[-1]
* self.field_nums[0], size],
dtype=tf.float32, initializer=glorot_uniform(seed=self.seed + i), regularizer=l2(self.l2_reg)))
dtype=tf.float32, initializer=glorot_uniform(seed=self.seed + i),
regularizer=l2(self.l2_reg)))

self.bias.append(self.add_weight(name='bias' + str(i), shape=[size], dtype=tf.float32,
initializer=tf.keras.initializers.Zeros()))
Expand Down Expand Up @@ -346,13 +348,13 @@ def build(self, input_shape):
"Unexpected inputs dimensions %d, expect to be 2 dimensions" % (len(input_shape),))

dim = input_shape[-1].value
self.kernels = [self.add_weight(name='kernel'+str(i),
self.kernels = [self.add_weight(name='kernel' + str(i),
shape=(dim, 1),
initializer=glorot_normal(
seed=self.seed),
regularizer=l2(self.l2_reg),
trainable=True) for i in range(self.layer_num)]
self.bias = [self.add_weight(name='bias'+str(i),
self.bias = [self.add_weight(name='bias' + str(i),
shape=(dim, 1),
initializer=Zeros(),
trainable=True) for i in range(self.layer_num)]
Expand All @@ -373,7 +375,7 @@ def call(self, inputs, **kwargs):
x_l = tf.squeeze(x_l, axis=2)
return x_l

def get_config(self,):
def get_config(self, ):

config = {'layer_num': self.layer_num,
'l2_reg': self.l2_reg, 'seed': self.seed}
Expand Down Expand Up @@ -436,7 +438,7 @@ class InnerProductLayer(Layer):
product or inner product between feature vectors.
Input shape
- A list of N 3D tensor with shape: ``(batch_size,1,embedding_size)``.
- a list of 3D tensor with shape: ``(batch_size,1,embedding_size)``.
Output shape
- 3D tensor with shape: ``(batch_size, N*(N-1)/2 ,1)`` if use reduce_sum. or 3D tensor with shape: ``(batch_size, N*(N-1)/2, embedding_size )`` if not use reduce_sum.
Expand Down Expand Up @@ -492,7 +494,9 @@ def call(self, inputs, **kwargs):
col.append(j)
p = tf.concat([embed_list[idx]
for idx in row], axis=1) # batch num_pairs k
q = tf.concat([embed_list[idx] for idx in col], axis=1)
q = tf.concat([embed_list[idx]
for idx in col], axis=1)

inner_product = p * q
if self.reduce_sum:
inner_product = tf.reduce_sum(
Expand All @@ -509,7 +513,7 @@ def compute_output_shape(self, input_shape):
else:
return (input_shape[0], num_pairs, embed_size)

def get_config(self,):
def get_config(self, ):
config = {'reduce_sum': self.reduce_sum, }
base_config = super(InnerProductLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
Expand Down Expand Up @@ -549,14 +553,18 @@ def build(self, input_shape):
raise ValueError(
"Unexpected inputs dimensions %d, expect to be 3 dimensions" % (len(input_shape)))
embedding_size = input_shape[-1].value
self.W_Query = self.add_weight(name='query', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32,
self.W_Query = self.add_weight(name='query', shape=[embedding_size, self.att_embedding_size * self.head_num],
dtype=tf.float32,
initializer=tf.keras.initializers.TruncatedNormal(seed=self.seed))
self.W_key = self.add_weight(name='key', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32,
initializer=tf.keras.initializers.TruncatedNormal(seed=self.seed+1))
self.W_Value = self.add_weight(name='value', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32,
initializer=tf.keras.initializers.TruncatedNormal(seed=self.seed+2))
self.W_key = self.add_weight(name='key', shape=[embedding_size, self.att_embedding_size * self.head_num],
dtype=tf.float32,
initializer=tf.keras.initializers.TruncatedNormal(seed=self.seed + 1))
self.W_Value = self.add_weight(name='value', shape=[embedding_size, self.att_embedding_size * self.head_num],
dtype=tf.float32,
initializer=tf.keras.initializers.TruncatedNormal(seed=self.seed + 2))
if self.use_res:
self.W_Res = self.add_weight(name='res', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32,
self.W_Res = self.add_weight(name='res', shape=[embedding_size, self.att_embedding_size * self.head_num],
dtype=tf.float32,
initializer=tf.keras.initializers.TruncatedNormal(seed=self.seed))

# Be sure to call this somewhere!
Expand Down Expand Up @@ -656,10 +664,12 @@ def build(self, input_shape):
embed_size = input_shape[-1].value
if self.kernel_type == 'mat':

self.kernel = self.add_weight(shape=(embed_size, num_pairs, embed_size), initializer=glorot_uniform(seed=self.seed),
self.kernel = self.add_weight(shape=(embed_size, num_pairs, embed_size),
initializer=glorot_uniform(seed=self.seed),
name='kernel')
elif self.kernel_type == 'vec':
self.kernel = self.add_weight(shape=(num_pairs, embed_size,), initializer=glorot_uniform(self.seed), name='kernel'
self.kernel = self.add_weight(shape=(num_pairs, embed_size,), initializer=glorot_uniform(self.seed),
name='kernel'
)
elif self.kernel_type == 'num':
self.kernel = self.add_weight(
Expand Down Expand Up @@ -737,7 +747,88 @@ def compute_output_shape(self, input_shape):
num_pairs = int(num_inputs * (num_inputs - 1) / 2)
return (None, num_pairs)

def get_config(self,):
def get_config(self, ):
config = {'kernel_type': self.kernel_type, 'seed': self.seed}
base_config = super(OutterProductLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


class FGCNNLayer(Layer):
"""Feature Generation Layer used in FGCNN,including Convolution,MaxPooling and Recombination.
Input shape
- A 3D tensor with shape:``(batch_size,field_size,embedding_size)``.
Output shape
- 3D tensor with shape: ``(batch_size,new_feture_num,embedding_size)``.
References
- [Liu B, Tang R, Chen Y, et al. Feature Generation by Convolutional Neural Network for Click-Through Rate Prediction[J]. arXiv preprint arXiv:1904.04447, 2019.](https://arxiv.org/pdf/1904.04447)
"""

def __init__(self, filters=(14, 16,), kernel_width=(7, 7,), new_maps=(3, 3,), pooling_width=(2, 2),
**kwargs):
if not (len(filters) == len(kernel_width) == len(new_maps) == len(pooling_width)):
raise ValueError("length of argument must be equal")
self.filters = filters
self.kernel_width = kernel_width
self.new_maps = new_maps
self.pooling_width = pooling_width

super(FGCNNLayer, self).__init__(**kwargs)

def build(self, input_shape):

if len(input_shape) != 3:
raise ValueError(
"Unexpected inputs dimensions %d, expect to be 3 dimensions" % (len(input_shape)))

super(FGCNNLayer, self).build(
input_shape) # Be sure to call this somewhere!

def call(self, inputs, **kwargs):

if K.ndim(inputs) != 3:
raise ValueError(
"Unexpected inputs dimensions %d, expect to be 3 dimensions" % (K.ndim(inputs)))

embedding_size = inputs.shape[-1].value
pooling_result = tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=3))(inputs)

new_feature_list = []

for i in range(1, len(self.filters) + 1):
filters = self.filters[i - 1]
width = self.kernel_width[i - 1]
new_filters = self.new_maps[i - 1]
pooling_width = self.pooling_width[i - 1]
conv_result = tf.keras.layers.Conv2D(filters=filters, kernel_size=(width, 1), strides=(1, 1),
padding='same',
activation='tanh', use_bias=True, )(pooling_result)
pooling_result = tf.keras.layers.MaxPooling2D(pool_size=(pooling_width, 1))(conv_result)
flatten_result = tf.keras.layers.Flatten()(pooling_result)
new_result = tf.keras.layers.Dense(pooling_result.shape[1].value * embedding_size * new_filters,
activation='tanh', use_bias=True)(flatten_result)
new_feature_list.append(
tf.keras.layers.Reshape((pooling_result.shape[1].value * new_filters, embedding_size))(new_result))
new_features = concat_fun(new_feature_list, axis=1)
return new_features

def compute_output_shape(self, input_shape):

new_features_num = 0
features_num = input_shape[1]

for i in range(0, len(self.kernel_width)):
pooled_features_num = features_num // self.pooling_width[i]
new_features_num += self.new_maps[i] * pooled_features_num
features_num = pooled_features_num

return (None, new_features_num, input_shape[-1])

def get_config(self, ):
config = {'kernel_width': self.kernel_width, 'filters': self.filters, 'new_maps': self.new_maps,
'pooling_width': self.pooling_width}
base_config = super(FGCNNLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
3 changes: 2 additions & 1 deletion deepctr/models/__init__.py
Expand Up @@ -12,6 +12,7 @@
from .pnn import PNN
from .wdl import WDL
from .xdeepfm import xDeepFM
from .fgcnn import FGCNN

__all__ = ["AFM", "CCPM","DCN", "MLR", "DeepFM",
"MLR", "NFM", "DIN", "DIEN", "FNN", "PNN", "WDL", "xDeepFM", "AutoInt", "NFFM"]
"MLR", "NFM", "DIN", "DIEN", "FNN", "PNN", "WDL", "xDeepFM", "AutoInt", "NFFM","FGCNN"]
2 changes: 2 additions & 0 deletions deepctr/models/ccpm.py
Expand Up @@ -39,6 +39,8 @@ def CCPM(feature_dim_dict, embedding_size=8, conv_kernel_width=(6, 5), conv_filt
"""

check_feature_config_dict(feature_dim_dict)
if len(conv_kernel_width)!=len(conv_filters):
raise ValueError("conv_kernel_width must have same element with conv_filters")

deep_emb_list, linear_logit, inputs_list = preprocess_input_embedding(feature_dim_dict, embedding_size,
l2_reg_embedding, l2_reg_linear, init_std,
Expand Down

0 comments on commit ab7cdd8

Please sign in to comment.