In [None]:
import torch
from torch import nn

In [None]:
# 多头自注意力层
# 输入x, shape=(batch_size, sequence_size, embedding_dim)
# 输出y, shape=(batch_size, sequence_size, output_dim-一般与输入维度相同)

class MultiheadAttentionLayer(nn.Module):
    def __init__(self, head_num, embedding_dim, att_dim, output_dim):
        super().__init__()
        self.head_num = head_num
        self.att_dim = att_dim
        self.WQ = nn.Linear(in_features=embedding_dim, out_features=att_dim*head_num, bias=True)
        self.WK = nn.Linear(in_features=embedding_dim, out_features=att_dim*head_num, bias=True)
        self.WV = nn.Linear(in_features=embedding_dim, out_features=att_dim*head_num, bias=True)
        self.WZ = nn.Linear(head_dim*att_dim, output_dim, bias=True)
        self.myLayerNorm = MyLayerNorm(embedding_dim)

    def forward(self, X, maskedMatrix=None):
        batch_size = X.shape[0]
        seq_size = X.shape[1]
        Q = self.WQ(x) #shape = (batch, seq, att_dim*head_num)
        K = self.WK(x) 
        V = self.WV(x)
        Q_slice=Q.reshape(batch_size, seq_size, self.head_num, self.att_dim)
        K_slice=K.reshape(batch_size, seq_size, self.head_num, self.att_dim)
        V_slice=V.reshape(batch_size, seq_size, self.head_num, self.att_dim)
        #养成画图的习惯
        Q_slice_permute = Q_slice.permute(0, 2, 1, 3)
        K_slice_permute = K_slice.permute(0, 2, 3, 1)
        V_slice_permute = V_slice.permute(0, 2, 1, 3)
        att_weight_matrix = getAttWeight(Q_slice_permute, K_slice_permute)    
        # 需要注意是sum()第3维还是第4维
        multihead_seqs = torch.sum(att_weight_matrix.reshape(batch_size, self.head_num, seq_size, seq_size, 1) * V_slice.reshape(batch_size, self.head_num, 1, seq_size, self.att_dim), 3)
        Z = self.WZ(multihead_seqs.permute(0,2,1,3).reshape(batch_size, seq_size, -1))
        self.myLayerNorm(Z + X)

    def getAttWeight(Q, K, nhead=1 maskedMatrix=None):
        # 注意这里的Q的形状为(batch_size, nhead, seq_size, dims)
        att_score_matrix = Q_slice_permute.matmul(K_slice_permute) / torch.sqrt(self.att_dim) 
        att_weight_matrix = softmax(att_score_matrix, dim=-1)
        return att_weight_matrix
        

In [None]:
class MyLayerNorm(nn.Module):
    def __init__(self, dim, epsilon=0e-5):
        super().__init__()
        self.alpha = torch.ones(dim, requires_grad=True).reshape(1, 1, dim)
        self.beta = torch.zeros(dim, requires_grad=True).reshape(1, 1, dim)
        self.epsilon = epsilon

    def forward(self, X):
        X_mean = X.mean(-1, keepdim=True) #shape=(batch_size, seq, 1)
        #虽然可以直接求标准差...
        X_var = X.var(-1, keepdim=True)
        return (X - X_mean) / torch.sqrt(X_var + self.epsilon) * self.alpha + self.beta

In [None]:
def softmax(x, dim=-1):
    x_exp = x.exp()
    x_exp_sum= x_exp.sum(dim=dim, keepdim=True)
    return x_exp / x_exp_sum
    

In [None]:
class PointwiseLayer(nn.Module):
    def __init__(self, dim):
        self.W1 = nn.Linear(infeature=dim, outfeature=dim, bias=True)
        self.W2 = nn.Linear(infeature=dim, outfeature=dim, bias=True)

    def forward(self, X):
        return self.W2(reLu(self.W1(X)))
        
        

In [None]:
def relu(X):
    return torch.max(X, torch.zeros_like(X))

In [None]:
class MaskedMultiheadAttentionLayer(nn.Module):
    def __init__(self, nhead, in_features):
        self.nhead = nhead
        self.in_features = in_features
        self.att_dim = in_features / nhead
        self.WQ = nn.Linear(in_features, in_features, bias=True)
        self.WK = nn.Linear(in_features, in_features, bias=True)
        self.WV = nn.Linear(in_features, in_features, bias=True)

    def forward(self, X):
        