Skip to content

Commit

Permalink
no message
Browse files Browse the repository at this point in the history
  • Loading branch information
shenweichen committed Nov 3, 2022
1 parent 2405b38 commit 3c64524
Show file tree
Hide file tree
Showing 11 changed files with 56 additions and 44 deletions.
8 changes: 4 additions & 4 deletions deepctr/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import tensorflow as tf

from .activation import Dice
from .core import DNN, LocalActivationUnit, PredictionLayer, RegulationLayer
from .core import DNN, LocalActivationUnit, PredictionLayer, RegulationModule
from .interaction import (CIN, FM, AFMLayer, BiInteractionPooling, CrossNet, CrossNetMix,
InnerProductLayer, InteractingLayer,
OutterProductLayer, FGCNNLayer, SENETLayer, BilinearInteraction,
FieldWiseBiInteraction, FwFMLayer, FEFMLayer, BridgeLayer)
FieldWiseBiInteraction, FwFMLayer, FEFMLayer, BridgeModule)
from .normalization import LayerNormalization
from .sequence import (AttentionSequencePoolingLayer, BiasEncoding, BiLSTM,
KMaxPooling, SequencePoolingLayer, WeightedSequenceLayer,
Expand All @@ -28,7 +28,7 @@
'SequencePoolingLayer': SequencePoolingLayer,
'AttentionSequencePoolingLayer': AttentionSequencePoolingLayer,
'CIN': CIN,
'RegulationLayer': RegulationLayer,
'RegulationLayer': RegulationModule,
'InteractingLayer': InteractingLayer,
'LayerNormalization': LayerNormalization,
'BiLSTM': BiLSTM,
Expand All @@ -50,5 +50,5 @@
'FEFMLayer': FEFMLayer,
'reduce_sum': reduce_sum,
'PositionEncoding': PositionEncoding,
'BridgeLayer': BridgeLayer
'BridgeLayer': BridgeModule
}
12 changes: 5 additions & 7 deletions deepctr/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def get_config(self, ):
return dict(list(base_config.items()) + list(config.items()))


class RegulationLayer(Layer):
class RegulationModule(Layer):
"""Regulation module used in EDCN.
Input shape
Expand All @@ -280,17 +280,15 @@ class RegulationLayer(Layer):
- **tau** : Positive float, the temperature coefficient to control
distribution of field-wise gating unit.
- **seed** : A Python integer to use as random seed.
References
- [Enhancing Explicit and Implicit Feature Interactions via Information Sharing for Parallel Deep CTR Models.](https://dlp-kdd.github.io/assets/pdf/DLP-KDD_2021_paper_12.pdf)
"""

def __init__(self, tau=1.0, **kwargs):
if tau == 0:
raise ValueError("RegulationLayer tau can not be zero.")
raise ValueError("RegulationModule tau can not be zero.")
self.tau = 1.0 / tau
super(RegulationLayer, self).__init__(**kwargs)
super(RegulationModule, self).__init__(**kwargs)

def build(self, input_shape):
self.field_size = int(input_shape[1])
Expand All @@ -301,7 +299,7 @@ def build(self, input_shape):
name=self.name + '_field_weight')

# Be sure to call this somewhere!
super(RegulationLayer, self).build(input_shape)
super(RegulationModule, self).build(input_shape)

def call(self, inputs, **kwargs):

Expand All @@ -318,6 +316,6 @@ def compute_output_shape(self, input_shape):

def get_config(self):
config = {'tau': self.tau}
base_config = super(RegulationLayer, self).get_config()
base_config = super(RegulationModule, self).get_config()
base_config.update(config)
return base_config
37 changes: 15 additions & 22 deletions deepctr/layers/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,8 +1493,8 @@ def get_config(self):
return config


