# Cross Attention

##### cross Attention  
与selfAttention非常类似，但指的是一个序列和另一个序列的相关性
> 一个序列提供Q，另外一个序列提供K、V，存在个别一个序列提供K，另一个序列提供Q、V的情况
> transformer中Decode的第二个MultiHeadAttention就用的cross Attention

In [3]:
import math
import torch
import torch.nn as nn

# MultiHeadAttention是一个特征提取器
# 输入query, key, value三个向量，输出部分是融合了上下文语义信息的单词表示，输出维度和query相同
# 可以兼容transformer中的三类Attention：encoder self-attention，无mask，输入query = key = value
# decoder self-attention，有sequence mask，保证当前单词只能看到之前的单词，看不到之后的单词。输入query = key = value
# encoder-decoder attention，实现encoder和decoder的交互，query是decoder层的输入，key = value 为encoder的输出。


class MultiHeadAttention(nn.Module):
    def __init__(self, heads, hidden_size):
        super(MultiHeadAttention, self).__init__()
        assert hidden_size % heads == 0
        self.hidden_size = hidden_size
        self.heads = heads
        self.wq = nn.Linear(hidden_size, hidden_size)
        self.wk = nn.Linear(hidden_size, hidden_size)
        self.wv = nn.Linear(hidden_size, hidden_size)

    def forward(self, query, key, value, mask=None):
        # query, key, value = [batch_size, seq_len, hidden_size]
        batch_size, seq2_len, hidden_size = query.shape
        seq1_len = key.shape[1]
        q = self.wq(query).view(
            batch_size, seq2_len, self.heads, -1
        )  # [batch_size, seq2_len, heads, d_k]
        k = self.wk(key).view(
            batch_size, seq1_len, self.heads, -1
        )  # [batch_size, seq1_len, heads, d_k]
        v = self.wv(value).view(
            batch_size, seq1_len, self.heads, -1
        )  # [batch_size, seq1_len, heads, d_k]
        q = q.permute(0, 2, 1, 3)  # [batch_size, heads, seq2_len, d_k]
        k = k.permute(0, 2, 1, 3)  # [batch_size, heads, seq1_len, d_k]
        v = v.permute(0, 2, 1, 3)  # [batch_size, heads, seq1_len, d_k]
        d_k = self.hidden_size // self.heads
        attention = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(
            d_k
        )  # [batch_size, heads, seq2_len, seq1_len]

        if mask is not None:
            if len(mask.shape) != len(attention.shape):
                mask.unsqueeze_(1)  # [batch_size, 1, seq2_len, seq1_len]
            attention.masked_fill_(mask, float("-inf"))
        score = nn.functional.softmax(attention, dim=-1)
        output = torch.matmul(score, v)  # [batch_size, heads, seq2_len, d_k]
        output = output.permute(0, 2, 1, 3).reshape(
            batch_size, seq2_len, -1
        )  # [batch_size, seq2_len, heads, d_k] -> [batch_size, seq2_len, hidden_size]
        return output


class FeedForward(nn.Module):
    def __init__(self, model_size, dropout=0.1):
        super(FeedForward, self).__init__()
        self.linear1 = nn.Linear(model_size, 4 * model_size)
        self.linear2 = nn.Linear(4 * model_size, model_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.relu(self.linear1(x))
        return self.linear2(self.dropout(x))

测试代码

In [9]:
def testCrossAttention():
    x = torch.randn(2, 3, 100)
    y = torch.randn(2, 5, 100)
    crossAtten = MultiHeadAttention(1, 100)
    out = crossAtten(x, y, y)
    print(out.shape)

运行代码

In [10]:
if __name__ == "__main__":
    testCrossAttention()

torch.Size([2, 3, 100])
