Skip to content

Commit

Permalink
add EDCN model
Browse files Browse the repository at this point in the history
- add EDCN model
- fix savedmodel error in tf2.4
- fix typo in position encoding layer
  • Loading branch information
shenweichen committed Nov 10, 2022
2 parents ec78b9b + 6789597 commit 4db36c2
Show file tree
Hide file tree
Showing 25 changed files with 491 additions and 132 deletions.
66 changes: 33 additions & 33 deletions .github/workflows/ci.yml
@@ -1,6 +1,6 @@
name: CI
name: CI_TF2

on:
on:
push:
path:
- 'deepctr/*'
Expand All @@ -9,17 +9,17 @@ on:
path:
- 'deepctr/*'
- 'tests/*'

jobs:
build:

runs-on: ubuntu-latest
timeout-minutes: 180
strategy:
matrix:
python-version: [3.6,3.7,3.8,3.9,3.10.7]
tf-version: [1.4.0,1.15.0,2.6.0,2.7.0,2.8.0,2.9.0,2.10.0]
python-version: [ 3.6,3.7,3.8, 3.9,3.10.7 ]
tf-version: [ 2.6.0,2.7.0,2.8.0,2.9.0,2.10.0 ]

exclude:
- python-version: 3.7
tf-version: 1.4.0
Expand Down Expand Up @@ -64,31 +64,31 @@ jobs:
- python-version: 3.10.7
tf-version: 2.7.0
steps:

- uses: actions/checkout@v3

- name: Setup python environment
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
pip3 install -q tensorflow==${{ matrix.tf-version }}
pip install -q protobuf==3.19.0
pip install -q requests
pip install -e .
- name: Test with pytest
timeout-minutes: 180
run: |
pip install -q pytest
pip install -q pytest-cov
pip install -q python-coveralls
pytest --cov=deepctr --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3.1.0
with:
token: ${{secrets.CODECOV_TOKEN}}
file: ./coverage.xml
flags: pytest
name: py${{ matrix.python-version }}-tf${{ matrix.tf-version }}
- uses: actions/checkout@v3

- name: Setup python environment
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
pip3 install -q tensorflow==${{ matrix.tf-version }}
pip install -q protobuf==3.19.0
pip install -q requests
pip install -e .
- name: Test with pytest
timeout-minutes: 180
run: |
pip install -q pytest
pip install -q pytest-cov
pip install -q python-coveralls
pytest --cov=deepctr --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3.1.0
with:
token: ${{secrets.CODECOV_TOKEN}}
file: ./coverage.xml
flags: pytest
name: py${{ matrix.python-version }}-tf${{ matrix.tf-version }}
96 changes: 96 additions & 0 deletions .github/workflows/ci2.yml
@@ -0,0 +1,96 @@
name: CI_TF1

on:
push:
path:
- 'deepctr/*'
- 'tests/*'
pull_request:
path:
- 'deepctr/*'
- 'tests/*'

jobs:
build:

runs-on: ubuntu-latest
timeout-minutes: 180
strategy:
matrix:
python-version: [ 3.6,3.7 ]
tf-version: [ 1.15.0 ]

exclude:
- python-version: 3.7
tf-version: 1.4.0
- python-version: 3.7
tf-version: 1.12.0
- python-version: 3.7
tf-version: 1.15.0
- python-version: 3.8
tf-version: 1.4.0
- python-version: 3.8
tf-version: 1.14.0
- python-version: 3.8
tf-version: 1.15.0
- python-version: 3.6
tf-version: 2.7.0
- python-version: 3.6
tf-version: 2.8.0
- python-version: 3.6
tf-version: 2.9.0
- python-version: 3.6
tf-version: 2.10.0
- python-version: 3.9
tf-version: 1.4.0
- python-version: 3.9
tf-version: 1.15.0
- python-version: 3.9
tf-version: 2.2.0
- python-version: 3.9
tf-version: 2.5.0
- python-version: 3.9
tf-version: 2.6.0
- python-version: 3.9
tf-version: 2.7.0
- python-version: 3.10.7
tf-version: 1.4.0
- python-version: 3.10.7
tf-version: 1.15.0
- python-version: 3.10.7
tf-version: 2.2.0
- python-version: 3.10.7
tf-version: 2.5.0
- python-version: 3.10.7
tf-version: 2.6.0
- python-version: 3.10.7
tf-version: 2.7.0
steps:

- uses: actions/checkout@v3

