Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev zany fun9 #303

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Let's [**Get Started!**](https://deepctr-doc.readthedocs.io/en/latest/Quick-Star
| Deep Session Interest Network | [IJCAI 2019][Deep Session Interest Network for Click-Through Rate Prediction ](https://arxiv.org/abs/1905.06482) |
| FiBiNET | [RecSys 2019][FiBiNET: Combining Feature Importance and Bilinear feature Interaction for Click-Through Rate Prediction](https://arxiv.org/pdf/1905.09433.pdf) |
| FLEN | [arxiv 2019][FLEN: Leveraging Field for Scalable CTR Prediction](https://arxiv.org/pdf/1911.04690.pdf) |
| DMR | [AAAI 2020][Deep Match to Rank Model for Personalized Click-Through Rate Prediction](https://ojs.aaai.org//index.php/AAAI/article/view/5346) |

## Citation

Expand Down
11 changes: 8 additions & 3 deletions deepctr/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import tensorflow as tf

from .activation import Dice
from .core import DNN, LocalActivationUnit, PredictionLayer
from .core import DNN, LocalActivationUnit, PredictionLayer,SampledSoftmaxLayer,EmbeddingIndex,PoolingLayer
from .interaction import (CIN, FM, AFMLayer, BiInteractionPooling, CrossNet,
InnerProductLayer, InteractingLayer,
OutterProductLayer, FGCNNLayer, SENETLayer, BilinearInteraction,
FieldWiseBiInteraction, FwFMLayer)
from .normalization import LayerNormalization
from .sequence import (AttentionSequencePoolingLayer, BiasEncoding, BiLSTM,
KMaxPooling, SequencePoolingLayer,WeightedSequenceLayer,
Transformer, DynamicGRU)
Transformer, DynamicGRU,PositionalEncoding)
from .utils import NoMask, Hash,Linear,Add,combined_dnn_input

custom_objects = {'tf': tf,
Expand Down Expand Up @@ -42,5 +42,10 @@
'WeightedSequenceLayer':WeightedSequenceLayer,
'Add':Add,
'FieldWiseBiInteraction':FieldWiseBiInteraction,
'FwFMLayer': FwFMLayer
'FwFMLayer': FwFMLayer,
'SampledSoftmaxLayer': SampledSoftmaxLayer,
'EmbeddingIndex': EmbeddingIndex,
'PoolingLayer': PoolingLayer,
'PositionalEncoding':PositionalEncoding
}

124 changes: 117 additions & 7 deletions deepctr/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from tensorflow.python.keras.regularizers import l2

from .activation import activation_layer

from .utils import reduce_max,reduce_mean,reduce_sum,concat_func

class LocalActivationUnit(Layer):
"""The LocalActivationUnit used in DIN with which the representation of
Expand All @@ -36,19 +36,22 @@ class LocalActivationUnit(Layer):

- **use_bn**: bool. Whether use BatchNormalization before activation or not in attention net.

- **self_attention**: bool.Whether or not use self_attention.

- **seed**: A Python integer to use as random seed.

References
- [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf)
"""

def __init__(self, hidden_units=(64, 32), activation='sigmoid', l2_reg=0, dropout_rate=0, use_bn=False, seed=1024,
**kwargs):
def __init__(self, hidden_units=(64, 32), activation='sigmoid', l2_reg=0, dropout_rate=0, use_bn=False,
self_attention=False, seed=1024, **kwargs):
self.hidden_units = hidden_units
self.activation = activation
self.l2_reg = l2_reg
self.dropout_rate = dropout_rate
self.use_bn = use_bn
self.self_attention = self_attention
self.seed = seed
super(LocalActivationUnit, self).__init__(**kwargs)
self.supports_masking = True
Expand All @@ -63,8 +66,13 @@ def build(self, input_shape):
raise ValueError("Unexpected inputs dimensions %d and %d, expect to be 3 dimensions" % (
len(input_shape[0]), len(input_shape[1])))

if input_shape[0][-1] != input_shape[1][-1] or input_shape[0][1] != 1:
raise ValueError('A `LocalActivationUnit` layer requires '
if self.self_attention and input_shape[0][-1] != input_shape[1][-1]:
raise ValueError('A `LocalActivationUnit` layer with self_attention is True requires '
'inputs of a two inputs with shape (None,T,embedding_size) and (None,T,embedding_size)'
'Got different shapes: %s,%s' % (input_shape[0], input_shape[1]))

if not self.self_attention and (input_shape[0][-1] != input_shape[1][-1] or input_shape[0][1] != 1):
raise ValueError('A `LocalActivationUnit` layer with not self_attention requires '
'inputs of a two inputs with shape (None,1,embedding_size) and (None,T,embedding_size)'
'Got different shapes: %s,%s' % (input_shape[0], input_shape[1]))
size = 4 * \
Expand All @@ -88,8 +96,11 @@ def call(self, inputs, training=None, **kwargs):

query, keys = inputs

keys_len = keys.get_shape()[1]
queries = K.repeat_elements(query, keys_len, 1)
if not self.self_attention:
keys_len = keys.get_shape()[1]
queries = K.repeat_elements(query, keys_len, 1)
else:
queries = query

att_input = tf.concat(
[queries, keys, queries - keys, queries * keys], axis=-1)
Expand Down Expand Up @@ -255,3 +266,102 @@ def get_config(self, ):
config = {'task': self.task, 'use_bias': self.use_bias}
base_config = super(PredictionLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


class SampledSoftmaxLayer(Layer):

def __init__(self, num_sampled=2, **kwargs):
self.num_sampled = num_sampled
super(SampledSoftmaxLayer, self).__init__(**kwargs)

def build(self, input_shape):
self.size = input_shape[0][0]
self.zero_bias = self.add_weight(shape=[self.size],
initializer=Zeros,
dtype=tf.float32,
trainable=False,
name="bias")
super(SampledSoftmaxLayer, self).build(input_shape)

def call(self, inputs_with_label_idx, training=None, **kwargs):
"""
The first input should be the model as it were, and the second the
target (i.e., a repeat of the training data) to compute the labels
argument
"""
embeddings, inputs, label_idx = inputs_with_label_idx

loss = tf.nn.sampled_softmax_loss(weights=embeddings, # self.item_embedding.
biases=self.zero_bias,
labels=label_idx,
inputs=inputs,
num_sampled=self.num_sampled,
num_classes=self.size, # self.target_song_size
)

return reduce_mean(loss)

def compute_output_shape(self, input_shape):
return (None, 1)

def get_config(self, ):
config = {'num_sampled': self.num_sampled}
base_config = super(SampledSoftmaxLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


class EmbeddingIndex(Layer):

def __init__(self, index, **kwargs):
self.index = index
super(EmbeddingIndex, self).__init__(**kwargs)

def build(self, input_shape):
super(EmbeddingIndex, self).build(
input_shape) # Be sure to call this somewhere!

def call(self, x, **kwargs):
return tf.constant(self.index)

def get_config(self, ):
config = {'index': self.index, }
base_config = super(EmbeddingIndex, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


class PoolingLayer(Layer):

def __init__(self, mode='mean', supports_masking=False, **kwargs):

if mode not in ['sum', 'mean', 'max']:
raise ValueError("mode must be sum or mean")
self.mode = mode
self.eps = tf.constant(1e-8, tf.float32)
super(PoolingLayer, self).__init__(**kwargs)

self.supports_masking = supports_masking

def build(self, input_shape):

super(PoolingLayer, self).build(
input_shape) # Be sure to call this somewhere!

def call(self, seq_value_len_list, mask=None, **kwargs):
if not isinstance(seq_value_len_list, list):
seq_value_len_list = [seq_value_len_list]
if len(seq_value_len_list) == 1:
return seq_value_len_list[0]
expand_seq_value_len_list = list(map(lambda x: tf.expand_dims(x, axis=-1), seq_value_len_list))
a = concat_func(expand_seq_value_len_list)
if self.mode == "mean":
hist = reduce_mean(a, axis=-1, )
elif self.mode == "sum":
hist = reduce_sum(a, axis=-1, )
elif self.mode == "max":
hist = reduce_max(a, axis=-1, )
return hist

def get_config(self, ):
config = {'mode': self.mode, 'supports_masking': self.supports_masking}
base_config = super(PoolingLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))