Skip to content

Commit

Permalink
Add multi-task models: SharedBottom, ESMM, MMOE, PLE
Browse files Browse the repository at this point in the history
* add multitask mdoels

1. Add multi-task models: SharedBottom, ESMM, MMOE, PLE
2. Bugfix:
#240
#232

* support python 3.9/3.10 (#259)
* fix: variable name typo (#257)
Co-authored-by: Jason Zan <zanshuxun@aliyun.com>
Co-authored-by: Yi-Xuan Xu <xuyx@lamda.nju.edu.cn>
  • Loading branch information
shenweichen committed Oct 21, 2022
2 parents 2cd84f3 + 70aa7ab commit f685425
Show file tree
Hide file tree
Showing 41 changed files with 1,322 additions and 48 deletions.
6 changes: 3 additions & 3 deletions .github/ISSUE_TEMPLATE/bug_report.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ Steps to reproduce the behavior:
4. See error

**Operating environment(运行环境):**
- python version [e.g. 3.5, 3.6]
- torch version [e.g. 1.6.0, 1.7.0]
- deepctr-torch version [e.g. 0.2.7,]
- python version [e.g. 3.6, 3.7]
- torch version [e.g. 1.9.0, 1.10.0]
- deepctr-torch version [e.g. 0.2.9,]

**Additional context**
Add any other context about the problem here.
4 changes: 2 additions & 2 deletions .github/ISSUE_TEMPLATE/question.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ Add any other context about the problem here.

**Operating environment(运行环境):**
- python version [e.g. 3.6]
- torch version [e.g. 1.7.0,]
- deepctr-torch version [e.g. 0.2.7,]
- torch version [e.g. 1.10.0,]
- deepctr-torch version [e.g. 0.2.9,]
39 changes: 34 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,48 @@ jobs:
timeout-minutes: 120
strategy:
matrix:
python-version: [3.6,3.7,3.8]
torch-version: [1.1.0,1.2.0,1.3.0,1.4.0,1.5.0,1.6.0,1.7.1,1.8.1,1.9.0,1.10.2,1.11.0]
python-version: [3.6,3.7,3.8,3.9,3.10.7]
torch-version: [1.2.0,1.3.0,1.4.0,1.5.0,1.6.0,1.7.1,1.8.1,1.9.0,1.10.2,1.11.0,1.12.1]

exclude:
- python-version: 3.6
torch-version: 1.11.0
- python-version: 3.8
torch-version: 1.1.0
- python-version: 3.6
torch-version: 1.12.1
- python-version: 3.8
torch-version: 1.2.0
- python-version: 3.8
torch-version: 1.3.0

- python-version: 3.9
torch-version: 1.2.0
- python-version: 3.9
torch-version: 1.3.0
- python-version: 3.9
torch-version: 1.4.0
- python-version: 3.9
torch-version: 1.5.0
- python-version: 3.9
torch-version: 1.6.0
- python-version: 3.9
torch-version: 1.7.1
- python-version: 3.10.7
torch-version: 1.2.0
- python-version: 3.10.7
torch-version: 1.3.0
- python-version: 3.10.7
torch-version: 1.4.0
- python-version: 3.10.7
torch-version: 1.5.0
- python-version: 3.10.7
torch-version: 1.6.0
- python-version: 3.10.7
torch-version: 1.7.1
- python-version: 3.10.7
torch-version: 1.8.1
- python-version: 3.10.7
torch-version: 1.9.0
- python-version: 3.10.7
torch-version: 1.10.2
steps:

- uses: actions/checkout@v3
Expand Down
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,14 @@ Let's [**Get Started!**](https://deepctr-torch.readthedocs.io/en/latest/Quick-St
| AutoInt | [CIKM 2019][AutoInt: Automatic Feature Interaction Learning via Self-Attentive Neural Networks](https://arxiv.org/abs/1810.11921) |
| ONN | [arxiv 2019][Operation-aware Neural Networks for User Response Prediction](https://arxiv.org/pdf/1904.12579.pdf) |
| FiBiNET | [RecSys 2019][FiBiNET: Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction](https://arxiv.org/pdf/1905.09433.pdf) |
| IFM | [IJCAI 2019][An Input-aware Factorization Machine for Sparse Prediction](https://www.ijcai.org/Proceedings/2019/0203.pdf) |
| DCN V2 | [arxiv 2020][DCN V2: Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535) |
| DIFM | [IJCAI 2020][A Dual Input-aware Factorization Machine for CTR Prediction](https://www.ijcai.org/Proceedings/2020/0434.pdf) |
| AFN | [AAAI 2020][Adaptive Factorization Network: Learning Adaptive-Order Feature Interactions](https://arxiv.org/pdf/1909.03276) |
| IFM | [IJCAI 2019][An Input-aware Factorization Machine for Sparse Prediction](https://www.ijcai.org/Proceedings/2019/0203.pdf) |
| DCN V2 | [arxiv 2020][DCN V2: Improved Deep & Cross Network and Practical Lessons for Web-scale Learning to Rank Systems](https://arxiv.org/abs/2008.13535) |
| DIFM | [IJCAI 2020][A Dual Input-aware Factorization Machine for CTR Prediction](https://www.ijcai.org/Proceedings/2020/0434.pdf) |
| AFN | [AAAI 2020][Adaptive Factorization Network: Learning Adaptive-Order Feature Interactions](https://arxiv.org/pdf/1909.03276) |
| SharedBottom | [arxiv 2017][An Overview of Multi-Task Learning in Deep Neural Networks](https://arxiv.org/pdf/1706.05098.pdf) |
| ESMM | [SIGIR 2018][Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate](https://dl.acm.org/doi/10.1145/3209978.3210104) |
| MMOE | [KDD 2018][Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts](https://dl.acm.org/doi/abs/10.1145/3219819.3220007) |
| PLE | [RecSys 2020][Progressive Layered Extraction (PLE): A Novel Multi-Task Learning (MTL) Model for Personalized Recommendations](https://dl.acm.org/doi/10.1145/3383313.3412236) |



Expand Down
2 changes: 1 addition & 1 deletion deepctr_torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
from . import models
from .utils import check_version

__version__ = '0.2.8'
__version__ = '0.2.9'
check_version(__version__)
3 changes: 2 additions & 1 deletion deepctr_torch/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
from .ccpm import CCPM
from .dien import DIEN
from .din import DIN
from .afn import AFN
from .afn import AFN
from .multitask import SharedBottom, ESMM, MMOE, PLE
30 changes: 21 additions & 9 deletions deepctr_torch/models/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,13 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoc
y_pred = model(x).squeeze()

optim.zero_grad()
loss = loss_func(y_pred, y.squeeze(), reduction='sum')
if isinstance(loss_func, list):
assert len(loss_func) == self.num_tasks,\
"the length of `loss_func` should be equal with `self.num_tasks`"
loss = sum(
[loss_func[i](y_pred[:, i], y[:, i], reduction='sum') for i in range(self.num_tasks)])
else:
loss = loss_func(y_pred, y.squeeze(), reduction='sum')
reg_loss = self.get_regularization_loss()

total_loss = loss + reg_loss + self.aux_loss
Expand Down Expand Up @@ -456,18 +462,24 @@ def _get_optim(self, optimizer):

def _get_loss_func(self, loss):
if isinstance(loss, str):
if loss == "binary_crossentropy":
loss_func = F.binary_cross_entropy
elif loss == "mse":
loss_func = F.mse_loss
elif loss == "mae":
loss_func = F.l1_loss
else:
raise NotImplementedError
loss_func = self._get_loss_func_single(loss)
elif isinstance(loss, list):
loss_func = [self._get_loss_func_single(loss_single) for loss_single in loss]
else:
loss_func = loss
return loss_func

def _get_loss_func_single(self, loss):
if loss == "binary_crossentropy":
loss_func = F.binary_cross_entropy
elif loss == "mse":
loss_func = F.mse_loss
elif loss == "mae":
loss_func = F.l1_loss
else:
raise NotImplementedError
return loss_func

def _log_loss(self, y_true, y_pred, eps=1e-7, normalize=True, sample_weight=None, labels=None):
# change eps to improve calculation accuracy
return log_loss(y_true,
Expand Down
7 changes: 4 additions & 3 deletions deepctr_torch/models/dcnmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@ def __init__(self, linear_feature_columns,
self.add_regularization_weight(
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn)
self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_linear)
self.add_regularization_weight(self.crossnet.U_list, l2=l2_reg_cross)
self.add_regularization_weight(self.crossnet.V_list, l2=l2_reg_cross)
self.add_regularization_weight(self.crossnet.C_list, l2=l2_reg_cross)
regularization_modules = [self.crossnet.U_list, self.crossnet.V_list, self.crossnet.C_list]
for module in regularization_modules:
self.add_regularization_weight(module, l2=l2_reg_cross)

self.to(device)

def forward(self, X):
Expand Down
10 changes: 5 additions & 5 deletions deepctr_torch/models/dien.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def forward(self, keys, keys_length, neg_keys=None):

masked_keys = torch.masked_select(keys, mask.view(-1, 1, 1)).view(-1, max_length, dim)

packed_keys = pack_padded_sequence(masked_keys, lengths=masked_keys_length, batch_first=True,
packed_keys = pack_padded_sequence(masked_keys, lengths=masked_keys_length.cpu(), batch_first=True,
enforce_sorted=False)
packed_interests, _ = self.gru(packed_keys)
interests, _ = pad_packed_sequence(packed_interests, batch_first=True, padding_value=0.0,
Expand Down Expand Up @@ -353,7 +353,7 @@ def forward(self, query, keys, keys_length, mask=None):
query = torch.masked_select(query, mask.view(-1, 1)).view(-1, dim).unsqueeze(1)

if self.gru_type == 'GRU':
packed_keys = pack_padded_sequence(keys, lengths=keys_length, batch_first=True, enforce_sorted=False)
packed_keys = pack_padded_sequence(keys, lengths=keys_length.cpu(), batch_first=True, enforce_sorted=False)
packed_interests, _ = self.interest_evolution(packed_keys)
interests, _ = pad_packed_sequence(packed_interests, batch_first=True, padding_value=0.0,
total_length=max_length)
Expand All @@ -362,15 +362,15 @@ def forward(self, query, keys, keys_length, mask=None):
elif self.gru_type == 'AIGRU':
att_scores = self.attention(query, keys, keys_length.unsqueeze(1)) # [b, 1, T]
interests = keys * att_scores.transpose(1, 2) # [b, T, H]
packed_interests = pack_padded_sequence(interests, lengths=keys_length, batch_first=True,
packed_interests = pack_padded_sequence(interests, lengths=keys_length.cpu(), batch_first=True,
enforce_sorted=False)
_, outputs = self.interest_evolution(packed_interests)
outputs = outputs.squeeze(0) # [b, H]
elif self.gru_type == 'AGRU' or self.gru_type == 'AUGRU':
att_scores = self.attention(query, keys, keys_length.unsqueeze(1)).squeeze(1) # [b, T]
packed_interests = pack_padded_sequence(keys, lengths=keys_length, batch_first=True,
packed_interests = pack_padded_sequence(keys, lengths=keys_length.cpu(), batch_first=True,
enforce_sorted=False)
packed_scores = pack_padded_sequence(att_scores, lengths=keys_length, batch_first=True,
packed_scores = pack_padded_sequence(att_scores, lengths=keys_length.cpu(), batch_first=True,
enforce_sorted=False)
outputs = self.interest_evolution(packed_interests, packed_scores)
outputs, _ = pad_packed_sequence(outputs, batch_first=True, padding_value=0.0, total_length=max_length)
Expand Down
6 changes: 3 additions & 3 deletions deepctr_torch/models/fibinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def __init__(self, linear_feature_columns, dnn_feature_columns, bilinear_type='i
device=device, gpus=gpus)
self.linear_feature_columns = linear_feature_columns
self.dnn_feature_columns = dnn_feature_columns
self.filed_size = len(self.embedding_dict)
self.SE = SENETLayer(self.filed_size, reduction_ratio, seed, device)
self.Bilinear = BilinearInteraction(self.filed_size, self.embedding_size, bilinear_type, seed, device)
self.field_size = len(self.embedding_dict)
self.SE = SENETLayer(self.field_size, reduction_ratio, seed, device)
self.Bilinear = BilinearInteraction(self.field_size, self.embedding_size, bilinear_type, seed, device)
self.dnn = DNN(self.compute_input_dim(dnn_feature_columns), dnn_hidden_units,
activation=dnn_activation, l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout, use_bn=False,
init_std=init_std, device=device)
Expand Down
4 changes: 4 additions & 0 deletions deepctr_torch/models/multitask/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .sharedbottom import SharedBottom
from .esmm import ESMM
from .mmoe import MMOE
from .ple import PLE
94 changes: 94 additions & 0 deletions deepctr_torch/models/multitask/esmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# -*- coding:utf-8 -*-
"""
Author:
zanshuxun, zanshuxun@aliyun.com
Reference:
[1] Ma X, Zhao L, Huang G, et al. Entire space multi-task model: An effective approach for estimating post-click conversion rate[C]//The 41st International ACM SIGIR Conference on Research & Development in Information Retrieval. 2018.(https://dl.acm.org/doi/10.1145/3209978.3210104)
"""
import torch
import torch.nn as nn

from ..basemodel import BaseModel
from ...inputs import combined_dnn_input
from ...layers import DNN


class ESMM(BaseModel):
"""Instantiates the Entire Space Multi-Task Model architecture.
:param dnn_feature_columns: An iterable containing all the features used by deep part of the model.
:param tower_dnn_hidden_units: list, list of positive integer or empty list, the layer number and units in each layer of task-specific DNN.
:param l2_reg_linear: float, L2 regularizer strength applied to linear part.
:param l2_reg_embedding: float, L2 regularizer strength applied to embedding vector.
:param l2_reg_dnn: float, L2 regularizer strength applied to DNN.
:param init_std: float, to use as the initialize std of embedding vector.
:param seed: integer, to use as random seed.
:param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate.
:param dnn_activation: Activation function to use in DNN.
:param dnn_use_bn: bool, Whether use BatchNormalization before activation or not in DNN.
:param task_types: list of str, indicating the loss of each tasks, ``"binary"`` for binary logloss or ``"regression"`` for regression loss. e.g. ['binary', 'regression'].
:param task_names: list of str, indicating the predict target of each tasks.
:param device: str, ``"cpu"`` or ``"cuda:0"``.
:param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`.
:return: A PyTorch model instance.
"""

def __init__(self, dnn_feature_columns, tower_dnn_hidden_units=(256, 128),
l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0, init_std=0.0001, seed=1024,
dnn_dropout=0, dnn_activation='relu', dnn_use_bn=False, task_types=('binary', 'binary'),
task_names=('ctr', 'ctcvr'), device='cpu', gpus=None):
super(ESMM, self).__init__(linear_feature_columns=[], dnn_feature_columns=dnn_feature_columns,
l2_reg_linear=l2_reg_linear, l2_reg_embedding=l2_reg_embedding, init_std=init_std,
seed=seed, task='binary', device=device, gpus=gpus)
self.num_tasks = len(task_names)
if self.num_tasks != 2:
raise ValueError("the length of task_names must be equal to 2")
if len(dnn_feature_columns) == 0:
raise ValueError("dnn_feature_columns is null!")
if len(task_types) != self.num_tasks:
raise ValueError("num_tasks must be equal to the length of task_types")

for task_type in task_types:
if task_type != 'binary':
raise ValueError("task must be binary in ESMM, {} is illegal".format(task_type))

input_dim = self.compute_input_dim(dnn_feature_columns)

self.ctr_dnn = DNN(input_dim, tower_dnn_hidden_units, activation=dnn_activation,
dropout_rate=dnn_dropout, use_bn=dnn_use_bn,
init_std=init_std, device=device)
self.cvr_dnn = DNN(input_dim, tower_dnn_hidden_units, activation=dnn_activation,
dropout_rate=dnn_dropout, use_bn=dnn_use_bn,
init_std=init_std, device=device)

self.ctr_dnn_final_layer = nn.Linear(tower_dnn_hidden_units[-1], 1, bias=False)
self.cvr_dnn_final_layer = nn.Linear(tower_dnn_hidden_units[-1], 1, bias=False)

self.add_regularization_weight(
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.ctr_dnn.named_parameters()), l2=l2_reg_dnn)
self.add_regularization_weight(
filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.cvr_dnn.named_parameters()), l2=l2_reg_dnn)
self.add_regularization_weight(self.ctr_dnn_final_layer.weight, l2=l2_reg_dnn)
self.add_regularization_weight(self.cvr_dnn_final_layer.weight, l2=l2_reg_dnn)
self.to(device)

def forward(self, X):
sparse_embedding_list, dense_value_list = self.input_from_feature_columns(X, self.dnn_feature_columns,
self.embedding_dict)
dnn_input = combined_dnn_input(sparse_embedding_list, dense_value_list)

ctr_output = self.ctr_dnn(dnn_input)
cvr_output = self.cvr_dnn(dnn_input)

ctr_logit = self.ctr_dnn_final_layer(ctr_output)
cvr_logit = self.cvr_dnn_final_layer(cvr_output)

ctr_pred = self.out(ctr_logit)
cvr_pred = self.out(cvr_logit)

ctcvr_pred = ctr_pred * cvr_pred # CTCVR = CTR * CVR

task_outs = torch.cat([ctr_pred, ctcvr_pred], -1)
return task_outs
Loading

0 comments on commit f685425

Please sign in to comment.