In [37]:
import torch.nn as nn
import torch

In [38]:
class MultiHeadAttention(nn.Module):
    def __init__(self, batch_size, hidden_dim, num_heads, dropout):
        super(MultiHeadAttention, self).__init__()
        self.h_dim = hidden_dim
        self.n_heads = num_heads

        assert self.h_dim % self.n_heads == 0
        self.bs = batch_size
        self.w_q = nn.Linear(hidden_dim, hidden_dim)
        self.w_k = nn.Linear(hidden_dim, hidden_dim)
        self.w_v = nn.Linear(hidden_dim, hidden_dim)
        self.w_o = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)

        self.scale = torch.sqrt(torch.FloatTensor([hidden_dim // num_heads]))
    
    def forward(self, q, k, v, Mask=None):
        self.mask = Mask
        
        Q = self.w_q(q)
        K = self.w_k(k)
        V = self.w_v(v)

        Q = Q.view(self.bs, -1, self.n_heads, self.h_dim // self.n_heads).permute(0, 2, 1, 3)
        K = K.view(self.bs, -1, self.n_heads, self.h_dim // self.n_heads).permute(0, 2, 1, 3)
        V = V.view(self.bs, -1, self.n_heads, self.h_dim // self.n_heads).permute(0, 2, 1, 3)
        
        att_scores = self.attention(Q, K, V, self.mask)
        return att_scores
    

    def attention(self, Q, K, V, Mask):
        scores = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
        
        if Mask is not None:
            scores = scores.masked_fill(Mask==0, -1e10)
        
        sft_scores = self.dropout(torch.softmax(scores, dim=-1))

        output = torch.matmul(sft_scores, V)

        output = output.permute(0, 2, 1, 3).contiguous()
        output = output.view(self.bs, -1, self.h_dim)
        
        output = self.w_o(output)
        return output

In [39]:
# bs, h_dim, n_heads
bs = 32
q_len = 12
k_len, v_len = 10,10
h_dim = 512
n_heads = 8
q = torch.rand(bs, q_len, h_dim)
k = torch.rand(bs, k_len, h_dim)
v = torch.rand(bs, v_len, h_dim)

mha2 = MultiHeadAttention(bs, h_dim, n_heads, 0.2)
att2 = mha2(q,k,v)
att2

tensor([[[-0.1847,  0.1332,  0.1810,  ...,  0.1395,  0.0891,  0.0164],
         [-0.2831,  0.0968,  0.2607,  ...,  0.0749,  0.0551, -0.0144],
         [-0.1594,  0.1094,  0.1718,  ...,  0.1591,  0.0712, -0.0204],
         ...,
         [-0.1714,  0.1697,  0.2321,  ...,  0.1548,  0.1060, -0.0109],
         [-0.2557,  0.1208,  0.2601,  ...,  0.1334,  0.0721, -0.0028],
         [-0.2001,  0.1494,  0.2060,  ...,  0.1904,  0.0636, -0.0242]],

        [[-0.2306,  0.1464,  0.2053,  ...,  0.1728,  0.0129, -0.0715],
         [-0.1311,  0.1801,  0.1181,  ...,  0.1231, -0.0084, -0.0314],
         [-0.1773,  0.2176,  0.1569,  ...,  0.1959, -0.0013, -0.0103],
         ...,
         [-0.1553,  0.2128,  0.1303,  ...,  0.2243, -0.0195, -0.0230],
         [-0.1812,  0.1897,  0.1222,  ...,  0.0616, -0.0064, -0.0487],
         [-0.2439,  0.1516,  0.1738,  ...,  0.1110, -0.0098, -0.0506]],

        [[-0.2014,  0.1517,  0.1643,  ...,  0.1603,  0.0605, -0.0436],
         [-0.2008,  0.0615,  0.2367,  ...,  0

Encoder的组成：\
    - input: PE+WE后的向量 \
    - MHA \
    - Add & LN \
    - FFNN \
    - Add & LN \
    - output：z，K，V向量，z用于下一层Encoder \
Decoder的组成: \
    - input:第一个词是BOS\
    - MHA\
    - Add & LN\
    - Mask MHA\
    - Add & LN\
    - FFNN\
    - Add & LN\
    - softmax\
    - output：作为下一个词的input\

In [None]:
class LayerNorm(nn.Module):
    pass

In [None]:
class SublayerConnection(nn.Module):
    pass 

In [None]:
class EncoderLayer(nn.Module):
    pass 

In [None]:
def clones(module, N):
    return nn.ModuleList([copy.deepcopy])

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.layer = EncoderLayer()
        self.norm = LayerNorm()

    def forward(self):
        pass 

In [None]:
class Decoder(nn.Module):
    pass

In [None]:
class PositionEmbedding(nn.Module):
    pass

In [None]:
class MyTransformer(nn.Module):
    pass