Skip to content

Commit

Permalink
support python 3.9&3.10
Browse files Browse the repository at this point in the history
- support python 3.9 and 3.10
- support `cos` and `ln` attention_type in transformer
- polish docstring
  • Loading branch information
shenweichen committed Oct 16, 2022
1 parent 9564e05 commit ec78b9b
Show file tree
Hide file tree
Showing 13 changed files with 89 additions and 42 deletions.
4 changes: 2 additions & 2 deletions .github/ISSUE_TEMPLATE/bug_report.md
Expand Up @@ -19,8 +19,8 @@ Steps to reproduce the behavior:

**Operating environment(运行环境):**
- python version [e.g. 3.6, 3.7]
- tensorflow version [e.g. 1.4.0, 1.15.0, 2.5.0]
- deepctr version [e.g. 0.9.0,]
- tensorflow version [e.g. 1.4.0, 1.15.0, 2.10.0]
- deepctr version [e.g. 0.9.2,]

**Additional context**
Add any other context about the problem here.
4 changes: 2 additions & 2 deletions .github/ISSUE_TEMPLATE/question.md
Expand Up @@ -16,5 +16,5 @@ Add any other context about the problem here.

**Operating environment(运行环境):**
- python version [e.g. 3.6]
- tensorflow version [e.g. 1.4.0, 1.15.0, 2.5.0]
- deepctr version [e.g. 0.9.0,]
- tensorflow version [e.g. 1.4.0, 1.15.0, 2.10.0]
- deepctr version [e.g. 0.9.2,]
24 changes: 22 additions & 2 deletions .github/workflows/ci.yml
Expand Up @@ -17,8 +17,8 @@ jobs:
timeout-minutes: 180
strategy:
matrix:
python-version: [3.6,3.7,3.8]
tf-version: [1.4.0,1.15.0,2.5.0,2.6.0,2.7.0,2.8.0,2.9.0]
python-version: [3.6,3.7,3.8,3.9,3.10.7]
tf-version: [1.4.0,1.15.0,2.6.0,2.7.0,2.8.0,2.9.0,2.10.0]

exclude:
- python-version: 3.7
Expand All @@ -37,12 +37,32 @@ jobs:
tf-version: 2.8.0
- python-version: 3.6
tf-version: 2.9.0
- python-version: 3.6
tf-version: 2.10.0
- python-version: 3.9
tf-version: 1.4.0
- python-version: 3.9
tf-version: 1.15.0
- python-version: 3.9
tf-version: 2.2.0
- python-version: 3.9
tf-version: 2.5.0
- python-version: 3.9
tf-version: 2.6.0
- python-version: 3.9
tf-version: 2.7.0
- python-version: 3.10.7
tf-version: 1.4.0
- python-version: 3.10.7
tf-version: 1.15.0
- python-version: 3.10.7
tf-version: 2.2.0
- python-version: 3.10.7
tf-version: 2.5.0
- python-version: 3.10.7
tf-version: 2.6.0
- python-version: 3.10.7
tf-version: 2.7.0
steps:

