# ESMM 简介

Entire Space Multi-task Model(ESMM)是阿里妈妈精准定向广告算法团队研发的新型多任务联合训练算法范式。

在诸如信息检索、推荐系统、在线广告投放系统等工业级的应用中准确预估转化率（post-click conversion rate，CVR）是至关重要的。例如，在电商平台的推荐系统中，最大化场景商品交易总额（GMV）是平台的重要目标之一，而GMV可以拆解为流量×点击率×转化率×客单价，可见转化率是优化目标的重要因子；从用户体验的角度来说准确预估的转换率被用来平衡用户的点击偏好与购买偏好。

传统的CVR预估任务通常采用类似于CTR预估的技术，比如最近很流行的深度学习模型。然而，有别于CTR预估任务，CVR预估任务面临一些特有的挑战：1) 样本选择偏差；2) 训练数据稀疏；3) 延迟反馈等。

ESMM模型利用用户行为序列数据在完整样本空间建模，避免了传统CVR模型经常遭遇的样本选择偏差和训练数据稀疏的问题，取得了显著的效果。另一方面，ESMM模型首次提出了利用学习CTR和CTCVR的辅助任务迂回学习CVR的思路。ESMM模型中的BASE子网络可以替换为任意的学习模型，因此ESMM的框架可以非常容易地和其他学习模型集成，从而吸收其他学习模型的优势，进一步提升学习效果，想象空间巨大。

## 文档内容说明
本文旨在介绍ESMM以及如何使用ESSM开源项目进行实际业务生产使用，阅读完成后，你可以了解到：

* ESMM的基本系统组成
* ESMM开源代码的运行和使用
* 应用ESMM到具体实践的方法

公开数据集下载
* Ali-CCP：Alibaba Click and Conversion Prediction请参阅：[https://tianchi.aliyun.com/datalab/dataSet.html?dataId=408](https://tianchi.aliyun.com/datalab/dataSet.html?dataId=408)

## ESMM 适用的问题


<img src="C:/Users/zhangy/Desktop/ESMM手把手Guidebook/impression_click_buy.png"/>


ESMM 充分利用用户行为的序列模式，在 CTR 和 CTCVR 两项辅助任务的帮助下，优雅地解决了在实践中遇到的 CVR 建模 \$\textbf{SSB}\$ 和 \$\textbf{DS}\$ 的挑战。ESMM 可以很容易地推广到具有序列依赖性的用户行为(浏览、点击、加购、购买等)预估中，构建跨域多场景全链路预估模型。



<img src="assets/system_overview.png"/>

广告或推荐系统中，用户行为的系统链路可以表示为 \$召回 \rightarrow  粗排 \rightarrow 精排 \rightarrow 展现 \rightarrow 点击 \rightarrow 转化 \rightarrow 复购 \$ 的序列。通常我们在引擎请求的时候进行多阶段的综合排序并不断选取头部的子集传给下一级，最终在展现阶段返回给用户。每阶段任务的输入量级都会因为上一阶段任务经过系统筛选（比如 召回到粗排、粗排到精排、精排到展现）或者用户主动筛选（比如 展现到点击、点击到转化、转化到复购）而逐步减少。ESMM 适用于成熟的电商推荐或者广告全链路预估系统。我

# ESMM 框架介绍

## 算法原理

ESMM 引入两个预估展现点击率（CTR）和展现后点击转化率（CTCVR）作为辅助任务。ESMM 将 pCVR 作为一个中间变量，并将其乘以 pCTR 得到 pCTCVR，而不是直接基于有偏的点击样本子集进行 CVR 模型训练。pCTCVR 和 pCTR 是在全空间中以所有展现样本估计的，因此衍生的 pCVR 也适用于全空间并且缓解了 \$\textbf{SSB}\$ 问题。此外，CVR 任务的特征表示网络与 CTR 任务共享，后者用更丰富的样本进行训练。这种参数共享遵循特征表示迁移学习范式，并为缓解 \$\textbf{DS}\$ 问题提供了显著的帮助。

## 全空间建模
pCTR和pCTCVR是ESMM在全空间实际预估的变量。这种乘法形式使得三个关联和共同训练的分类器能够在训练期间利用数据的序列模式并相互传递信息。ESMM的损失函数如下，它由 CTR 和 CTCVR 任务中的两个损失项组成，这些任务通过所有展现次数的样本进行计算

\begin{equation}
\begin{split}
L(\theta*{cvr}, \theta*{ctr}) = \sum*{i=1}^N l(y\_i, f(\textbf{x}*i;\theta*{ctr})) + \sum*{i=1}^N l(y\_i&z\_i, f(\textbf{x}*i;\theta*{ctr}) \times f(\textbf{x}*i;\theta*{cvr}))
\end{split}
\end{equation}

其中 \$\theta\_{ctr}\$ 和 \$\theta\_{cvr}\$ 是 CTR 和 CVR 网络的参数，l函数是交叉熵损失函数。
在数学上，公式 Eq.（3) 将 \$y \rightarrow z\$ 分解为两部分对应于 CTR 和 CTCVR 任务的标签，构造训练数据集如下：
对于CTR任务，单击的展现被标记为\$y = 1\$，否则为 \$y=0\$；对于 CTCVR 任务，同时发生点击和转化事件的展现被标记为 \$ y & z = 1 \$ ，否则 \$ y & z = 0 \$，\$y\$ 和 \$ y & z \$ ，这实际上是利用点击和转化标签的序列依赖性。

