In [1]:
# coding=utf-8
import tensorflow as tf
import numpy as np


def target_attention(Q, K, V):
    """ target_attention implementation
    :param Q:
    :param K:
    :param V:
    :return: target_attention tensor
    """
    k1, k2 = K.get_shape().as_list()[-1], Q.get_shape().as_list()[-1]
    W = tf.get_variable("w", shape=[k1, k2], initializer=tf.keras.initializers.he_normal())

    K_transform = tf.tensordot(K, W, axes=1)
    d_k = tf.cast((k1 + k2) / 2, dtype=tf.float32)
    logit = tf.matmul(K_transform, tf.expand_dims(Q, axis=-1)) / tf.sqrt(d_k)
    weight = tf.nn.softmax(tf.squeeze(logit, axis=-1), axis=-1)

    attention = tf.matmul(tf.expand_dims(weight, axis=1), V)

    return tf.squeeze(attention, axis=1)


if __name__ == '__main__':

    seq_len = 4
    embedding_size = 4

    seq_tensor = tf.placeholder(dtype=tf.float32, shape=(None, seq_len, embedding_size))  # 序列特征
    target_tensor = tf.placeholder(dtype=tf.float32, shape=(None, embedding_size))  # target_item

    t_attn = target_attention(target_tensor, seq_tensor, seq_tensor)

    feed_dict = {
        target_tensor: np.array([
            [3.0, 4.0, 5.0, 6.0]
        ]),
        seq_tensor: np.array([[
            [1.0, 2.0, 3.0, 4.0],
            [5.0, 6.0, 7.0, 8.0],
            [9.0, 10.0, 11.0, 12.0],
            [5.0, 4.0, 3.0, 8.0]
        ]])
    }

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        t_attn_out = sess.run(t_attn, feed_dict=feed_dict)
        print(t_attn_out)


Instructions for updating:
Colocations handled automatically by placer.
[[  8.99884224   9.99884224  10.99884224  11.99884224]]


In [2]:

import torch
import torch.nn as nn

class TargetAttention(nn.Module):
    def __init__(self, num_heads, d_model, dropout=0.1):
        super(TargetAttention, self).__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.dropout = nn.Dropout(p=dropout)
        
        # 定义一个线性层，用于将输入的序列映射到查询空间
        self.query = nn.Linear(d_model, d_model)
        
        # 定义两个线性层，分别用于计算注意力权重和输出
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        
        # 定义一个可训练的参数来控制注意力的缩放因子
        self.scale = nn.Parameter(torch.Tensor([1.0]))
        
    def forward(self, x):
        # 将输入x通过查询层映射到查询空间
        query = self.query(x)
        
        # 将查询向量分成num_heads个部分，并计算它们的均值
        query = query.view(query.size(0), query.size(1), self.num_heads, self.d_k).mean(dim=2)
        
        # 将查询向量与键向量进行点积，得到注意力权重矩阵
        scores = torch.matmul(query, self.key(x).transpose(1,2)) / math.sqrt(self.d_k)
        
        # 对注意力权重进行缩放，以控制输出的大小
        scores = scores * self.scale.repeat(scores.size(0), 1, 1).unsqueeze(1)
        
        # 利用softmax函数对注意力权重进行归一化处理
        scores = scores / torch.max(scores, dim=-1, keepdim=True)[0].unsqueeze(2)
        
        # 将归一化后的注意力权重与值向量相乘，得到最终的输出
        output = scores.matmul(self.value(x))
        
        # 对输出进行dropout操作，以防止过拟合
        output = self.dropout(output)
        
        return output
    
# 测试代码
x = torch.randn(10, 50)
target = torch.randn(10, 50)
attn = TargetAttention(num_heads=2, d_model=50, dropout=0.3)
output = attn(x, target)
print(output)

TypeError: forward() takes 2 positional arguments but 3 were given