In [3]:
import math
from typing import Optional, List
import torch
from torch import nn
from labml import tracker

In [5]:
class PrepareForMultiHeadAttention(nn.Module): # nn.Module
    def __init__(self, d_model, heads, d_k, bias): 
        '''
            var 'heads' is the number of heads 
            var 'd_model' is the number of features in q, k, v
            var 'd_k' is the demention of each head
            var bias is ?
        '''
        super().__init__()
        self.linear = nn.Linear(d_model, heads*d_k, bias = bias)
        self.heads = heads
        self.d_k = d_k

    def forward(self, x: torch.Tensor):
        '''
            var x is the vector of all
        '''
        head_shape = x.shape[:-1] #??? #输入的形状为[seq_len, batch_size, d_model] 或[batch_size, d_model] 。我们对最后一维应用线性变换，并将其分为多个头。
        x = self.linear(x) #???
        x = x.view(*head_shape, self.heads, self.d_k) #???
        return x

In [None]:
class MultiheadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout_prob=0.1, bias=True):
        super().__init__()
        self.d_k = d_model
        self.heads = heads
        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True) 
        self.softmax = nn.Softmax(dim=1) #dim=1能保证在时间维度上对key进行注意力softmax？
        self.output = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout_prob)
        self.scale = 1/math.sqrt(self.d_k)
        self.attn = None
    def get_scores(self, query, key):
        return torch.einsum('ibhd, jbhd->ijbh', query, key)
    def prepare_mask(self, mask, query_shape, key_shape):
        assert mask.shape[0]==1 or mask.shape[0]==query_shape[0]
        assert mask.shape[1]==key_shape[0]
        assert mask.shape[2]==1 or mask.shape[2]==query_shape[1]
        mask = mask.unsqueeze(-1) #???
        return mask
    def forward(self, *, query, key, value, mask):
        seq_len, batch_size, _ = query.shape
        if mask is not None:
            mask = self.prepare_mask(mask, query.shape, key.shape)
        query = self.query(query)
        key = self.key(key)
        value = self.value(value)
        scores = self.get_scores(query, key)
        scores *= self.scale
        if mask is not None:
            scores = scores.masked_fill(mask==0, float('-inf'))
        attn = self.softmax(scores)
        tracker.debug('attn', attn)
        attn = self.dropout(attn)
        x = torch.einsum('ijbh,jbhd->ibhd', attn, value)
        self.attn = attn.detach()
        x = x.reshape(seq_len, batch_size, -1)
        return self.output(x)