## 迁移学习
正如 BASE 模型部分介绍的那样，Embedding Layer 将大规模稀疏输入映射到低维稠密向量中，它占据深度网络的大部分参数，需要大量的样本来进行训练。在 ESMM 中，CVR 网络的 Embedding 参数与 CTR 任务共享。它遵循特征表示转化学习范式。CTR 任务所有展现次数的样本规模比 CVR 任务要丰富多个量级。该参数共享机制使 ESMM 中的 CVR 网络可以从未点击的展现中学习，缓解了数据稀疏性问题。

## 结构扩展性

它主要由两个子网组成：CVR 网络在图的左边部分和右边部分的 CTR 网络。 CVR 和 CTR 网络都采用与 BASE 模型相同的结构。 CTCVR 将 CVR 和 CTR 网络的输出结果相乘作为输出。其中每个子网络结果可以被替代为任意的分类预估网络。

# ESMM 训练示例

## 数据处理

In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
import numpy as np
from collections import Counter
import tensorflow as tf

import os
import pickle
import re
from tensorflow.python.ops import math_ops

  from ._conv import register_converters as _register_converters


## 先来看看数据
经过step1&2后，采样2.5%的数据阳历，其中训练和测试各是100W条

In [2]:
sample_feature_columns=['sample_id','click','buy','md5','feature_num','ItemID','CategoryID','ShopID','NodeID',
                       'BrandID','Com_CateID','Com_ShopID','Com_BrandID','Com_NodeID','PID']

train_sample_table = pd.read_table('C:/Users/zhangy/Desktop/ctr_cvr_data/BuyWeight_sampled_sample_skeleton_train_sample_feature_column.csv',
                                  sep=',',dtype={'ItemID':object,'CategoryID':object,'ShopID':object,'PID':object},header=0,names=None,
                                  engine='python')
train_sample_table = train_sample_table.drop('feature_list',axis=1)
train_sample_table.head()

Unnamed: 0,sample_id,click,buy,md5,feature_num,ItemID,CategoryID,ShopID,NodeID,BrandID,Com_CateID,Com_ShopID,Com_BrandID,Com_NodeID,PID
0,104,0,0,bacff91692951881,16,7519835,8316856,8952950,9026927|9078967|9109890|9094933|9042779|909090...,<PAD>,9355093,<PAD>,<PAD>,10023513|10010801|10008961,9351665
1,152,0,0,bacff91692951881,14,7541969,8317499,8766801,9112685|9078547|9048065,9192345,9355716,9663046,9886824,10087588|10028347|10056259,9351665
2,157,0,0,bacff91692951881,11,7882213,8315782,8716483,9115431|9019674|9080897,9187433,9354098,9628886,9883335,<PAD>,9351665
3,160,0,0,bacff91692951881,10,8135427,8315277,8980998,9095211|9023762|9022971,9206290,9353609,<PAD>,9896652,<PAD>,9351665
4,209,0,0,bacff91692951881,11,5831350,8315277,8773531,9095211|9023762|9022971,9181078,9353609,9667568,9878755,<PAD>,9351665


In [8]:
sample_feature_columns=['sample_id','click','buy','md5','feature_num','ItemID','CategoryID','ShopID','NodeID',
                       'BrandID','Com_CateID','Com_ShopID','Com_BrandID','Com_NodeID','PID']

test_sample_table = pd.read_table('C:/Users/zhangy/Desktop/ctr_cvr_data/BuyWeight_sampled_sample_skeleton_test_sample_feature_column.csv',
                                  sep=',',dtype={'ItemID':object,'CategoryID':object,'ShopID':object,'PID':object},header=0,names=None,
                                  engine='python')
test_sample_table = test_sample_table.drop('feature_list',axis=1)
test_sample_table.head()

Unnamed: 0,sample_id,click,buy,md5,feature_num,ItemID,CategoryID,ShopID,NodeID,BrandID,Com_CateID,Com_ShopID,Com_BrandID,Com_NodeID,PID
0,3,0,0,23bd0f75de327c60,12,6539512,8315405,8546676,9100479|9084127|9074748|9038869|9080592|910870...,9273427,<PAD>,<PAD>,<PAD>,<PAD>,9351665
1,10,0,0,543b0cd53c7d5858,15,5213924,8315479,8419269,9088205|9035664|9023041|9039089|9095177|909845...,9201467,9353805,<PAD>,9893286,<PAD>,9351665
2,20,0,0,543b0cd53c7d5858,8,5802912,8316509,8814075,9044230|9078066|9072583,9286972,<PAD>,<PAD>,<PAD>,<PAD>,9351665
3,49,0,0,543b0cd53c7d5858,11,6972165,8317356,8331019,9072169|9078426|9073266|9056371|9049913|910983...,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,9351665
4,121,0,0,a2ea4295d36bc432,15,6036239,8320324,8613357,9073352|9067159|9029051|9046436|9045892|903303...,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,9351665


In [4]:
common_feature_columns = ['md5', 'feature_num', 'UserID', 'User_CateIDs', 'User_ShopIDs', 'User_BrandIDs', 'User_NodeIDs', 'User_Cluster', 
                     'User_ClusterID', 'User_Gender', 'User_Age', 'User_Level1', 'User_Level2', 
                     'User_Occupation', 'User_Geo']
train_common_features = pd.read_table('C:/Users/zhangy/Desktop/ctr_cvr_data/BuyWeight_sampled_common_features_skeleton_train_sample_feature_column.csv', 
                                      sep=',', header=0, names=None, engine = 'python')
train_common_features.head()

