In [4]:
import torch
import torch.nn as nn
from torch import functional as F
import math

In [5]:
d_model = 512
n_head = 8

In [6]:
X = torch.randn(16,64,512)

In [18]:
class Multi_Head_Attention(nn.Module):
    def __init__(self, d_model, n_head):
        super().__init__()
        self.d_model = d_model
        self.n_head = n_head
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, q, k, v):
        B, T, D = q.shape
        
        n_d = self.d_model // self.n_head #每个头的维度
        
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
        
        q = q.view(B, T, self.n_head, n_d).transpose(1,2)
        k = k.view(B, T, self.n_head, n_d).transpose(1,2)
        v = v.view(B, T, self.n_head, n_d).transpose(1,2)
        
        score = q@k.transpose(2,3) / math.sqrt(n_d)
        
        mask =  torch.tril(torch.ones(T, T, dtype = bool))
        
        score = score.masked_fill(mask == 0, -10000)
        
        score = self.softmax(score)
        
        score = score@v
        
        x_concate = score.transpose(1,2).contiguous().view(B,T,self.d_model)
        x_output = self.w_o(x_concate)
        
        return x_output

In [19]:
atte = Multi_Head_Attention(d_model, n_head)

In [20]:
Y = atte(X,X,X)

In [22]:
print(Y.shape)

torch.Size([16, 64, 512])


In [26]:
#层归一化 layer norm
class Layer_Norm(nn.Module):
    def __init__(self, d_model, eps = 1e-12):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps
        
    def forward(self, x):
        mean = x.mean(-1, keepdim = True)
        var = x.var(-1, keepdim = True)
        out = (x - mean) / torch.sqrt(var + self.eps)
        out = self.gamma * out + self.beta
        return out

In [27]:
LN = Layer_Norm(d_model)

In [28]:
print("d_model: ", d_model)

d_model:  512


In [31]:
print(f"LN gamma: {LN.gamma.shape}")
print(f"LN beta: {LN.beta.shape}")

LN gamma: torch.Size([512])
LN beta: torch.Size([512])


In [33]:
Y_ln = LN(X)
print(Y_ln.shape)

torch.Size([16, 64, 512])
