In [18]:
import tensorflow as tf
from tensorflow.keras import layers


class Dice(layers.Layer):
    def __init__(self):
        super(Dice, self).__init__()
        self.bn = layers.BatchNormalization(center=False, scale=False)
        self.alpha = self.add_weight(shape=(), initializer='zeros', trainable=True)

    def call(self, x):
        x_normed = self.bn(x)
        p = tf.sigmoid(x_normed)
        return p * x + (1 - p) * self.alpha * x


class Attention(layers.Layer):
    def __init__(self, hidden_units, use_softmax=False):
        super(Attention, self).__init__()
        self.dense_layer = tf.keras.Sequential()
        for units in hidden_units:
            self.dense_layer.add(layers.Dense(units, activation='sigmoid'))
        self.output_layer = layers.Dense(1)
        self.use_softmax = use_softmax

    def call(self, query, keys, keys_mask):
        # query: [batch_size, embedding_dim]
        # keys: [batch_size, max_length, embedding_dim]
        # keys_mask: [batch_size, max_length]
        query = tf.tile(tf.expand_dims(query, 1), [1, tf.shape(keys)[1], 1])  # [batch_size, max_length, embedding_dim]
        inputs = tf.concat([query, keys, query - keys, query * keys], axis=-1)
        outputs = self.dense_layer(inputs)
        scores = self.output_layer(outputs)  # [batch_size, max_length, 1]
        scores = tf.squeeze(scores, axis=-1)  # [batch_size, max_length]

        # 应用掩码
        paddings = tf.ones_like(scores) * (-2 ** 32 + 1)
        scores = tf.where(keys_mask, scores, paddings)

        if self.use_softmax:
            scores = tf.nn.softmax(scores)  # [batch_size, max_length]
        output = tf.reduce_sum(keys * tf.expand_dims(scores, -1), axis=1)  # [batch_size, embedding_dim]
        return output


class DIN(tf.keras.Model):
    def __init__(self, user_feature_dim, item_feature_dim, embedding_dim, hidden_units, attention_hidden_units):
        super(DIN, self).__init__()
        self.user_embedding = layers.Embedding(user_feature_dim, embedding_dim)
        self.item_embedding = layers.Embedding(item_feature_dim, embedding_dim)
        self.attention = Attention(attention_hidden_units)
        self.dense_layer = tf.keras.Sequential()
        for units in hidden_units:
            self.dense_layer.add(layers.Dense(units))
            self.dense_layer.add(Dice())
        self.output_layer = layers.Dense(1, activation='sigmoid')

    def call(self, user_features, item_features, hist_item_features, hist_mask):
        user_emb = self.user_embedding(user_features)  # [batch_size, embedding_dim]
        item_emb = self.item_embedding(item_features)  # [batch_size, embedding_dim]
        hist_item_embs = self.item_embedding(hist_item_features)  # [batch_size, max_length, embedding_dim]

        hist_attention_emb = self.attention(item_emb, hist_item_embs, hist_mask)  # [batch_size, embedding_dim]

        inputs = tf.concat([user_emb, item_emb, hist_attention_emb], axis=-1)
        outputs = self.dense_layer(inputs)
        output = self.output_layer(outputs)
        return output


# 示例使用
user_feature_dim = 100
item_feature_dim = 200
embedding_dim = 16
hidden_units = [32, 16]
attention_hidden_units = [32, 16]

model = DIN(user_feature_dim, item_feature_dim, embedding_dim, hidden_units, attention_hidden_units)

# 模拟输入数据
batch_size = 32
user_features = tf.random.uniform([batch_size], minval=0, maxval=user_feature_dim, dtype=tf.int32)
item_features = tf.random.uniform([batch_size], minval=0, maxval=item_feature_dim, dtype=tf.int32)
max_length = 10
hist_item_features = tf.random.uniform([batch_size, max_length], minval=0, maxval=item_feature_dim, dtype=tf.int32)
real_length = tf.random.uniform([batch_size], minval=0, maxval=max_length, dtype=tf.int32)
hist_mask = tf.sequence_mask(real_length, max_length)

# 前向传播
output = model(user_features, item_features, hist_item_features, hist_mask)

In [19]:
output

<tf.Tensor: shape=(32, 1), dtype=float32, numpy=
array([[0.],
       [0.],
       [1.],
       [0.],
       [0.],
       [1.],
       [0.],
       [0.],
       [0.],
       [0.],
       [1.],
       [0.],
       [1.],
       [0.],
       [0.],
       [0.],
       [1.],
       [0.],
       [0.],
       [0.],
       [1.],
       [1.],
       [0.],
       [0.],
       [0.],
       [1.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.],
       [0.]], dtype=float32)>