Unnamed: 0,md5,feature_num,UserID,User_CateIDs,User_ShopIDs,User_BrandIDs,User_NodeIDs,User_Cluster,User_ClusterID,User_Gender,User_Age,User_Level1,User_Level2,User_Occupation,User_Geo
0,0000350f0c2121e7,811,392326,447553|445995|450247|449070|450980|445135|4454...,<PAD>,3716224|3514627|3772871|3543283|3728186|371080...,<PAD>,3438725,3438760,3438769,3438772,3438778,3438782,3864885,3864888
1,000091a89d1867ab,7,<PAD>,<PAD>,<PAD>,<PAD>,<PAD>,3438658,3438761,3438769,3438773,<PAD>,3438781,3864885,3864889
2,0001def19d7cb335,964,241189,449099|455676|449360|453249|449425|456071|4509...,<PAD>,3530778|3689932|3497595|3569442|3569123|368644...,<PAD>,3438685,3438762,3438769,3438774,3438778,3438782,3864885,3864888
3,0001fa8246be0940,374,407969,451311|450954|450462|451530|451099|450656|4490...,<PAD>,3504052|3507496|3622158|3630324|3566530|352097...,<PAD>,3438737,3438757,3438768,3438774,3438778,3438782,3864885,3864888
4,000260b23f85aadb,266,168295,450837|451033|450838|449949|455349|455827|4553...,<PAD>,3627914|3760360|3763560|3496527|3689932|384471...,<PAD>,3438705,3438765,3438768,3438771,3438777,3438782,3864886,3864888


In [5]:
common_feature_columns = ['md5', 'feature_num', 'UserID', 'User_CateIDs', 'User_ShopIDs', 'User_BrandIDs', 'User_NodeIDs', 'User_Cluster', 
                     'User_ClusterID', 'User_Gender', 'User_Age', 'User_Level1', 'User_Level2', 
                     'User_Occupation', 'User_Geo']
test_common_features = pd.read_table('C:/Users/zhangy/Desktop/ctr_cvr_data/BuyWeight_sampled_common_features_skeleton_test_sample_feature_column.csv', 
                                      sep=',', header=0, names=None, engine = 'python')
test_common_features.head()

Unnamed: 0,md5,feature_num,UserID,User_CateIDs,User_ShopIDs,User_BrandIDs,User_NodeIDs,User_Cluster,User_ClusterID,User_Gender,User_Age,User_Level1,User_Level2,User_Occupation,User_Geo
0,810d5366057b3f58,1025,412797,451311|451286|451133|450954|450656|446913|4506...,<PAD>,3592299|3449840|3650730|3650822|3792091|352433...,<PAD>,3438725,3438760,3438769,3438772,3438778,3438782,3864885,3864888
1,0001970d9ebf72cf,126,64841,451130|450658|450656|451639|453921|453929|4490...,<PAD>,3704348|3704152|3852278|3849481|3580992|366112...,<PAD>,3438658,3438756,3438769,3438771,<PAD>,3438780,3864885,<PAD>
2,0010d0b9633bb5b0,250,66015,455028|451998|451100|445269|445990|450099|4557...,<PAD>,3520924|3505215|3588720|3541711|3801132|382945...,<PAD>,3438670,3438756,3438769,3438771,3438777,3438782,3864886,3864889
3,0012aad1f55312b6,170,121803,451822|451095|449537|449301|455342|449077|4490...,<PAD>,3518975|3697970|3784310|3821497|3698768|345218...,<PAD>,3438658,3438766,3438768,3438772,<PAD>,3438782,3864885,3864889
4,0013e5c24e8dd3a6,617,135732,452511|450721|449079|450276|450656|449078|4493...,<PAD>,3707605|3632935|3809314|3703188|3700287|356905...,<PAD>,3438670,3438756,3438769,3438771,3438777,3438782,3864885,3864889


### 两表join示例

In [9]:
print(train_sample_table.shape)
print(train_common_features.shape)

print(test_sample_table.shape)
print(test_common_features.shape)

(1065221, 15)
(466911, 15)
(1084385, 15)
(541312, 15)


In [10]:
merge_data=pd.merge(train_sample_table,train_common_features,on='md5',how='inner')
print(merge_data.shape)
print(merge_data.columns)
merge_data.head()

(1065221, 29)
Index(['sample_id', 'click', 'buy', 'md5', 'feature_num_x', 'ItemID',
       'CategoryID', 'ShopID', 'NodeID', 'BrandID', 'Com_CateID', 'Com_ShopID',
       'Com_BrandID', 'Com_NodeID', 'PID', 'feature_num_y', 'UserID',
       'User_CateIDs', 'User_ShopIDs', 'User_BrandIDs', 'User_NodeIDs',
       'User_Cluster', 'User_ClusterID', 'User_Gender', 'User_Age',
       'User_Level1', 'User_Level2', 'User_Occupation', 'User_Geo'],
      dtype='object')


Unnamed: 0,sample_id,click,buy,md5,feature_num_x,ItemID,CategoryID,ShopID,NodeID,BrandID,...,User_BrandIDs,User_NodeIDs,User_Cluster,User_ClusterID,User_Gender,User_Age,User_Level1,User_Level2,User_Occupation,User_Geo
0,104,0,0,bacff91692951881,16,7519835,8316856,8952950,9026927|9078967|9109890|9094933|9042779|909090...,<PAD>,...,3534361|3654192|3503151|3722909|3650730|376454...,<PAD>,3438658,3438762,3438769,3438774,<PAD>,3438782,3864885,3864887
1,152,0,0,bacff91692951881,14,7541969,8317499,8766801,9112685|9078547|9048065,9192345,...,3534361|3654192|3503151|3722909|3650730|376454...,<PAD>,3438658,3438762,3438769,3438774,<PAD>,3438782,3864885,3864887
2,157,0,0,bacff91692951881,11,7882213,8315782,8716483,9115431|9019674|9080897,9187433,...,3534361|3654192|3503151|3722909|3650730|376454...,<PAD>,3438658,3438762,3438769,3438774,<PAD>,3438782,3864885,3864887
3,160,0,0,bacff91692951881,10,8135427,8315277,8980998,9095211|9023762|9022971,9206290,...,3534361|3654192|3503151|3722909|3650730|376454...,<PAD>,3438658,3438762,3438769,3438774,<PAD>,3438782,3864885,3864887
4,209,0,0,bacff91692951881,11,5831350,8315277,8773531,9095211|9023762|9022971,9181078,...,3534361|3654192|3503151|3722909|3650730|376454...,<PAD>,3438658,3438762,3438769,3438774,<PAD>,3438782,3864885,3864887


