In [2]:
import torch
import torch.nn as nn

* q和k算完以后，最后两个维度都是时间为维度
* 矩阵相乘都只看最后两个维度，然后只要第一个矩阵的最后一个维度和第二个矩阵的第一个维度相等就可以了

In [3]:
q = torch.randn(5,3,135,39)
k = torch.randn(5,3,135,39)
z = torch.matmul(q, k.transpose(-2, -1)) 
z.shape

torch.Size([5, 3, 135, 135])

In [4]:
torch.softmax(z,dim = -1).shape

torch.Size([5, 3, 135, 135])

In [5]:
v = torch.randn(5,3,135,39)
torch.matmul(z,v).shape

torch.Size([5, 3, 135, 39])

* 通过上面其实可以看出来，向量又变回原来的维度了
* 下面是完整的Multi-head attention

* Multi-head的本质其实是把原来的dimention拆成了head*sub_dimention,然后拿的sub_dimension去做的处理
* 所以原来传的dimension是head*sub_dimension的大小，这里sub_dimension就是词向量的维度
* Multi_head attention里面用的是还是普通的attention,反正都是只处理的传进来的数据的最后两个维度而已

In [52]:
def attention(q, k, v, d_k, mask=None, dropout=None):
    #这边matmul完的最后一个维度就是k的时间维度,batch,heads,q_t,k_t
    scores = torch.matmul(q, k.transpose(-2, -1)) /  math.sqrt(d_k)
    print("scores.shape_1:",scores.shape)
    if mask is not None:
        mask = mask.unsqueeze(1)
        print("mask.shape:", mask.shape)  # 又多了一个1，为什么呢？因为muti-head attention也是4个维度的，所以要再unsequeeze一下
        scores = scores.masked_fill(mask == 0, -1e9)  # 所有的那些mask对应的权重就都是1e-9
        print("scores_mask.shape",scores.shape)  # 维度没变
        print("scores_mask:",scores)
    #这边softmax肯定是在时间维度上softmax的，所以dim = -1应该就是时间维度,batch,heads,
    scores = F.softmax(scores, dim=-1)
    print("scores.shape_2:",scores.shape)
    
    if dropout is not None:
        scores = dropout(scores)
        
    output = torch.matmul(scores, v)
    return output

In [53]:
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout = 0.1):
        super().__init__()

        self.d_model = d_model
        self.d_k = d_model // heads # d_model should be multiples of heads, 这个其实就是内部拆成多个head
        self.h = heads
        #q,k,v都有自己对应的linear层
        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):
        bs = q.size(0)

        # perform linear transformation and split into N heads
        #把原来的d_model拆成了batch_size, seq_len, head, 每个head对应的维度
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
        print("k.shape_1",k.shape)

        # transpose to get dimensions bs * N * seq_len * d_model
        k = k.transpose(1,2)
        q = q.transpose(1,2)
        v = v.transpose(1,2)
        print("k.shape_2",k.shape)

        # calculate attention，这边其实就是拿的sub_dimension在做attentioni
        scores = attention(q, k, v, self.d_k, mask, self.dropout)

        # concatenate heads and put through final linear layer
        concat = scores.transpose(1,2).contiguous().view(bs, -1, self.d_model)
        output = self.out(concat)
        # output: (batch_size, seq_len, d_model)
        return output


In [54]:
import math
import torch.nn.functional as F

In [55]:
multi_head = MultiHeadAttention(heads=4 , d_model = 512)

In [56]:
q, k, v = torch.randn(5, 135, 512), torch.randn(5, 135, 512), torch.randn(5, 135, 512)  # q,k,v 其实可以都是和input_x维度完全一样的
# 当然，可以去变换最后一个维度不同

In [57]:
mask = torch.ones(5,1,135) # encoder的 mask是 batch_size, 1, time_step

mask[0][0][-1] = 0 

mask[0][0][-2] = 0

In [58]:
out = multi_head(q, k, v,mask)

k.shape_1 torch.Size([5, 135, 4, 128])
k.shape_2 torch.Size([5, 4, 135, 128])
scores.shape_1: torch.Size([5, 4, 135, 135])
mask.shape: torch.Size([5, 1, 1, 135])
scores_mask.shape torch.Size([5, 4, 135, 135])
scores_mask: tensor([[[[ 2.9316e-01, -7.4257e-01, -4.4617e-01,  ..., -8.1529e-01,
           -1.0000e+09, -1.0000e+09],
          [-2.5498e-01,  4.0815e-01,  5.3102e-01,  ...,  3.2908e-02,
           -1.0000e+09, -1.0000e+09],
          [-5.1913e-02, -3.4576e-01, -1.3457e-01,  ...,  6.6825e-01,
           -1.0000e+09, -1.0000e+09],
          ...,
          [-1.7907e-01, -4.3929e-01, -1.6597e-02,  ...,  3.0488e-01,
           -1.0000e+09, -1.0000e+09],
          [-2.4973e-01, -3.8612e-01, -3.6779e-01,  ...,  2.5824e-01,
           -1.0000e+09, -1.0000e+09],
          [-6.1649e-01,  1.6495e-01,  5.4718e-01,  ..., -5.4148e-02,
           -1.0000e+09, -1.0000e+09]],

         [[-3.0802e-02, -2.2213e-01, -1.3141e-01,  ...,  5.5849e-01,
           -1.0000e+09, -1.0000e+09],
          [-

* 还是用的最后那两个维度去做q,k,v

In [45]:
out.shape

torch.Size([5, 135, 512])

### 总的来说就是，一开始传进来的是(batch_size, seq_len, d_model),d_model = head * word_embedding, 最后传出去的也是batch_size,seq_len,d_model
### 和普通的attention唯一的不同就是
* 普通的attention不是d_model, 直接就是word_embedding
* Multi-head attention的d_model是head * word_embedding那么大