class BridgeLayer(Layer):
"""BridgeLayer layer used in EDCN
class BridgeModule(Layer):
"""Bridge Module used in EDCN
Input shape
- A list of two 2D tensor with shape: ``(batch_size, units)``.
Expand All @@ -1506,37 +1506,32 @@ class BridgeLayer(Layer):
- **bridge_type**: The type of bridge interaction, one of 'pointwise_addition', 'hadamard_product', 'concatenation', 'attention_pooling'
- **activation**: Activation function to use.
- **l2_reg**: float between 0 and 1. L2 regularizer strength applied to the kernel weights matrix.
- **seed**: A Python integer to use as random seed.
References
- [Enhancing Explicit and Implicit Feature Interactions via Information Sharing for Parallel Deep CTR Models.](https://dlp-kdd.github.io/assets/pdf/DLP-KDD_2021_paper_12.pdf)
"""

def __init__(self, bridge_type='hadamard_product', activation='relu', l2_reg=0, seed=1024, **kwargs):
def __init__(self, bridge_type='hadamard_product', activation='relu', **kwargs):
self.bridge_type = bridge_type
self.activation = activation
self.l2_reg = l2_reg
self.seed = seed

super(BridgeLayer, self).__init__(**kwargs)
super(BridgeModule, self).__init__(**kwargs)

def build(self, input_shape):
if not isinstance(input_shape, list) or len(input_shape) < 2:
raise ValueError(
'A `BridgeLayer` layer should be called '
'on a list of at least 2 inputs')
'A `BridgeModule` layer should be called '
'on a list of 2 inputs')

self.dnn_dim = int(input_shape[0][-1])
if self.bridge_type == "concatenation":
self.dense = Dense(self.dnn_dim, self.activation)
elif self.bridge_type == "attention_pooling":
self.dense_x = DNN([self.dnn_dim, self.dnn_dim], self.activation, output_activation='softmax')
self.dense_h = DNN([self.dnn_dim, self.dnn_dim], self.activation, output_activation='softmax')

self.dense = Dense(self.dnn_dim, self.activation)
self.dense_x = DNN([self.dnn_dim, self.dnn_dim], self.activation, output_activation='softmax')
self.dense_h = DNN([self.dnn_dim, self.dnn_dim], self.activation, output_activation='softmax')

super(BridgeLayer, self).build(input_shape) # Be sure to call this somewhere!
super(BridgeModule, self).build(input_shape) # Be sure to call this somewhere!

def call(self, inputs, **kwargs):
x, h = inputs
Expand All @@ -1545,7 +1540,7 @@ def call(self, inputs, **kwargs):
elif self.bridge_type == "hadamard_product":
return x * h
elif self.bridge_type == "concatenation":
return self.dense(tf.concat(inputs, axis=-1))
return self.dense(tf.concat([x, h], axis=-1))
elif self.bridge_type == "attention_pooling":
a_x = self.dense_x(x)
a_h = self.dense_h(h)
Expand All @@ -1555,12 +1550,10 @@ def compute_output_shape(self, input_shape):
return (None, self.dnn_dim)

def get_config(self):
base_config = super(BridgeLayer, self).get_config().copy()
base_config = super(BridgeModule, self).get_config().copy()
config = {
'bridge_type': self.bridge_type,
'l2_reg': self.l2_reg,
'activation': self.activation,
'seed': self.seed
'activation': self.activation
}
config.update(base_config)
return config
18 changes: 9 additions & 9 deletions deepctr/models/edcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
from tensorflow.python.keras.models import Model

from ..feature_column import build_input_features, get_linear_logit, input_from_feature_columns
from ..layers.core import PredictionLayer, DNN, RegulationLayer
from ..layers.interaction import CrossNet, BridgeLayer
from ..layers.core import PredictionLayer, DNN, RegulationModule
from ..layers.interaction import CrossNet, BridgeModule
from ..layers.utils import add_func, concat_func


def EDCN(linear_feature_columns,
dnn_feature_columns,
cross_num=2,
cross_parameterization='vector',
bridge_type='hadamard_product',
bridge_type='concatenation',
tau=1.0,
l2_reg_linear=1e-5,
l2_reg_embedding=1e-5,
Expand All @@ -31,6 +31,7 @@ def EDCN(linear_feature_columns,
dnn_activation='relu',
task='binary'):
"""Instantiates the Enhanced Deep&Cross Network architecture.
:param linear_feature_columns: An iterable containing all the features used by linear part of the model.
:param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
:param cross_num: positive integet,cross layer number
Expand Down Expand Up @@ -63,25 +64,24 @@ def EDCN(linear_feature_columns,
features, dnn_feature_columns, l2_reg_embedding, seed, support_dense=False)

emb_input = concat_func(sparse_embedding_list, axis=1)
deep_in = RegulationLayer(tau)(emb_input)
cross_in = RegulationLayer(tau)(emb_input)
deep_in = RegulationModule(tau)(emb_input)
cross_in = RegulationModule(tau)(emb_input)

field_size = len(sparse_embedding_list)
embedding_size = int(sparse_embedding_list[0].shape[-1])
cross_dim = field_size * embedding_size

for i in range(cross_num):

cross_out = CrossNet(1, parameterization=cross_parameterization,
l2_reg=l2_reg_cross)(cross_in)
deep_out = DNN([cross_dim], dnn_activation, l2_reg_dnn,
dnn_dropout, dnn_use_bn, seed=seed)(deep_in)
print(cross_out, deep_out)
bridge_out = BridgeLayer(bridge_type)([cross_out, deep_out])
bridge_out = BridgeModule(bridge_type)([cross_out, deep_out])
if i + 1 < cross_num:
bridge_out_list = Reshape([field_size, embedding_size])(bridge_out)
deep_in = RegulationLayer(tau)(bridge_out_list)
cross_in = RegulationLayer(tau)(bridge_out_list)
deep_in = RegulationModule(tau)(bridge_out_list)
cross_in = RegulationModule(tau)(bridge_out_list)

stack_out = Concatenate()([cross_out, deep_out, bridge_out])
final_logit = Dense(1, use_bias=False)(stack_out)
Expand Down
Binary file added docs/pics/EDCN.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 11 additions & 0 deletions docs/source/Features.md
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,17 @@ information routing across tasks in a general setup.

[Tang H, Liu J, Zhao M, et al. Progressive layered extraction (ple): A novel multi-task learning (mtl) model for personalized recommendations[C]//Fourteenth ACM Conference on Recommender Systems. 2020.](https://dl.acm.org/doi/10.1145/3383313.3412236)

### EDCN(Enhancing Explicit and Implicit Feature Interactions DCN)

EDCN introduces two advanced modules, namelybridge moduleandregulation module, which work collaboratively tocapture the layer-wise interactive signals and learn discriminativefeature distributions for each hidden layer of the parallel networks.

[**EDCN Model API**](./deepctr.models.edcn.html)

![EDCN](../pics/EDCN.png)

[Chen B, Wang Y, Liu Z, et al. Enhancing explicit and implicit feature interactions via information sharing for parallel deep ctr models[C]//Proceedings of the 30th ACM International Conference on Information & Knowledge Management. 2021: 3757-3766.](https://dlp-kdd.github.io/assets/pdf/DLP-KDD_2021_paper_12.pdf)


## Layers

The models of deepctr are modular, so you can use different modules to build your own models.
Expand Down
1 change: 1 addition & 0 deletions docs/source/History.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# History
- 11/05/2022 : [v0.9.3](https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.3) released.Add [EDCN]().
- 10/15/2022 : [v0.9.2](https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.2) released.Support python `3.9`,`3.10`.
- 06/11/2022 : [v0.9.1](https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.1) released.Improve compatibility with tensorflow `2.x`.
- 09/03/2021 : [v0.9.0](https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.0) released.Add multitask learning models:[SharedBottom](./Features.html#sharedbottom),[ESMM](./Features.html#esmm-entire-space-multi-task-model),[MMOE](./Features.html#mmoe-multi-gate-mixture-of-experts) and [PLE](./Features.html#ple-progressive-layered-extraction). [running example](./Examples.html#multitask-learning-mmoe)
Expand Down
1 change: 1 addition & 0 deletions docs/source/Models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@ DeepCTR Models API
ESMM<deepctr.models.multitask.esmm>
MMOE<deepctr.models.multitask.mmoe>
PLE<deepctr.models.multitask.ple>
EDCN<deepctr.models.edcn>


7 changes: 7 additions & 0 deletions docs/source/deepctr.models.edcn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
deepctr.models.edcn module
=========================

.. automodule:: deepctr.models.edcn
:members:
:no-undoc-members:
:no-show-inheritance:
1 change: 1 addition & 0 deletions docs/source/deepctr.models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Submodules
deepctr.models.ccpm
deepctr.models.dcn
deepctr.models.dcnmix
deepctr.models.edcn
deepctr.models.deepfm
deepctr.models.dien
deepctr.models.din
Expand Down
4 changes: 2 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ You can read the latest code and related projects

News
-----
11/05/2022 : Add `EDCN` . `Changelog <https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.3>`_

10/15/2022 : Support python `3.9`,`3.10`. `Changelog <https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.2>`_

06/11/2022 : Improve compatibility with tensorflow `2.x`. `Changelog <https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.1>`_

09/03/2021 : Add multitask learning models: `SharedBottom <./Features.html#sharedbottom>`_ , `ESMM <./Features.html#esmm-entire-space-multi-task-model>`_ , `MMOE <./Features.html#mmoe-multi-gate-mixture-of-experts>`_ , `PLE <./Features.html#ple-progressive-layered-extraction>`_ . `running example <./Examples.html#multitask-learning-mmoe>`_ `Changelog <https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.0>`_

DisscussionGroup
-----------------------

Expand Down

0 comments on commit 3c64524

Please sign in to comment.