## 实现数据预处理

In [11]:
train_sample_table.head()

Unnamed: 0,sample_id,click,buy,md5,feature_num,ItemID,CategoryID,ShopID,NodeID,BrandID,Com_CateID,Com_ShopID,Com_BrandID,Com_NodeID,PID
0,104,0,0,bacff91692951881,16,7519835,8316856,8952950,9026927|9078967|9109890|9094933|9042779|909090...,<PAD>,9355093,<PAD>,<PAD>,10023513|10010801|10008961,9351665
1,152,0,0,bacff91692951881,14,7541969,8317499,8766801,9112685|9078547|9048065,9192345,9355716,9663046,9886824,10087588|10028347|10056259,9351665
2,157,0,0,bacff91692951881,11,7882213,8315782,8716483,9115431|9019674|9080897,9187433,9354098,9628886,9883335,<PAD>,9351665
3,160,0,0,bacff91692951881,10,8135427,8315277,8980998,9095211|9023762|9022971,9206290,9353609,<PAD>,9896652,<PAD>,9351665
4,209,0,0,bacff91692951881,11,5831350,8315277,8773531,9095211|9023762|9022971,9181078,9353609,9667568,9878755,<PAD>,9351665


In [12]:
#打印Column和type，确保训练集和测试集可以一起序列化
train_sample_table['ItemID'].head()
print(train_sample_table.dtypes)
print(test_sample_table.dtypes)

sample_id       int64
click           int64
buy             int64
md5            object
feature_num     int64
ItemID         object
CategoryID     object
ShopID         object
NodeID         object
BrandID        object
Com_CateID     object
Com_ShopID     object
Com_BrandID    object
Com_NodeID     object
PID            object
dtype: object
sample_id       int64
click           int64
buy             int64
md5            object
feature_num     int64
ItemID         object
CategoryID     object
ShopID         object
NodeID         object
BrandID        object
Com_CateID     object
Com_ShopID     object
Com_BrandID    object
Com_NodeID     object
PID            object
dtype: object


In [13]:
print(train_common_features.columns)
train_common_features['feature_num'].head()

Index(['md5', 'feature_num', 'UserID', 'User_CateIDs', 'User_ShopIDs',
       'User_BrandIDs', 'User_NodeIDs', 'User_Cluster', 'User_ClusterID',
       'User_Gender', 'User_Age', 'User_Level1', 'User_Level2',
       'User_Occupation', 'User_Geo'],
      dtype='object')


0    811
1      7
2    964
3    374
4    266
Name: feature_num, dtype: int64

打印unique ID数

In [14]:
print(len(train_sample_table['ItemID'].unique()))
print(len(train_sample_table['CategoryID'].unique()))
print(len(train_sample_table['ShopID'].unique()))
print(len(train_sample_table['NodeID'].unique()))
print(len(train_sample_table['BrandID'].unique()))
print(len(train_sample_table['Com_CateID'].unique()))
print(len(train_sample_table['Com_ShopID'].unique()))
print(len(train_sample_table['Com_BrandID'].unique()))
print(len(train_sample_table['Com_NodeID'].unique()))
print(len(train_sample_table['PID'].unique()))

433465
6186
208741
683923
85623
5116
83509
40290
233650
3


In [15]:
value1 = set(train_sample_table['ShopID'].tolist())
value2 = set(test_sample_table['ShopID'].tolist())

value3 = set(train_common_features['UserID'].tolist())
value4 = set(test_common_features['UserID'].tolist())

print("inner ShopID:",len(value1&value2))

print("inner UserID:",len(value3&value4))


inner ShopID: 139604
inner UserID: 93416