- name: Setup python environment
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
pip3 install -q tensorflow==${{ matrix.tf-version }}
pip install -q protobuf==3.19.0
pip install -q requests
pip install -e .
- name: Test with pytest
timeout-minutes: 180
run: |
pip install -q pytest
pip install -q pytest-cov
pip install -q python-coveralls
pytest --cov=deepctr --cov-report=xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3.1.0
with:
token: ${{secrets.CODECOV_TOKEN}}
file: ./coverage.xml
flags: pytest
name: py${{ matrix.python-version }}-tf${{ matrix.tf-version }}
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -66,6 +66,7 @@ Introduction](https://zhuanlan.zhihu.com/p/53231955)) and [welcome to join us!](
| ESMM | [SIGIR 2018][Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate](https://arxiv.org/abs/1804.07931) |
| MMOE | [KDD 2018][Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts](https://dl.acm.org/doi/abs/10.1145/3219819.3220007) |
| PLE | [RecSys 2020][Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations](https://dl.acm.org/doi/10.1145/3383313.3412236) |
| EDCN | [KDD 2021][Enhancing Explicit and Implicit Feature Interactions via Information Sharing for Parallel Deep CTR Models](https://dlp-kdd.github.io/assets/pdf/DLP-KDD_2021_paper_12.pdf) |

## Citation

Expand Down
2 changes: 1 addition & 1 deletion deepctr/__init__.py
@@ -1,4 +1,4 @@
from .utils import check_version

__version__ = '0.9.2'
__version__ = '0.9.3'
check_version(__version__)
14 changes: 8 additions & 6 deletions deepctr/layers/__init__.py
@@ -1,17 +1,16 @@
import tensorflow as tf

from .activation import Dice
from .core import DNN, LocalActivationUnit, PredictionLayer
from .core import DNN, LocalActivationUnit, PredictionLayer, RegulationModule
from .interaction import (CIN, FM, AFMLayer, BiInteractionPooling, CrossNet, CrossNetMix,
InnerProductLayer, InteractingLayer,
OutterProductLayer, FGCNNLayer, SENETLayer, BilinearInteraction,
FieldWiseBiInteraction, FwFMLayer, FEFMLayer)
FieldWiseBiInteraction, FwFMLayer, FEFMLayer, BridgeModule)
from .normalization import LayerNormalization
from .sequence import (AttentionSequencePoolingLayer, BiasEncoding, BiLSTM,
KMaxPooling, SequencePoolingLayer, WeightedSequenceLayer,
Transformer, DynamicGRU,PositionEncoding)

from .utils import NoMask, Hash, Linear, _Add, combined_dnn_input, softmax, reduce_sum
Transformer, DynamicGRU, PositionEncoding)
from .utils import NoMask, Hash, Linear, _Add, combined_dnn_input, softmax, reduce_sum, Concat

custom_objects = {'tf': tf,
'InnerProductLayer': InnerProductLayer,
Expand All @@ -38,6 +37,7 @@
'FGCNNLayer': FGCNNLayer,
'Hash': Hash,
'Linear': Linear,
'Concat': Concat,
'DynamicGRU': DynamicGRU,
'SENETLayer': SENETLayer,
'BilinearInteraction': BilinearInteraction,
Expand All @@ -48,5 +48,7 @@
'softmax': softmax,
'FEFMLayer': FEFMLayer,
'reduce_sum': reduce_sum,
'PositionEncoding':PositionEncoding
'PositionEncoding': PositionEncoding,
'RegulationModule': RegulationModule,
'BridgeModule': BridgeModule
}
58 changes: 56 additions & 2 deletions deepctr/layers/core.py
Expand Up @@ -10,9 +10,9 @@
from tensorflow.python.keras import backend as K

try:
from tensorflow.python.ops.init_ops_v2 import Zeros, glorot_normal
from tensorflow.python.ops.init_ops_v2 import Zeros, Ones, glorot_normal
except ImportError:
from tensorflow.python.ops.init_ops import Zeros, glorot_normal_initializer as glorot_normal
from tensorflow.python.ops.init_ops import Zeros, Ones, glorot_normal_initializer as glorot_normal

from tensorflow.python.keras.layers import Layer, Dropout

Expand Down Expand Up @@ -265,3 +265,57 @@ def get_config(self, ):
config = {'task': self.task, 'use_bias': self.use_bias}
base_config = super(PredictionLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


class RegulationModule(Layer):
"""Regulation module used in EDCN.
Input shape
- 3D tensor with shape: ``(batch_size,field_size,embedding_size)``.
Output shape
- 2D tensor with shape: ``(batch_size,field_size * embedding_size)``.
Arguments
- **tau** : Positive float, the temperature coefficient to control
distribution of field-wise gating unit.
References
- [Enhancing Explicit and Implicit Feature Interactions via Information Sharing for Parallel Deep CTR Models.](https://dlp-kdd.github.io/assets/pdf/DLP-KDD_2021_paper_12.pdf)
"""

def __init__(self, tau=1.0, **kwargs):
if tau == 0:
raise ValueError("RegulationModule tau can not be zero.")
self.tau = 1.0 / tau
super(RegulationModule, self).__init__(**kwargs)

def build(self, input_shape):
self.field_size = int(input_shape[1])
self.embedding_size = int(input_shape[2])
self.g = self.add_weight(
shape=(1, self.field_size, 1),
initializer=Ones(),
name=self.name + '_field_weight')

# Be sure to call this somewhere!
super(RegulationModule, self).build(input_shape)

def call(self, inputs, **kwargs):

if K.ndim(inputs) != 3:
raise ValueError(
"Unexpected inputs dimensions %d, expect to be 3 dimensions" % (K.ndim(inputs)))

feild_gating_score = tf.nn.softmax(self.g * self.tau, 1)
E = inputs * feild_gating_score
return tf.reshape(E, [-1, self.field_size * self.embedding_size])

def compute_output_shape(self, input_shape):
return (None, self.field_size * self.embedding_size)

def get_config(self):
config = {'tau': self.tau}
base_config = super(RegulationModule, self).get_config()
base_config.update(config)
return base_config
70 changes: 69 additions & 1 deletion deepctr/layers/interaction.py
Expand Up @@ -3,7 +3,8 @@
Authors:
Weichen Shen,weichenswc@163.com,
Harshit Pande
Harshit Pande,
Yi He, heyi_jack@163.com
"""

Expand All @@ -26,6 +27,7 @@

from .activation import activation_layer
from .utils import concat_func, reduce_sum, softmax, reduce_mean
from .core import DNN


class AFMLayer(Layer):
Expand Down Expand Up @@ -1489,3 +1491,69 @@ def get_config(self):
'regularizer': self.regularizer,
})
return config


class BridgeModule(Layer):
"""Bridge Module used in EDCN
Input shape
- A list of two 2D tensor with shape: ``(batch_size, units)``.
Output shape
- 2D tensor with shape: ``(batch_size, units)``.
Arguments
- **bridge_type**: The type of bridge interaction, one of 'pointwise_addition', 'hadamard_product', 'concatenation', 'attention_pooling'
- **activation**: Activation function to use.
References
- [Enhancing Explicit and Implicit Feature Interactions via Information Sharing for Parallel Deep CTR Models.](https://dlp-kdd.github.io/assets/pdf/DLP-KDD_2021_paper_12.pdf)
"""

def __init__(self, bridge_type='hadamard_product', activation='relu', **kwargs):
self.bridge_type = bridge_type
self.activation = activation

super(BridgeModule, self).__init__(**kwargs)

def build(self, input_shape):
if not isinstance(input_shape, list) or len(input_shape) < 2:
raise ValueError(
'A `BridgeModule` layer should be called '
'on a list of 2 inputs')

self.dnn_dim = int(input_shape[0][-1])
if self.bridge_type == "concatenation":
self.dense = Dense(self.dnn_dim, self.activation)
elif self.bridge_type == "attention_pooling":
self.dense_x = DNN([self.dnn_dim, self.dnn_dim], self.activation, output_activation='softmax')
self.dense_h = DNN([self.dnn_dim, self.dnn_dim], self.activation, output_activation='softmax')

super(BridgeModule, self).build(input_shape) # Be sure to call this somewhere!

def call(self, inputs, **kwargs):
x, h = inputs
if self.bridge_type == "pointwise_addition":
return x + h
elif self.bridge_type == "hadamard_product":
return x * h
elif self.bridge_type == "concatenation":
return self.dense(tf.concat([x, h], axis=-1))
elif self.bridge_type == "attention_pooling":
a_x = self.dense_x(x)
a_h = self.dense_h(h)
return a_x * x + a_h * h

def compute_output_shape(self, input_shape):
return (None, self.dnn_dim)

def get_config(self):
base_config = super(BridgeModule, self).get_config().copy()
config = {
'bridge_type': self.bridge_type,
'activation': self.activation
}
config.update(base_config)
return config

0 comments on commit 4db36c2

Please sign in to comment.