diff --git a/README.md b/README.md index 4727aa0e..70ac2da2 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ Let's [**Get Started!**](https://deepctr-doc.readthedocs.io/en/latest/Quick-Star | Deep Session Interest Network | [IJCAI 2019][Deep Session Interest Network for Click-Through Rate Prediction ](https://arxiv.org/abs/1905.06482) | | FiBiNET | [RecSys 2019][FiBiNET: Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction](https://arxiv.org/pdf/1905.09433.pdf) | | FLEN | [arxiv 2019][FLEN: Leveraging Field for Scalable CTR Prediction](https://arxiv.org/pdf/1911.04690.pdf) | +| DMR | [AAAI 2020][Deep Match to Rank Model for Personalized Click-Through Rate Prediction](https://ojs.aaai.org//index.php/AAAI/article/view/5346) | ## Citation diff --git a/deepctr/layers/__init__.py b/deepctr/layers/__init__.py index 89cc60fb..c718c84b 100644 --- a/deepctr/layers/__init__.py +++ b/deepctr/layers/__init__.py @@ -1,7 +1,7 @@ import tensorflow as tf from .activation import Dice -from .core import DNN, LocalActivationUnit, PredictionLayer +from .core import DNN, LocalActivationUnit, PredictionLayer,SampledSoftmaxLayer,EmbeddingIndex,PoolingLayer from .interaction import (CIN, FM, AFMLayer, BiInteractionPooling, CrossNet, InnerProductLayer, InteractingLayer, OutterProductLayer, FGCNNLayer, SENETLayer, BilinearInteraction, @@ -9,7 +9,7 @@ from .normalization import LayerNormalization from .sequence import (AttentionSequencePoolingLayer, BiasEncoding, BiLSTM, KMaxPooling, SequencePoolingLayer,WeightedSequenceLayer, - Transformer, DynamicGRU) + Transformer, DynamicGRU,PositionalEncoding) from .utils import NoMask, Hash,Linear,Add,combined_dnn_input custom_objects = {'tf': tf, @@ -42,5 +42,10 @@ 'WeightedSequenceLayer':WeightedSequenceLayer, 'Add':Add, 'FieldWiseBiInteraction':FieldWiseBiInteraction, - 'FwFMLayer': FwFMLayer + 'FwFMLayer': FwFMLayer, + 'SampledSoftmaxLayer': SampledSoftmaxLayer, + 'EmbeddingIndex': EmbeddingIndex, + 'PoolingLayer': PoolingLayer, + 'PositionalEncoding':PositionalEncoding } + diff --git a/deepctr/layers/core.py b/deepctr/layers/core.py index f81bf97b..fe2fa7e4 100644 --- a/deepctr/layers/core.py +++ b/deepctr/layers/core.py @@ -13,7 +13,7 @@ from tensorflow.python.keras.regularizers import l2 from .activation import activation_layer - +from .utils import reduce_max,reduce_mean,reduce_sum,concat_func class LocalActivationUnit(Layer): """The LocalActivationUnit used in DIN with which the representation of @@ -36,19 +36,22 @@ class LocalActivationUnit(Layer): - **use_bn**: bool. Whether use BatchNormalization before activation or not in attention net. + - **self_attention**: bool.Whether or not use self_attention. + - **seed**: A Python integer to use as random seed. References - [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf) """ - def __init__(self, hidden_units=(64, 32), activation='sigmoid', l2_reg=0, dropout_rate=0, use_bn=False, seed=1024, - **kwargs): + def __init__(self, hidden_units=(64, 32), activation='sigmoid', l2_reg=0, dropout_rate=0, use_bn=False, + self_attention=False, seed=1024, **kwargs): self.hidden_units = hidden_units self.activation = activation self.l2_reg = l2_reg self.dropout_rate = dropout_rate self.use_bn = use_bn + self.self_attention = self_attention self.seed = seed super(LocalActivationUnit, self).__init__(**kwargs) self.supports_masking = True @@ -63,8 +66,13 @@ def build(self, input_shape): raise ValueError("Unexpected inputs dimensions %d and %d, expect to be 3 dimensions" % ( len(input_shape[0]), len(input_shape[1]))) - if input_shape[0][-1] != input_shape[1][-1] or input_shape[0][1] != 1: - raise ValueError('A `LocalActivationUnit` layer requires ' + if self.self_attention and input_shape[0][-1] != input_shape[1][-1]: + raise ValueError('A `LocalActivationUnit` layer with self_attention is True requires ' + 'inputs of a two inputs with shape (None,T,embedding_size) and (None,T,embedding_size)' + 'Got different shapes: %s,%s' % (input_shape[0], input_shape[1])) + + if not self.self_attention and (input_shape[0][-1] != input_shape[1][-1] or input_shape[0][1] != 1): + raise ValueError('A `LocalActivationUnit` layer with not self_attention requires ' 'inputs of a two inputs with shape (None,1,embedding_size) and (None,T,embedding_size)' 'Got different shapes: %s,%s' % (input_shape[0], input_shape[1])) size = 4 * \ @@ -88,8 +96,11 @@ def call(self, inputs, training=None, **kwargs): query, keys = inputs - keys_len = keys.get_shape()[1] - queries = K.repeat_elements(query, keys_len, 1) + if not self.self_attention: + keys_len = keys.get_shape()[1] + queries = K.repeat_elements(query, keys_len, 1) + else: + queries = query att_input = tf.concat( [queries, keys, queries - keys, queries * keys], axis=-1) @@ -255,3 +266,102 @@ def get_config(self, ): config = {'task': self.task, 'use_bias': self.use_bias} base_config = super(PredictionLayer, self).get_config() return dict(list(base_config.items()) + list(config.items())) + + +class SampledSoftmaxLayer(Layer): + + def __init__(self, num_sampled=2, **kwargs): + self.num_sampled = num_sampled + super(SampledSoftmaxLayer, self).__init__(**kwargs) + + def build(self, input_shape): + self.size = input_shape[0][0] + self.zero_bias = self.add_weight(shape=[self.size], + initializer=Zeros, + dtype=tf.float32, + trainable=False, + name="bias") + super(SampledSoftmaxLayer, self).build(input_shape) + + def call(self, inputs_with_label_idx, training=None, **kwargs): + """ + The first input should be the model as it were, and the second the + target (i.e., a repeat of the training data) to compute the labels + argument + """ + embeddings, inputs, label_idx = inputs_with_label_idx + + loss = tf.nn.sampled_softmax_loss(weights=embeddings, # self.item_embedding. + biases=self.zero_bias, + labels=label_idx, + inputs=inputs, + num_sampled=self.num_sampled, + num_classes=self.size, # self.target_song_size + ) + + return reduce_mean(loss) + + def compute_output_shape(self, input_shape): + return (None, 1) + + def get_config(self, ): + config = {'num_sampled': self.num_sampled} + base_config = super(SampledSoftmaxLayer, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class EmbeddingIndex(Layer): + + def __init__(self, index, **kwargs): + self.index = index + super(EmbeddingIndex, self).__init__(**kwargs) + + def build(self, input_shape): + super(EmbeddingIndex, self).build( + input_shape) # Be sure to call this somewhere! + + def call(self, x, **kwargs): + return tf.constant(self.index) + + def get_config(self, ): + config = {'index': self.index, } + base_config = super(EmbeddingIndex, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class PoolingLayer(Layer): + + def __init__(self, mode='mean', supports_masking=False, **kwargs): + + if mode not in ['sum', 'mean', 'max']: + raise ValueError("mode must be sum or mean") + self.mode = mode + self.eps = tf.constant(1e-8, tf.float32) + super(PoolingLayer, self).__init__(**kwargs) + + self.supports_masking = supports_masking + + def build(self, input_shape): + + super(PoolingLayer, self).build( + input_shape) # Be sure to call this somewhere! + + def call(self, seq_value_len_list, mask=None, **kwargs): + if not isinstance(seq_value_len_list, list): + seq_value_len_list = [seq_value_len_list] + if len(seq_value_len_list) == 1: + return seq_value_len_list[0] + expand_seq_value_len_list = list(map(lambda x: tf.expand_dims(x, axis=-1), seq_value_len_list)) + a = concat_func(expand_seq_value_len_list) + if self.mode == "mean": + hist = reduce_mean(a, axis=-1, ) + elif self.mode == "sum": + hist = reduce_sum(a, axis=-1, ) + elif self.mode == "max": + hist = reduce_max(a, axis=-1, ) + return hist + + def get_config(self, ): + config = {'mode': self.mode, 'supports_masking': self.supports_masking} + base_config = super(PoolingLayer, self).get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/deepctr/layers/sequence.py b/deepctr/layers/sequence.py index 3c767a07..ccad61e0 100644 --- a/deepctr/layers/sequence.py +++ b/deepctr/layers/sequence.py @@ -39,15 +39,18 @@ class SequencePoolingLayer(Layer): Arguments - **mode**:str.Pooling operation to be used,can be sum,mean or max. + - **padding_first**: bool.Is padding at the beginning of the sequence + - **supports_masking**:If True,the input need to support masking. """ - def __init__(self, mode='mean', supports_masking=False, **kwargs): + def __init__(self, mode='mean', padding_first=False, supports_masking=False, **kwargs): if mode not in ['sum', 'mean', 'max']: raise ValueError("mode must be sum or mean") self.mode = mode self.eps = tf.constant(1e-8, tf.float32) + self.padding_first = padding_first super(SequencePoolingLayer, self).__init__(**kwargs) self.supports_masking = supports_masking @@ -74,6 +77,9 @@ def call(self, seq_value_len_list, mask=None, **kwargs): self.seq_len_max, dtype=tf.float32) mask = tf.transpose(mask, (0, 2, 1)) + if self.padding_first: + mask = tf.reverse(mask, axis=[-1]) + embedding_size = uiseq_embed_list.shape[-1] mask = tf.tile(mask, [1, 1, embedding_size]) @@ -188,7 +194,8 @@ class AttentionSequencePoolingLayer(Layer): Input shape - A list of three tensor: [query,keys,keys_length] - - query is a 3D tensor with shape: ``(batch_size, 1, embedding_size)`` + - query is a 3D tensor with shape: ``(batch_size, T, embedding_size)`` with self_attention is True, + and ``(batch_size, 1, embedding_size)``with self_attention is False, - keys is a 3D tensor with shape: ``(batch_size, T, embedding_size)`` @@ -204,6 +211,12 @@ class AttentionSequencePoolingLayer(Layer): - **weight_normalization**: bool.Whether normalize the attention score of local activation unit. + - **padding_first**: bool.Is padding at the beginning of the sequence + + - **causality**: bool.Whether or not use blinding. + + - **self_attention**: bool.Whether or not use self_attention. + - **supports_masking**:If True,the input need to support masking. References @@ -211,13 +224,16 @@ class AttentionSequencePoolingLayer(Layer): """ def __init__(self, att_hidden_units=(80, 40), att_activation='sigmoid', weight_normalization=False, - return_score=False, + return_score=False, padding_first=False, causality=False,self_attention=False, supports_masking=False, **kwargs): self.att_hidden_units = att_hidden_units self.att_activation = att_activation self.weight_normalization = weight_normalization self.return_score = return_score + self.padding_first = padding_first + self.causality = causality + self.self_attention = self_attention super(AttentionSequencePoolingLayer, self).__init__(**kwargs) self.supports_masking = supports_masking @@ -232,14 +248,20 @@ def build(self, input_shape): "Unexpected inputs dimensions,the 3 tensor dimensions are %d,%d and %d , expect to be 3,3 and 2" % ( len(input_shape[0]), len(input_shape[1]), len(input_shape[2]))) - if input_shape[0][-1] != input_shape[1][-1] or input_shape[0][1] != 1 or input_shape[2][1] != 1: - raise ValueError('A `AttentionSequencePoolingLayer` layer requires ' - 'inputs of a 3 tensor with shape (None,1,embedding_size),(None,T,embedding_size) and (None,1)' - 'Got different shapes: %s' % (input_shape)) + if self.self_attention and input_shape[0][-1] != input_shape[1][-1]: + raise ValueError('A `LocalActivationUnit` layer with self_attention is True requires ' + 'inputs of a two inputs with shape (None,T,embedding_size) and (None,T,embedding_size)' + 'Got different shapes: %s,%s' % (input_shape[0], input_shape[1])) + + if not self.self_attention and (input_shape[0][-1] != input_shape[1][-1] or input_shape[0][1] != 1): + raise ValueError('A `LocalActivationUnit` layer with not self_attention requires ' + 'inputs of a two inputs with shape (None,1,embedding_size) and (None,T,embedding_size)' + 'Got different shapes: %s,%s' % (input_shape[0], input_shape[1])) else: pass self.local_att = LocalActivationUnit( - self.att_hidden_units, self.att_activation, l2_reg=0, dropout_rate=0, use_bn=False, seed=1024, ) + self.att_hidden_units, self.att_activation, l2_reg=0, dropout_rate=0, use_bn=False, + self_attention=self.self_attention,seed=1024) super(AttentionSequencePoolingLayer, self).build( input_shape) # Be sure to call this somewhere! @@ -253,11 +275,13 @@ def call(self, inputs, mask=None, training=None, **kwargs): key_masks = tf.expand_dims(mask[-1], axis=1) else: - queries, keys, keys_length = inputs hist_len = keys.get_shape()[1] key_masks = tf.sequence_mask(keys_length, hist_len) + if self.padding_first: + key_masks = tf.reverse(key_masks, axis=[-1]) + attention_score = self.local_att([queries, keys], training=training) outputs = tf.transpose(attention_score, (0, 2, 1)) @@ -269,6 +293,14 @@ def call(self, inputs, mask=None, training=None, **kwargs): outputs = tf.where(key_masks, outputs, paddings) + if self.causality: + scores_tile = tf.tile(tf.reduce_sum(outputs, axis=1), [1, tf.shape(outputs)[-1]]) + scores_tile = tf.reshape(scores_tile, [-1, tf.shape(outputs)[-1], tf.shape(outputs)[-1]]) + diag_vals = tf.ones_like(scores_tile) + tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() + paddings = tf.ones_like(tril) * (-2 ** 32 + 1) + outputs = tf.where(tf.equal(tril, 0), paddings, scores_tile) + if self.weight_normalization: outputs = softmax(outputs) @@ -295,6 +327,7 @@ def get_config(self, ): config = {'att_hidden_units': self.att_hidden_units, 'att_activation': self.att_activation, 'weight_normalization': self.weight_normalization, 'return_score': self.return_score, + 'padding_first': self.padding_first, 'causality': self.causality, 'supports_masking': self.supports_masking} base_config = super(AttentionSequencePoolingLayer, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -648,6 +681,101 @@ def positional_encoding(inputs, return outputs + inputs +class PositionalEncoding(Layer): + """The PositionalEncoding is used to encode position information for the behavior sequence. + + Input shape + - A 3D tensor with shape: ``(batch_size, T, embedding_size)`` + + Output shape + - A 3D tensor with shape: ``(batch_size, T, embedding_size)`` with use_concat is False + or A 3D tensor with shape: ``(batch_size, T, embedding_size*2)`` with use_concat is True; + + Arguments + - **use_sinusoidal**: Whether or not use sinusoidal positional encoding. + + - **zero_pad**: Bool.If True, all the values of the first row (id = 0) should be constant zero + + - **scale**:Bool.If True, the output will be multiplied by sqrt num_units(check details from paper) + + - **use_concat**:Bool.If True, the positional encoding will concat with input; else, the it will be added to the input + """ + def __init__(self, use_sinusoidal=True, zero_pad=False, scale=True, use_concat=False, seed=1024, **kwargs): + self.use_sinusoidal = use_sinusoidal + self.zero_pad = zero_pad + self.scale = scale + self.use_concat = use_concat + self.seed = seed + super(PositionalEncoding, self).__init__(**kwargs) + + def build(self, input_shape): + + embed_size = input_shape[2].value + seq_len_max = input_shape[1].value + if not self.use_sinusoidal: + self.position_embedding = self.add_weight('position_embedding', shape=(seq_len_max, embed_size), + initializer=TruncatedNormal( + mean=0.0, stddev=0.0001, seed=self.seed)) + + # Be sure to call this somewhere! + super(PositionalEncoding, self).build(input_shape) + + def call(self, inputs, mask=None): + """ + :param inputs: None * field_size * embedding_size + :return: None * field_size * embedding_size or None * (field_size*2) * embedding_size + """ + _, T, num_units = inputs.get_shape().as_list() + position_ind = tf.expand_dims(tf.range(T), 0) + if self.use_sinusoidal: + position_enc = np.array([ + [pos / np.power(10000, 2. * i / num_units) + for i in range(num_units)] + for pos in range(T)]) + + # Second part, apply the cosine to even columns and sin to odds. + position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # dim 2i + position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # dim 2i+1 + # Convert to a tensor + lookup_table = tf.convert_to_tensor(position_enc) + lookup_table = tf.cast(lookup_table, tf.float32) + else: + lookup_table = self.position_embedding + + if self.zero_pad: + lookup_table = tf.concat((tf.zeros(shape=[1, num_units]), + lookup_table[1:, :]), 0) + + outputs = tf.nn.embedding_lookup(lookup_table, position_ind) + + if self.scale: + outputs = outputs * num_units ** 0.5 + + if self.use_concat: + outputs = tf.squeeze(outputs,axis=0) + outputs = tf.tile(outputs, [tf.shape(inputs)[0], 1]) + outputs = tf.reshape(outputs, [tf.shape(inputs)[0], -1, outputs.get_shape().as_list()[1]]) + outputs = tf.concat([outputs, inputs], -1) + else: + outputs = outputs + inputs + + return outputs + + def compute_output_shape(self, input_shape): + + return input_shape + + def compute_mask(self, inputs, mask=None): + return mask + + def get_config(self, ): + + config = {'use_sinusoidal': self.use_sinusoidal,'zero_pad': self.zero_pad,'scale': self.scale, + 'use_concat': self.use_concat, 'seed': self.seed,} + base_config = super(PositionalEncoding, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + class BiasEncoding(Layer): def __init__(self, sess_max_count, seed=1024, **kwargs): self.sess_max_count = sess_max_count diff --git a/deepctr/models/__init__.py b/deepctr/models/__init__.py index bc536a5e..ced23f94 100644 --- a/deepctr/models/__init__.py +++ b/deepctr/models/__init__.py @@ -18,6 +18,7 @@ from .fibinet import FiBiNET from .flen import FLEN from .fwfm import FwFM +from .dmr import DMR __all__ = ["AFM", "CCPM","DCN", "MLR", "DeepFM", "MLR", "NFM", "DIN", "DIEN", "FNN", "PNN", - "WDL", "xDeepFM", "AutoInt", "ONN", "FGCNN", "DSIN", "FiBiNET", 'FLEN', "FwFM"] + "WDL", "xDeepFM", "AutoInt", "ONN", "FGCNN", "DSIN", "FiBiNET", 'FLEN', "FwFM", "DMR"] diff --git a/deepctr/models/dmr.py b/deepctr/models/dmr.py new file mode 100644 index 00000000..186d78a0 --- /dev/null +++ b/deepctr/models/dmr.py @@ -0,0 +1,158 @@ +# -*- coding:utf-8 -*- +""" +Author: + Fan Zhang,826427920@qq.com +Reference: + [1] Ze Lyu, Yu Dong, Chengfu Huo, Weijun Ren. Deep Match to Rank Model for Personalized Click-Through Rate Prediction[J]. (https://github.com/lvze92/DMR) +""" + +import tensorflow as tf +from tensorflow.python.keras.layers import Lambda,Dot,Multiply +from tensorflow.python.keras import backend + +from ..feature_column import SparseFeat, VarLenSparseFeat, DenseFeat, build_input_features +from ..inputs import create_embedding_matrix, embedding_lookup, get_dense_input, varlen_embedding_lookup, \ + get_varlen_pooling_list +from ..layers.core import DNN, PredictionLayer,SampledSoftmaxLayer,EmbeddingIndex,PoolingLayer +from ..layers.sequence import AttentionSequencePoolingLayer,PositionalEncoding,SequencePoolingLayer +from ..layers.utils import concat_func, NoMask, combined_dnn_input + + +def User2ItemNetwork(query, keys, user_behavior_length, deep_match_id, features, sparse_feature_columns, att_hidden_size, + att_activation, padding_first, att_weight_normalization, l2_reg_embedding, seed): + dm_item_id = list(filter(lambda x: x.name == deep_match_id, sparse_feature_columns)) + dm_item_id_input = features[deep_match_id] + dm_hist_item_id_input = features['hist_'+deep_match_id] + dm_iid_embedding_table = create_embedding_matrix(dm_item_id, l2_reg_embedding, seed, prefix="deep_match_")[ + deep_match_id] + dm_iid_embedding = dm_iid_embedding_table(dm_item_id_input) + + hist = AttentionSequencePoolingLayer(att_hidden_size, att_activation, padding_first=padding_first, causality=True, + weight_normalization=att_weight_normalization, self_attention=True, + supports_masking=False)([query, keys, user_behavior_length]) + + hist = tf.keras.layers.Dense(dm_iid_embedding.get_shape().as_list()[-1], name='dm_align_2')(hist) + + user_embedding = tf.keras.layers.PReLU()(Lambda(lambda x: x[:, -1, :])(hist)) + dm_embedding = tf.keras.layers.PReLU()(Lambda(lambda x: x[:, -2, :])(hist)) + + rel_u2i = Dot(axes=-1)([dm_iid_embedding, user_embedding]) + + item_index = EmbeddingIndex(list(range(dm_item_id[0].vocabulary_size)))(dm_item_id[0]) + item_embedding_matrix = dm_iid_embedding_table + item_embedding_weight = NoMask()(item_embedding_matrix(item_index)) + pooling_item_embedding_weight = PoolingLayer()([item_embedding_weight]) + + aux_loss = SampledSoftmaxLayer()([pooling_item_embedding_weight, dm_embedding, + tf.cast(tf.reshape(dm_hist_item_id_input[:, -1], [-1, 1]), tf.int64), ]) + return rel_u2i,aux_loss + +def Item2ItemNetwork(query, keys, user_behavior_length, att_hidden_size, att_activation, padding_first): + scores = AttentionSequencePoolingLayer(att_hidden_size, att_activation, padding_first=padding_first, + weight_normalization=False, supports_masking=False, return_score=True)([ + query, keys, user_behavior_length]) + scores_norm = tf.keras.layers.Activation('softmax')(scores) + + att_sum = Lambda(lambda x: backend.batch_dot(x[0], x[1]))([scores_norm, keys]) + rel_i2i = Lambda(lambda z: backend.sum(z, axis=-1, keepdims=False))(scores) + return att_sum,rel_i2i + +def DMR(dnn_feature_columns, history_feature_list, deep_match_id, dnn_use_bn=False,padding_first=True, + dnn_hidden_units=(200, 80), dnn_activation='relu', att_hidden_size=(80, 40), att_activation="sigmoid", + att_weight_normalization=True, l2_reg_dnn=0, l2_reg_embedding=1e-6, dnn_dropout=0, seed=1024, + task='binary'): + """Instantiates the Deep Interest Network architecture. + + :param dnn_feature_columns: An iterable containing all the features used by deep part of the model. + :param history_feature_list: list,to indicate sequence sparse field + :param deep_match_id: str. An id that appears in the history_feature_list will be use to deep match. + :param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in deep net + :param padding_first: bool. Is padding at the beginning of the sequence + :param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of deep net + :param dnn_activation: Activation function to use in deep net + :param att_hidden_size: list,list of positive integer , the layer number and units in each layer of attention net + :param att_activation: Activation function to use in attention net + :param att_weight_normalization: bool.Whether normalize the attention score of local activation unit. + :param l2_reg_dnn: float. L2 regularizer strength applied to DNN + :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector + :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. + :param seed: integer ,to use as random seed. + :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss + :return: A Keras model instance. + + """ + if not padding_first: + raise ValueError("Now DMR only support padding first, " + "input history sequence should be like [0,0,1,2,3](0 is the padding).") + if deep_match_id not in history_feature_list: + raise ValueError("deep_match_id must appear in the history_feature_list.") + + features = build_input_features(dnn_feature_columns) + + user_behavior_length = features["seq_length"] + + sparse_feature_columns = list( + filter(lambda x: isinstance(x, SparseFeat), dnn_feature_columns)) if dnn_feature_columns else [] + dense_feature_columns = list( + filter(lambda x: isinstance(x, DenseFeat), dnn_feature_columns)) if dnn_feature_columns else [] + varlen_sparse_feature_columns = list( + filter(lambda x: isinstance(x, VarLenSparseFeat), dnn_feature_columns)) if dnn_feature_columns else [] + + history_feature_columns = [] + sparse_varlen_feature_columns = [] + history_fc_names = list(map(lambda x: "hist_" + x, history_feature_list)) + for fc in varlen_sparse_feature_columns: + feature_name = fc.name + if feature_name in history_fc_names: + history_feature_columns.append(fc) + else: + sparse_varlen_feature_columns.append(fc) + + inputs_list = list(features.values()) + + embedding_dict = create_embedding_matrix(dnn_feature_columns, l2_reg_embedding, seed, prefix="") + + query_emb_list = embedding_lookup(embedding_dict, features, sparse_feature_columns, history_feature_list, + history_feature_list, to_list=True) + keys_emb_list = embedding_lookup(embedding_dict, features, history_feature_columns, history_fc_names, + history_fc_names, to_list=True) + dnn_input_emb_list = embedding_lookup(embedding_dict, features, sparse_feature_columns, + mask_feat_list=history_feature_list, to_list=True) + dense_value_list = get_dense_input(features, dense_feature_columns) + + sequence_embed_dict = varlen_embedding_lookup(embedding_dict, features, sparse_varlen_feature_columns) + sequence_embed_list = get_varlen_pooling_list(sequence_embed_dict, features, sparse_varlen_feature_columns, + to_list=True) + + dnn_input_emb_list += sequence_embed_list + deep_input_emb = concat_func(dnn_input_emb_list) + + keys_emb = concat_func(keys_emb_list, mask=True) + query_emb = concat_func(query_emb_list, mask=True) + + keys_emb_pos = PositionalEncoding(use_sinusoidal=False, zero_pad=True, scale=False, use_concat=True)(keys_emb) + keys_emb_pos = tf.keras.layers.Dense(keys_emb.get_shape().as_list()[-1], name='dm_align_1')(keys_emb_pos) + keys_emb_pos = tf.keras.layers.PReLU()(keys_emb_pos) + + rel_u2i, aux_loss = User2ItemNetwork(keys_emb_pos,keys_emb,user_behavior_length,deep_match_id,features,sparse_feature_columns, + att_hidden_size,att_activation,padding_first,att_weight_normalization, + l2_reg_embedding,seed) + + att_sum, rel_i2i = Item2ItemNetwork(query_emb,keys_emb,user_behavior_length,att_hidden_size,att_activation,padding_first) + + hist_embedding = SequencePoolingLayer(mode='sum', padding_first=True)([keys_emb, user_behavior_length]) + + sim_u2i = Multiply()([query_emb, hist_embedding]) + + deep_input_emb = tf.keras.layers.Concatenate()([NoMask()(deep_input_emb),NoMask()(sim_u2i),hist_embedding,att_sum]) + deep_input_emb = tf.keras.layers.Flatten()(deep_input_emb) + dnn_input = combined_dnn_input([deep_input_emb,rel_u2i,rel_i2i], dense_value_list) + output = DNN(dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout, dnn_use_bn, seed=seed)(dnn_input) + final_logit = tf.keras.layers.Dense(1, use_bias=False, + kernel_initializer=tf.keras.initializers.glorot_normal(seed))(output) + + output = PredictionLayer(task)(final_logit) + + model = tf.keras.models.Model(inputs=inputs_list, outputs=output) + model.add_loss(aux_loss) + return model diff --git a/examples/run_dmr.py b/examples/run_dmr.py new file mode 100644 index 00000000..d7be2ff6 --- /dev/null +++ b/examples/run_dmr.py @@ -0,0 +1,39 @@ +import numpy as np +from deepctr.models import DMR +from deepctr.feature_column import SparseFeat, VarLenSparseFeat, DenseFeat,get_feature_names + +def get_xy_fd(): + + feature_columns = [SparseFeat('user',3,embedding_dim=10),SparseFeat( + 'gender', 2,embedding_dim=4), SparseFeat('item_id', 3 + 1,embedding_dim=8), SparseFeat('cate_id', 2 + 1,embedding_dim=4),DenseFeat('pay_score', 1)] + + feature_columns += [ + VarLenSparseFeat(SparseFeat('hist_item_id', vocabulary_size=3 + 1, embedding_dim=8, embedding_name='item_id'), + maxlen=4, length_name="seq_length"), + VarLenSparseFeat(SparseFeat('hist_cate_id', 2 + 1, embedding_dim=4, embedding_name='cate_id'), maxlen=4, + length_name="seq_length")] + + behavior_feature_list = ["item_id", "cate_id"] + uid = np.array([0, 1, 2]) + ugender = np.array([0, 1, 0]) + iid = np.array([1, 2, 3]) # 0 is mask value + cate_id = np.array([1, 2, 2]) # 0 is mask value + pay_score = np.array([0.1, 0.2, 0.3]) + + hist_iid = np.array([[1, 2, 3, 0], [3, 2, 1, 0], [1, 2, 0, 0]]) + hist_cate_id = np.array([[1, 2, 2, 0], [2, 2, 1, 0], [1, 2, 0, 0]]) + behavior_length = np.array([3, 3, 2]) + + feature_dict = {'user': uid, 'gender': ugender, 'item_id': iid, 'cate_id': cate_id, + 'hist_item_id': hist_iid, 'hist_cate_id': hist_cate_id, 'pay_score': pay_score, "seq_length": behavior_length} + x = {name:feature_dict[name] for name in get_feature_names(feature_columns)} + y = np.array([1, 0, 1]) + return x, y, feature_columns, behavior_feature_list + +if __name__ == "__main__": + x, y, feature_columns, behavior_feature_list = get_xy_fd() + deep_match_id = "item_id" + model = DMR(feature_columns, behavior_feature_list, deep_match_id) + model.compile('adam', 'binary_crossentropy', + metrics=['binary_crossentropy']) + history = model.fit(x, y, verbose=1, epochs=10, validation_split=0.5) diff --git a/setup.py b/setup.py index 5f06066d..16bbd85c 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ long_description = fh.read() REQUIRED_PACKAGES = [ - 'h5py','requests' + 'h5py==2.10.0','requests' ] setuptools.setup(