In [18]:
def load_ESMM_Train_and_Test_Data():
    """
    Load Dataset from File
    """
    sample_feature_columns = ['sample_id', 'click', 'buy', 'md5', 'feature_num', 'ItemID','CategoryID','ShopID','NodeID','BrandID','Com_CateID',
                     'Com_ShopID','Com_BrandID','Com_NodeID','PID']
    
    common_feature_columns = ['md5', 'feature_num', 'UserID', 'User_CateIDs', 'User_ShopIDs', 'User_BrandIDs', 'User_NodeIDs', 'User_Cluster', 
                     'User_ClusterID', 'User_Gender', 'User_Age', 'User_Level1', 'User_Level2', 
                     'User_Occupation', 'User_Geo']
    
    # 强制转化为其中部分列为object，是因为训练和测试某些列，Pandas load类型不一致，影响后面的序列化
    train_sample_table = pd.read_table('C:/Users/zhangy/Desktop/ctr_cvr_data/BuyWeight_sampled_sample_skeleton_train_sample_feature_column.csv', sep=',',\
                                  dtype={'ItemID': object, 'CategoryID': object, 'ShopID': object, 'PID': object},\
                                  header=0, names=None, engine = 'python')
    train_common_features = pd.read_table('C:/Users/zhangy/Desktop/ctr_cvr_data/BuyWeight_sampled_common_features_skeleton_train_sample_feature_column.csv', sep=',', header=0, names=None, engine = 'python')
    
    test_sample_table = pd.read_table('C:/Users/zhangy/Desktop/ctr_cvr_data/BuyWeight_sampled_sample_skeleton_test_sample_feature_column.csv', sep=',', \
                                  dtype={'ItemID': object, 'CategoryID': object, 'ShopID': object, 'PID': object},\
                                  header=0, names=None, engine = 'python')
    test_common_features = pd.read_table('C:/Users/zhangy/Desktop/ctr_cvr_data/BuyWeight_sampled_common_features_skeleton_test_sample_feature_column.csv', sep=',', header=0, names=None, engine = 'python')
    
    #itemID转数字字典
    ItemID_set = set()
    for val in train_sample_table['ItemID'].str.split('|'):
        ItemID_set.update(val)
    for val in test_sample_table['ItemID'].str.split('|'):
        ItemID_set.update(val)
    ItemID_set.add('<PAD>')
    ItemID2int = {val:ii for ii, val in enumerate(ItemID_set)}
    #itemID 转成等长数字列表，示例，其实itemID是One Hot的，不需要此操作
    ItemID_map = {val:[ItemID2int[row] for row in val.split('|')]  \
                  for ii,val in enumerate(set(train_sample_table['ItemID']))}
    test_ItemID_map = {val:[ItemID2int[row] for row in val.split('|')]  \
                  for ii,val in enumerate(set(test_sample_table['ItemID']))}
    # merge train & test
    ItemID_map.update(test_ItemID_map)
    ItemID_map_max_len = 1
    print("ItemID_map max_len:", ItemID_map_max_len)
    for key in ItemID_map:
        for cnt in range(ItemID_map_max_len - len(ItemID_map[key])):
            ItemID_map[key].insert(len(ItemID_map[key]) + cnt,itemID2int['<PAD>'])
    train_sample_table['ItemID'] = train_sample_table['ItemID'].map(ItemID_map)
    test_sample_table['ItemID'] = test_sample_table['ItemID'].map(ItemID_map)
    print("ItemID finish")
    
    
    #User_CateIDs转数字字典
    User_CateIDs_set = set()
    for val in train_common_features['User_CateIDs'].str.split('|'):
        User_CateIDs_set.update(val)
    for val in test_common_features['User_CateIDs'].str.split('|'):
        User_CateIDs_set.update(val)
    User_CateIDs_set.add('<PAD>')
    User_CateIDs2int = {val:ii for ii, val in enumerate(User_CateIDs_set)}
    #User_CateIDs 转成等长数字列表
    User_CateIDs_map = {val:[User_CateIDs2int[row] for row in val.split('|')]  \
                  for ii,val in enumerate(set(train_common_features['User_CateIDs']))}
    test_User_CateIDs_map = {val:[User_CateIDs2int[row] for row in val.split('|')]  \
                  for ii,val in enumerate(set(test_common_features['User_CateIDs']))}
    # merge train & test
    User_CateIDs_map.update(test_User_CateIDs_map)
    User_CateIDs_map_max_len = 100
    print("User_CateIDs_map max_len:", User_CateIDs_map_max_len)
    for key in User_CateIDs_map:
        for cnt in range(User_CateIDs_map_max_len - len(User_CateIDs_map[key])):
            User_CateIDs_map[key].insert(len(User_CateIDs_map[key]) + cnt,User_CateIDs2int['<PAD>'])
    train_common_features['User_CateIDs'] = train_common_features['User_CateIDs'].map(User_CateIDs_map)
    test_common_features['User_CateIDs'] = test_common_features['User_CateIDs'].map(User_CateIDs_map)
    print("User_CateIDs finish")
    
    #User_BrandIDs转数字字典
    User_BrandIDs_set = set()
    for val in train_common_features['User_BrandIDs'].str.split('|'):
        User_BrandIDs_set.update(val)
    for val in test_common_features['User_BrandIDs'].str.split('|'):
        User_BrandIDs_set.update(val)
    User_BrandIDs_set.add('<PAD>')
    User_BrandIDs2int = {val:ii for ii, val in enumerate(User_BrandIDs_set)}
    #User_BrandIDs 转成等长数字列表
    User_BrandIDs_map = {val:[User_BrandIDs2int[row] for row in val.split('|')]  \
                  for ii,val in enumerate(set(train_common_features['User_BrandIDs']))}
    test_User_BrandIDs_map = {val:[User_BrandIDs2int[row] for row in val.split('|')]  \
                  for ii,val in enumerate(set(test_common_features['User_BrandIDs']))}
    # merge train & test
    User_BrandIDs_map.update(test_User_BrandIDs_map)
    User_BrandIDs_map_max_len = 100
    print("User_BrandIDs_map max_len:", User_BrandIDs_map_max_len)
    for key in User_BrandIDs_map:
        for cnt in range(User_BrandIDs_map_max_len - len(User_BrandIDs_map[key])):
            User_BrandIDs_map[key].insert(len(User_BrandIDs_map[key]) + cnt,User_BrandIDs2int['<PAD>'])
    train_common_features['User_BrandIDs'] = train_common_features['User_BrandIDs'].map(User_BrandIDs_map)
    test_common_features['User_BrandIDs'] = test_common_features['User_BrandIDs'].map(User_BrandIDs_map)
    print("User_BrandIDs finish")
    
    
    #userID 转数字字典
    UserID_set = set()
    for val in train_common_features['UserID']:
        UserID_set.add(val)
    for val in test_common_features['UserID']:
        UserID_set.add(val)
    UserID2int = {val:ii for ii, val in enumerate(UserID_set)}
    UserID_map_max_len = 1
    print("UserID_map max_len:", UserID_map_max_len)
    train_common_features['UserID'] = train_common_features['UserID'].map(UserID2int)
    test_common_features['UserID'] = test_common_features['UserID'].map(UserID2int)
    print("UserID finish")
    
    #User_Cluster 转数字字典
    User_Cluster_set = set()
    for val in train_common_features['User_Cluster']:
        User_Cluster_set.add(val)
    for val in test_common_features['User_Cluster']:
        User_Cluster_set.add(val)
    User_Cluster2int = {val:ii for ii, val in enumerate(User_Cluster_set)}
    User_Cluster_map_max_len = 1
    print("User_Cluster_map max_len:", User_Cluster_map_max_len)
    train_common_features['User_Cluster'] = train_common_features['User_Cluster'].map(User_Cluster2int)
    test_common_features['User_Cluster'] = test_common_features['User_Cluster'].map(User_Cluster2int)
    print("User_Cluster finish")
    
    #CategoryID 转数字字典
    CategoryID_set = set()
    for val in train_sample_table['CategoryID']:
        CategoryID_set.add(val)
    for val in test_sample_table['CategoryID']:
        CategoryID_set.add(val)
    CategoryID2int = {val:ii for ii, val in enumerate(CategoryID_set)}
    CategoryID_map_max_len = 1
    print("CategoryID_map max_len:", CategoryID_map_max_len)
    train_sample_table['CategoryID'] = train_sample_table['CategoryID'].map(CategoryID2int)
    test_sample_table['CategoryID'] = test_sample_table['CategoryID'].map(CategoryID2int)
    print("CategoryID finish")
    
    #ShopID 转数字字典
    ShopID_set = set()
    for val in train_sample_table['ShopID']:
        ShopID_set.add(val)
    for val in test_sample_table['ShopID']:
        ShopID_set.add(val)
    ShopID2int = {val:ii for ii, val in enumerate(ShopID_set)}
    ShopID_map_max_len = 1
    print("ShopID_map max_len:", ShopID_map_max_len)
    train_sample_table['ShopID'] = train_sample_table['ShopID'].map(ShopID2int)
    test_sample_table['ShopID'] = test_sample_table['ShopID'].map(ShopID2int)
    print("ShopID finish")

    #BrandID 转数字字典
    BrandID_set = set()
    for val in train_sample_table['BrandID']:
        BrandID_set.add(val)
    for val in test_sample_table['BrandID']:
        BrandID_set.add(val)
    BrandID2int = {val:ii for ii, val in enumerate(BrandID_set)}
    BrandID_map_max_len = 1
    print("BrandID_map max_len:", UserID_map_max_len)
    train_sample_table['BrandID'] = train_sample_table['BrandID'].map(BrandID2int)
    test_sample_table['BrandID'] = test_sample_table['BrandID'].map(BrandID2int)
    print("BrandID finish")
    
    #Com_CateID 转数字字典
    Com_CateID_set = set()
    for val in train_sample_table['Com_CateID']:
        Com_CateID_set.add(val)
    for val in test_sample_table['Com_CateID']:
        Com_CateID_set.add(val)
    Com_CateID2int = {val:ii for ii, val in enumerate(Com_CateID_set)}
    Com_CateID_map_max_len = 1
    print("Com_CateID_map max_len:", Com_CateID_map_max_len)
    train_sample_table['Com_CateID'] = train_sample_table['Com_CateID'].map(Com_CateID2int)
    test_sample_table['Com_CateID'] = test_sample_table['Com_CateID'].map(Com_CateID2int)
    print("Com_CateID finish")
    
    #Com_ShopID 转数字字典
    Com_ShopID_set = set()
    for val in train_sample_table['Com_ShopID']:
        Com_ShopID_set.add(val)
    for val in test_sample_table['Com_ShopID']:
        Com_ShopID_set.add(val)
    Com_ShopID2int = {val:ii for ii, val in enumerate(Com_ShopID_set)}
    Com_ShopID_map_max_len = 1
    print("Com_ShopID_map max_len:", Com_ShopID_map_max_len)
    train_sample_table['Com_ShopID'] = train_sample_table['Com_ShopID'].map(Com_ShopID2int)
    test_sample_table['Com_ShopID'] = test_sample_table['Com_ShopID'].map(Com_ShopID2int)
    print("Com_ShopID finish")
    
    #Com_BrandID 转数字字典
    Com_BrandID_set = set()
    for val in train_sample_table['Com_BrandID']:
        Com_BrandID_set.add(val)
    for val in test_sample_table['Com_BrandID']:
        Com_BrandID_set.add(val)
    Com_BrandID2int = {val:ii for ii, val in enumerate(Com_BrandID_set)}
    Com_BrandID_map_max_len = 1
    print("Com_BrandID_map max_len:", UserID_map_max_len)
    train_sample_table['Com_BrandID'] = train_sample_table['Com_BrandID'].map(Com_BrandID2int)
    test_sample_table['Com_BrandID'] = test_sample_table['Com_BrandID'].map(Com_BrandID2int)
    print("Com_BrandID finish")
    
    #PID 转数字字典
    PID_set = set()
    for val in train_sample_table['PID']:
        PID_set.add(val)
    for val in test_sample_table['PID']:
        PID_set.add(val)
    PID2int = {val:ii for ii, val in enumerate(PID_set)}
    PID_map_max_len = 1
    print("PID_map max_len:", PID_map_max_len)
    train_sample_table['PID'] = train_sample_table['PID'].map(PID2int)
    test_sample_table['PID'] = test_sample_table['PID'].map(PID2int)
    print("PID finish")
    
    
    #按照md5合并两个表
    train_data = pd.merge(train_sample_table, train_common_features, on='md5',how='inner')
    test_data = pd.merge(test_sample_table, test_common_features, on='md5',how='inner')

    print("Sample/Common Merged")
    #将数据分成X和y两张表
    feature_fields = ['UserID','ItemID','User_Cluster', 'CategoryID','ShopID',\
                      'BrandID','Com_CateID','Com_ShopID','Com_BrandID','PID','User_CateIDs','User_BrandIDs']
    target_fields = ['click','buy']
    train_features_pd, train_targets_pd = train_data[feature_fields], train_data[target_fields]
    train_features = train_features_pd.values
    train_targets_values = train_targets_pd.values
    
    test_features_pd, test_targets_pd = test_data[feature_fields], test_data[target_fields]
    test_features = test_features_pd.values
    test_targets_values = test_targets_pd.values
    
    return UserID_map_max_len, ItemID_map_max_len, User_Cluster_map_max_len, \