- uses: actions/checkout@v3
Expand Down
8 changes: 3 additions & 5 deletions README.md
Expand Up @@ -18,14 +18,12 @@
<!-- [![Gitter](https://badges.gitter.im/DeepCTR/community.svg)](https://gitter.im/DeepCTR/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) -->


DeepCTR is a **Easy-to-use**,**Modular** and **Extendible** package of deep-learning based CTR models along with lots of
DeepCTR is a **Easy-to-use**, **Modular** and **Extendible** package of deep-learning based CTR models along with lots of
core components layers which can be used to easily build custom models.You can use any complex model with `model.fit()`
,and `model.predict()` .

- Provide `tf.keras.Model` like interface for **quick experiment**
. [example](https://deepctr-doc.readthedocs.io/en/latest/Quick-Start.html#getting-started-4-steps-to-deepctr)
- Provide `tensorflow estimator` interface for **large scale data** and **distributed training**
. [example](https://deepctr-doc.readthedocs.io/en/latest/Quick-Start.html#getting-started-4-steps-to-deepctr-estimator-with-tfrecord)
- Provide `tf.keras.Model` like interfaces for **quick experiment**. [example](https://deepctr-doc.readthedocs.io/en/latest/Quick-Start.html#getting-started-4-steps-to-deepctr)
- Provide `tensorflow estimator` interface for **large scale data** and **distributed training**. [example](https://deepctr-doc.readthedocs.io/en/latest/Quick-Start.html#getting-started-4-steps-to-deepctr-estimator-with-tfrecord)
- It is compatible with both `tf 1.x` and `tf 2.x`.

Some related projects:
Expand Down
2 changes: 1 addition & 1 deletion deepctr/__init__.py
@@ -1,4 +1,4 @@
from .utils import check_version

__version__ = '0.9.1'
__version__ = '0.9.2'
check_version(__version__)
2 changes: 1 addition & 1 deletion deepctr/feature_column.py
Expand Up @@ -95,7 +95,7 @@ def __hash__(self):
class DenseFeat(namedtuple('DenseFeat', ['name', 'dimension', 'dtype', 'transform_fn'])):
""" Dense feature
Args:
name: feature name,
name: feature name.
dimension: dimension of the feature, default = 1.
dtype: dtype of the feature, default="float32".
transform_fn: If not `None` , a function that can be used to transform
Expand Down
51 changes: 34 additions & 17 deletions deepctr/layers/sequence.py
Expand Up @@ -442,7 +442,7 @@ class Transformer(Layer):
- **blinding**: bool. Whether or not use blinding.
- **seed**: A Python integer to use as random seed.
- **supports_masking**:bool. Whether or not support masking.
- **attention_type**: str, Type of attention, the value must be one of { ``'scaled_dot_product'`` , ``'additive'`` }.
- **attention_type**: str, Type of attention, the value must be one of { ``'scaled_dot_product'`` , ``'cos'`` , ``'ln'`` , ``'additive'`` }.
- **output_type**: ``'mean'`` , ``'sum'`` or `None`. Whether or not use average/sum pooling for output.
References
Expand Down Expand Up @@ -490,6 +490,9 @@ def build(self, input_shape):
initializer=glorot_uniform(seed=self.seed))
self.v = self.add_weight('v', shape=[self.att_embedding_size], dtype=tf.float32,
initializer=glorot_uniform(seed=self.seed))
elif self.attention_type == "ln":
self.att_ln_q = LayerNormalization()
self.att_ln_k = LayerNormalization()
# if self.use_res:
# self.W_Res = self.add_weight(name='res', shape=[embedding_size, self.att_embedding_size * self.head_num], dtype=tf.float32,
# initializer=TruncatedNormal(seed=self.seed))
Expand Down Expand Up @@ -529,28 +532,42 @@ def call(self, inputs, mask=None, training=None, **kwargs):
queries = self.query_pe(queries)
keys = self.key_pe(queries)

querys = tf.tensordot(queries, self.W_Query,
axes=(-1, 0)) # None T_q D*head_num
keys = tf.tensordot(keys, self.W_key, axes=(-1, 0))
values = tf.tensordot(keys, self.W_Value, axes=(-1, 0))
Q = tf.tensordot(queries, self.W_Query,
axes=(-1, 0)) # N T_q D*h
K = tf.tensordot(keys, self.W_key, axes=(-1, 0))
V = tf.tensordot(keys, self.W_Value, axes=(-1, 0))

# head_num*None T_q D
querys = tf.concat(tf.split(querys, self.head_num, axis=2), axis=0)
keys = tf.concat(tf.split(keys, self.head_num, axis=2), axis=0)
values = tf.concat(tf.split(values, self.head_num, axis=2), axis=0)
# h*N T_q D
Q_ = tf.concat(tf.split(Q, self.head_num, axis=2), axis=0)
K_ = tf.concat(tf.split(K, self.head_num, axis=2), axis=0)
V_ = tf.concat(tf.split(V, self.head_num, axis=2), axis=0)

if self.attention_type == "scaled_dot_product":
# head_num*None T_q T_k
outputs = tf.matmul(querys, keys, transpose_b=True)
# h*N T_q T_k
outputs = tf.matmul(Q_, K_, transpose_b=True)

outputs = outputs / (keys.get_shape().as_list()[-1] ** 0.5)
outputs = outputs / (K_.get_shape().as_list()[-1] ** 0.5)
elif self.attention_type == "cos":
Q_cos = tf.nn.l2_normalize(Q_, dim=-1)
K_cos = tf.nn.l2_normalize(K_, dim=-1)

outputs = tf.matmul(Q_cos, K_cos, transpose_b=True) # h*N T_q T_k

outputs = outputs * 20 # Scale
elif self.attention_type == 'ln':
Q_ = self.att_ln_q(Q_)
K_ = self.att_ln_k(K_)

outputs = tf.matmul(Q_, K_, transpose_b=True) # h*N T_q T_k
# Scale
outputs = outputs / (K_.get_shape().as_list()[-1] ** 0.5)
elif self.attention_type == "additive":
querys_reshaped = tf.expand_dims(querys, axis=-2)
keys_reshaped = tf.expand_dims(keys, axis=-3)
outputs = tf.tanh(tf.nn.bias_add(querys_reshaped + keys_reshaped, self.b))
Q_reshaped = tf.expand_dims(Q_, axis=-2)
K_reshaped = tf.expand_dims(K_, axis=-3)
outputs = tf.tanh(tf.nn.bias_add(Q_reshaped + K_reshaped, self.b))
outputs = tf.squeeze(tf.tensordot(outputs, tf.expand_dims(self.v, axis=-1), axes=[-1, 0]), axis=-1)
else:
raise ValueError("attention_type must be scaled_dot_product or additive")
raise ValueError("attention_type must be [scaled_dot_product,cos,ln,additive]")

key_masks = tf.tile(key_masks, [self.head_num, 1])

Expand Down Expand Up @@ -583,7 +600,7 @@ def call(self, inputs, mask=None, training=None, **kwargs):
outputs = self.dropout(outputs, training=training)
# Weighted sum
# ( h*N, T_q, C/h)
result = tf.matmul(outputs, values)
result = tf.matmul(outputs, V_)
result = tf.concat(tf.split(result, self.head_num, axis=0), axis=2)

if self.use_res:
Expand Down
4 changes: 2 additions & 2 deletions deepctr/models/deepfm.py
Expand Up @@ -24,8 +24,8 @@ def DeepFM(linear_feature_columns, dnn_feature_columns, fm_group=(DEFAULT_GROUP_
dnn_activation='relu', dnn_use_bn=False, task='binary'):
"""Instantiates the DeepFM Network architecture.
:param linear_feature_columns: An iterable containing all the features used by linear part of the model.
:param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
:param linear_feature_columns: An iterable containing all the features used by the linear part of the model.
:param dnn_feature_columns: An iterable containing all the features used by the deep part of the model.
:param fm_group: list, group_name of features that will be used to do feature interactions.
:param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of DNN
:param l2_reg_linear: float. L2 regularizer strength applied to linear part
Expand Down
1 change: 1 addition & 0 deletions docs/source/History.md
@@ -1,4 +1,5 @@
# History
- 10/15/2022 : [v0.9.2](https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.2) released.Support python `3.9`,`3.10`.
- 06/11/2022 : [v0.9.1](https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.1) released.Improve compatibility with tensorflow `2.x`.
- 09/03/2021 : [v0.9.0](https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.0) released.Add multitask learning models:[SharedBottom](./Features.html#sharedbottom),[ESMM](./Features.html#esmm-entire-space-multi-task-model),[MMOE](./Features.html#mmoe-multi-gate-mixture-of-experts) and [PLE](./Features.html#ple-progressive-layered-extraction). [running example](./Examples.html#multitask-learning-mmoe)
- 07/18/2021 : [v0.8.7](https://github.com/shenweichen/DeepCTR/releases/tag/v0.8.7) released.Support pre-defined key-value vocabulary in `Hash` Layer. [example](./Examples.html#hash-layer-with-pre-defined-key-value-vocabulary)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Expand Up @@ -26,7 +26,7 @@
# The short X.Y version
version = ''
# The full version, including alpha/beta/rc tags
release = '0.9.1'
release = '0.9.2'


# -- General configuration ---------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions docs/source/index.rst
Expand Up @@ -42,12 +42,12 @@ You can read the latest code and related projects

News
-----
10/15/2022 : Support python `3.9`,`3.10`. `Changelog <https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.2>`_

06/11/2022 : Improve compatibility with tensorflow `2.x`. `Changelog <https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.1>`_

09/03/2021 : Add multitask learning models: `SharedBottom <./Features.html#sharedbottom>`_ , `ESMM <./Features.html#esmm-entire-space-multi-task-model>`_ , `MMOE <./Features.html#mmoe-multi-gate-mixture-of-experts>`_ , `PLE <./Features.html#ple-progressive-layered-extraction>`_ . `running example <./Examples.html#multitask-learning-mmoe>`_ `Changelog <https://github.com/shenweichen/DeepCTR/releases/tag/v0.9.0>`_

07/18/2021 : Support pre-defined key-value vocabulary in `Hash` Layer. `example <./Examples.html#hash-layer-with-pre-defined-key-value-vocabulary>`_ `Changelog <https://github.com/shenweichen/DeepCTR/releases/tag/v0.8.7>`_

DisscussionGroup
-----------------------

Expand Down
17 changes: 12 additions & 5 deletions setup.py
@@ -1,15 +1,21 @@
import setuptools

with open("README.md", "r") as fh:
with open("README.md", "r",encoding='utf-8') as fh:
long_description = fh.read()

REQUIRED_PACKAGES = [
import sys
if sys.version_info < (3, 9):
REQUIRED_PACKAGES = [
'h5py==2.10.0', 'requests'
]
]
else:
REQUIRED_PACKAGES = [
'h5py==3.7.0', 'requests'
]

setuptools.setup(
name="deepctr",
version="0.9.1",
version="0.9.2",
author="Weichen Shen",
author_email="weichenswc@163.com",
description="Easy-to-use,Modular and Extendible package of deep learning based CTR(Click Through Rate) prediction models with tensorflow 1.x and 2.x .",
Expand All @@ -35,10 +41,11 @@
'Intended Audience :: Science/Research',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development',
Expand Down
8 changes: 6 additions & 2 deletions tests/layers/sequence_test.py
Expand Up @@ -81,11 +81,15 @@ def test_BiLSTM(merge_mode):
input_shape=(BATCH_SIZE, SEQ_LENGTH, EMBEDDING_SIZE))


def test_Transformer():
@pytest.mark.parametrize(
'attention_type',
['scaled_dot_product', 'cos', 'ln', 'additive']
)
def test_Transformer(attention_type):
with CustomObjectScope({'Transformer': sequence.Transformer}):
layer_test(sequence.Transformer,
kwargs={'att_embedding_size': 1, 'head_num': 8, 'use_layer_norm': True, 'supports_masking': False,
'attention_type': 'additive', 'dropout_rate': 0.5, 'output_type': 'sum'},
'attention_type': attention_type, 'dropout_rate': 0.5, 'output_type': 'sum'},
input_shape=[(BATCH_SIZE, SEQ_LENGTH, EMBEDDING_SIZE), (BATCH_SIZE, SEQ_LENGTH, EMBEDDING_SIZE),
(BATCH_SIZE, 1), (BATCH_SIZE, 1)])

Expand Down

0 comments on commit ec78b9b

Please sign in to comment.