## 点乘注意力机制

In [3]:
import math
import torch 
import torch.nn as nn
class DotProductAttention(nn.Module): 
    def __init__(self, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.softmax = nn.Softmax()

    # query: (batch_size, #queries, d)
    # key: (batch_size, #kv_pairs, d)
    # value: (batch_size, #kv_pairs, dim_v)
    # valid_length: either (batch_size, ) or (batch_size, xx)
    def forward(self, query, key, value):
        d = query.shape[-1]
        # set transpose_b=True to swap the last two dimensions of key        
        scores = torch.bmm(query, key.transpose(1,2)) / math.sqrt(d)
        attention_weights = self.softmax(scores)
        print(attention_weights)
        return torch.bmm(attention_weights, value)

atten = DotProductAttention()
keys = torch.ones((2,10,2),dtype=torch.float)
values = torch.arange((40), dtype=torch.float).view(1,10,4).repeat(2,1,1)
print(values.shape)
print(values)
out = atten(torch.ones((2,1,2),dtype=torch.float), keys, values)
print(out.shape)

torch.Size([2, 10, 4])
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.],
         [12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.],
         [24., 25., 26., 27.],
         [28., 29., 30., 31.],
         [32., 33., 34., 35.],
         [36., 37., 38., 39.]],

        [[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.],
         [12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.],
         [24., 25., 26., 27.],
         [28., 29., 30., 31.],
         [32., 33., 34., 35.],
         [36., 37., 38., 39.]]])
tensor([[[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
          0.5000, 0.5000]],

        [[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
          0.5000, 0.5000]]])
torch.Size([2, 1, 4])


  attention_weights = self.softmax(scores)


## MLP注意力

In [6]:
class MLPAttention(nn.Module):  
    def __init__(self, dim, **kwargs):
        super(MLPAttention, self).__init__(**kwargs)
        # Use flatten=True to keep query's and key's 3-D shapes.
        self.W_k = nn.Linear(2, dim, bias=False)
        self.W_q = nn.Linear(2, dim, bias=False)
        self.v = nn.Linear(dim, 1, bias=False)
        self.softmax = nn.Softmax()

    def forward(self, query, key, value):
        query, key = self.W_k(query), self.W_q(key)
        # expand query to (batch_size, #querys, 1, units), and key to
        # (batch_size, 1, #kv_pairs, units). Then plus them with broadcast.
        features = query.unsqueeze(2) + key.unsqueeze(1)
        scores = self.v(features).squeeze(-1) 
        attention_weights = self.softmax(scores)
        print(attention_weights)
        return torch.bmm(attention_weights, value)

atten = MLPAttention(4)
keys = torch.ones((2,10,2),dtype=torch.float)
values = torch.arange((40), dtype=torch.float).view(1,10,4).repeat(2,1,1)
print(values.shape)
print(values)
atten(torch.ones((2,1,2),dtype=torch.float), keys, values)

torch.Size([2, 10, 4])
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.],
         [12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.],
         [24., 25., 26., 27.],
         [28., 29., 30., 31.],
         [32., 33., 34., 35.],
         [36., 37., 38., 39.]],

        [[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.],
         [12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.],
         [24., 25., 26., 27.],
         [28., 29., 30., 31.],
         [32., 33., 34., 35.],
         [36., 37., 38., 39.]]])
tensor([[[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
          0.5000, 0.5000]],

        [[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
          0.5000, 0.5000]]], grad_fn=<SoftmaxBackward0>)


  attention_weights = self.softmax(scores)


tensor([[[ 90.,  95., 100., 105.]],

        [[ 90.,  95., 100., 105.]]], grad_fn=<BmmBackward0>)