User_CateIDs_map_max_len, User_BrandIDs_map_max_len, \
CategoryID_map_max_len, ShopID_map_max_len, BrandID_map_max_len, Com_CateID_map_max_len,\
Com_ShopID_map_max_len, Com_BrandID_map_max_len, PID_map_max_len, UserID2int, ItemID2int,\
User_Cluster2int, User_CateIDs2int, User_BrandIDs2int,  CategoryID2int, ShopID2int, BrandID2int, Com_CateID2int, \
Com_ShopID2int, Com_BrandID2int, PID2int, train_features, train_targets_values, train_data, \
test_features, test_targets_values, test_data

In [20]:
UserID_map_max_len, ItemID_map_max_len, User_Cluster_map_max_len, \
User_CateIDs_map_max_len, User_BrandIDs_map_max_len, \
CategoryID_map_max_len, ShopID_map_max_len, BrandID_map_max_len, Com_CateID_map_max_len,\
Com_ShopID_map_max_len, Com_BrandID_map_max_len, PID_map_max_len, UserID2int, ItemID2int,\
User_Cluster2int, User_CateIDs2int, User_BrandIDs2int,  CategoryID2int, ShopID2int, BrandID2int, Com_CateID2int, \
Com_ShopID2int, Com_BrandID2int, PID2int, train_features, train_targets_values, train_data, \
test_features, test_targets_values, test_data = load_ESMM_Train_and_Test_Data()
print(0)
pickle.dump((UserID_map_max_len, ItemID_map_max_len, User_Cluster_map_max_len, \
User_CateIDs_map_max_len, User_BrandIDs_map_max_len, \
CategoryID_map_max_len, ShopID_map_max_len, BrandID_map_max_len, Com_CateID_map_max_len,\
Com_ShopID_map_max_len, Com_BrandID_map_max_len, PID_map_max_len, UserID2int, ItemID2int,\
User_Cluster2int, User_CateIDs2int, User_BrandIDs2int,  CategoryID2int, ShopID2int, BrandID2int, Com_CateID2int, \
Com_ShopID2int, Com_BrandID2int, PID2int, train_features, train_targets_values, train_data, \
test_features, test_targets_values, test_data), open('C:/Users/zhangy/Desktop/ctr_cvr_data/preprocess.p', 'wb'))
print(0)

ItemID_map max_len: 1
ItemID finish
User_CateIDs_map max_len: 100
User_CateIDs finish
User_BrandIDs_map max_len: 100
User_BrandIDs finish
UserID_map max_len: 1
UserID finish
User_Cluster_map max_len: 1
User_Cluster finish
CategoryID_map max_len: 1
CategoryID finish
ShopID_map max_len: 1
ShopID finish
BrandID_map max_len: 1
BrandID finish
Com_CateID_map max_len: 1
Com_CateID finish
Com_ShopID_map max_len: 1
Com_ShopID finish
Com_BrandID_map max_len: 1
Com_BrandID finish
PID_map max_len: 1
PID finish
Sample/Common Merged
0
0


In [25]:
User_CateIDs2int['<PAD>']


4501

In [26]:
User_BrandIDs2int['<PAD>']

176659

In [34]:
test_features[0:1,0:100]

array([[152590, list([572948]), 51, 3938, 76419, 80560, 2152, 115523,
        47260, 0,
        list([11419, 2854, 11757, 1979, 3880, 1175, 7618, 11881, 11368, 8027, 7363, 10810, 6901, 7493, 10005, 11323, 3447, 594, 9895, 4722, 6254, 288, 11235, 10038, 1605, 6044, 7904, 625, 11868, 11453, 698, 100, 7113, 126, 2646, 4466, 4491, 1938, 8325, 11329, 7996, 9240, 9164, 9823, 11878, 4327, 4091, 5784, 2200, 2762, 8465, 8124, 8394, 10919, 6838, 4470, 5004, 6836, 11258, 1700, 9749, 7267, 3291, 5023, 10871, 6719, 11465, 9403, 6070, 5901, 9943, 8758, 4294, 4932, 4953, 1866, 11332, 1412, 8528, 4902, 1610, 7333, 4755, 5431, 11497, 10683, 11052, 3488, 11049, 7507, 11314, 8713, 7822, 3621, 7673, 9710, 8287, 6047, 9665, 3297]),
        list([70458, 323287, 50137, 46116, 156670, 35966, 183202, 242324, 285211, 201167, 78979, 263081, 343745, 231808, 339866, 365974, 361510, 19397, 148492, 269373, 12027, 165704, 145134, 329482, 268650, 68468, 76879, 225120, 280504, 188437, 41317, 267523, 92029, 318460, 2271

In [35]:
train_features[0:1,0:100]

array([[152590, list([408977]), 51, 749, 2879, 100799, 1486, 115523,
        47260, 0,
        list([2854, 11757, 3880, 1175, 7618, 11368, 8027, 7493, 10005, 11323, 3447, 9895, 4722, 6044, 7904, 625, 11868, 698, 7113, 2646, 4466, 4491, 8325, 7996, 9240, 9164, 9823, 11878, 4327, 5784, 2200, 8124, 8394, 10919, 6838, 4470, 5004, 11258, 1700, 9749, 7267, 3291, 5023, 10871, 6070, 9403, 5901, 9943, 8758, 4294, 4932, 4953, 1412, 8528, 1610, 4755, 11497, 10683, 11052, 3488, 11049, 7507, 11314, 8713, 7822, 7673, 3074, 8287, 6047, 9665, 8527, 5172, 797, 8898, 7451, 2442, 4337, 10287, 736, 9597, 5547, 8691, 10327, 7307, 2665, 7278, 10178, 7357, 4975, 3368, 45, 5017, 11053, 6123, 5814, 1122, 3497, 8670, 4090, 11169]),
        list([70458, 323287, 46116, 35966, 242324, 285211, 201167, 78979, 263081, 231808, 339866, 361510, 19397, 148492, 269373, 12027, 165704, 329482, 268650, 76879, 225120, 280504, 188437, 41317, 267523, 92029, 318460, 227158, 334126, 195873, 227065, 95091, 14513, 105262, 41784, 26

In [36]:
train_features.take(11,1)[1000]

[319630,
 249057,
 237207,
 266709,
 76111,
 180482,
 114700,
 210485,
 342678,
 242324,
 56281,
 357296,
 201167,
 263081,
 156690,
 331234,
 339866,
 250497,
 220155,
 341185,
 352724,
 273326,
 164690,
 326596,
 26814,
 281779,
 184788,
 160549,
 145177,
 42244,
 32673,
 56033,
 91143,
 154110,
 214127,
 194824,
 319058,
 156567,
 229023,
 225120,
 28594,
 234793,
 156285,
 343158,
 113063,
 236420,
 334126,
 346342,
 195873,
 13098,
 80038,
 248971,
 283388,
 202009,
 30871,
 224840,
 185943,
 127832,
 356137,
 266697,
 332412,
 259484,
 350573,
 313967,
 308453,
 298121,
 110622,
 240598,
 31249,
 648,
 41697,
 12280,
 131737,
 211647,
 120278,
 95074,
 182338,
 277441,
 236110,
 324430,
 90962,
 54772,
 213105,
 292453,
 270710,
 255417,
 27721,
 73795,
 97945,
 117225,
 221357,
 104246,
 357487,
 216977,
 349226,
 264644,
 103866,
 52476,
 202923,
 135617]

In [38]:
test_targets_values[0:10]
print(train_targets_values.shape)
print(test_targets_values.shape)

(1065221, 2)
(1084385, 2)


In [40]:
user_cateids = np.zeros([2, 100])
for i in range(2):
    user_cateids[i] = train_features.take(10,1)[i]
print(user_cateids)

[[ 2854. 11757.  3880.  1175.  7618. 11368.  8027.  7493. 10005. 11323.
   3447.  9895.  4722.  6044.  7904.   625. 11868.   698.  7113.  2646.
   4466.  4491.  8325.  7996.  9240.  9164.  9823. 11878.  4327.  5784.
   2200.  8124.  8394. 10919.  6838.  4470.  5004. 11258.  1700.  9749.
   7267.  3291.  5023. 10871.  6070.  9403.  5901.  9943.  8758.  4294.
   4932.  4953.  1412.  8528.  1610.  4755. 11497. 10683. 11052.  3488.
  11049.  7507. 11314.  8713.  7822.  7673.  3074.  8287.  6047.  9665.
   8527.  5172.   797.  8898.  7451.  2442.  4337. 10287.   736.  9597.
   5547.  8691. 10327.  7307.  2665.  7278. 10178.  7357.  4975.  3368.
     45.  5017. 11053.  6123.  5814.  1122.  3497.  8670.  4090. 11169.]
 [ 2854. 11757.  3880.  1175.  7618. 11368.  8027.  7493. 10005. 11323.
   3447.  9895.  4722.  6044.  7904.   625. 11868.   698.  7113.  2646.
   4466.  4491.  8325.  7996.  9240.  9164.  9823. 11878.  4327.  5784.
   2200.  8124.  8394. 10919.  6838.  4470.  5004. 11258.  1700

### 从本地读取数据

In [None]:
UserID_map_max_len, ItemID_map_max_len, User_Cluster_map_max_len, \
User_CateIDs_map_max_len, User_BrandIDs_map_max_len, \
CategoryID_map_max_len, ShopID_map_max_len, BrandID_map_max_len, Com_CateID_map_max_len,\
Com_ShopID_map_max_len, Com_BrandID_map_max_len, PID_map_max_len, UserID2int, ItemID2int,\
User_Cluster2int, User_CateIDs2int, User_BrandIDs2int,  CategoryID2int, ShopID2int, BrandID2int, Com_CateID2int, \
Com_ShopID2int, Com_BrandID2int, PID2int, train_features, train_targets_values, train_data, \
test_features, test_targets_values, test_data = pickle.load(open('C:/Users/zhangy/Desktop/ctr_cvr_data/preprocess.p', mode='rb'))
print(0)

### Embedding Lookup 示例

In [None]:
import tensorflow as tf
import numpy as np

c = np.random.random([10,1])
b = tf.nn.embedding_lookup(c,[1])
a = tf.nn.embedding_lookup(c,1)
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    print(sess.run(b))
    print(sess.run(a))
    print(c)

In [None]:
### Tensorflow slice 示例