# 1. 多头自注意力 (包含残差连接和层归一化)

In [2]:
import numpy as np
import torch
import torch.nn as nn

d_k = 64  # K, Q 维度
d_v = 64  # V 维度


class ScaledDotProductAttention(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)
        scores.masked_fill_(attn_mask, -1e9)
        weights = nn.Softmax(dim=-1)(scores)

        context = torch.matmul(weights, V)
        return context, weights





In [3]:
d_embedding = 512
n_heads = 8
batch_size = 3


class MultiHeadAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.W_Q = nn.Linear(d_embedding, d_k * n_heads)
        self.W_K = nn.Linear(d_embedding, d_k * n_heads)
        self.W_V = nn.Linear(d_embedding, d_v * n_heads)
        self.linear = nn.Linear(n_heads * d_v, d_embedding)
        self.layer_norm = nn.LayerNorm(d_embedding)

    def forward(self, Q, K, V, attn_mask):
        residual, batch_size = Q, Q.size(0)
        q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1, 2)
        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1, 2)
        v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1, 2)
        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
        context, weights = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads*d_v)
        output = self.linear(context)
        output = self.layer_norm(output + residual)
        return output, weights




# 2. 逐位置前馈网络(包含残差连接和层归一化)

In [4]:
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_ff=2048):
        super().__init__()
        self.conv1 = nn.Conv1d(in_channels=d_embedding, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_embedding, kernel_size=1)
        self.layer_norm = nn.LayerNorm(d_embedding)

    def forward(self, inputs):
        residual = inputs
        output = nn.ReLU()(self.conv1(inputs.transpose(1, 2)))
        output = self.conv2(output).transpose(1, 2)
        output = self.layer_norm(output + residual)
        return output

# 3. 正弦位置编码表

In [5]:
def get_sin_enc_table(n_position, embedding_dim):
    sinusoid_table = np.zeros((n_position, embedding_dim))
    for pos_i in range(n_position):
        for hid_j in range(embedding_dim):
            angle = pos_i / np.power(10000, 2 * (hid_j // 2) / embedding_dim)
            sinusoid_table[pos_i, hid_j] = angle
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])
    return torch.FloatTensor(sinusoid_table)


# 4. 填充掩码

In [6]:
def get_attn_pad_mask(seq_q, seq_k):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)
    pad_attn_mask = pad_attn_mask.expand(batch_size, len_q, len_k)
    return pad_attn_mask


# 5. 编码器层

In [7]:
class EncoderLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc_self_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        enc_outputs, attn_weights = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)
        enc_outputs = self.pos_ffn(enc_outputs)
        return enc_outputs, attn_weights


# 6. 编码器

In [8]:
n_layer = 6


class Encoder(nn.Module):
    def __init__(self, corpus):
        super().__init__()
        self.src_emb = nn.Embedding(len(corpus.src_vocab), d_embedding)
        self.pos_emb = nn.Embedding.from_pretrained(get_sin_enc_table(corpus.src_len + 1, d_embedding), freeze=True)
        self.layers = nn.ModuleList(EncoderLayer() for _ in range(n_layer))

    def forward(self, enc_inputs):
        pos_indices = torch.arange(1, enc_inputs.size(1) + 1).unsqueeze(0).to(enc_inputs)
        enc_outputs = self.src_emb(enc_inputs) + self.pos_emb(pos_indices)
        enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs)
        enc_self_attn_weights = []
        for layer in self.layers:
            enc_outputs, enc_self_attn_weight = layer(enc_outputs, enc_self_attn_mask)
            enc_self_attn_weights.append(enc_self_attn_weight)
        return enc_outputs, enc_self_attn_weights

# 7. 后续掩码

In [9]:
def get_attn_subsequent_mask(seq):
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    subsequent_mask = np.triu(np.ones(attn_shape), k=1)
    subsequent_mask = torch.from_numpy(subsequent_mask).byte()
    return subsequent_mask


# 8. 解码器层

In [10]:
class DecoderLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.dec_self_attn = MultiHeadAttention()
        self.dec_enc_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
        dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
        dec_outputs = self.pos_ffn(dec_outputs)
        return dec_outputs, dec_self_attn, dec_enc_attn

# 9. 解码器

In [11]:
n_layer = 6


class Decoder(nn.Module):
    def __init__(self, corpus):
        super().__init__()
        self.tgt_emb = nn.Embedding(len(corpus.tgt_vocab), d_embedding)
        self.pos_emb = nn.Embedding.from_pretrained(get_sin_enc_table(corpus.tgt_len + 1, d_embedding), freeze=True)
        self.layers = nn.ModuleList(DecoderLayer() for _ in range(n_layer))

    def forward(self, dec_inputs, enc_inputs, enc_outputs):
        pos_indices = torch.arange(1, dec_inputs.size(1) + 1).unsqueeze(0).to(dec_inputs)
        dec_outputs = self.tgt_emb(dec_inputs) + self.pos_emb(pos_indices)

        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs)
        dec_self_subsequent_mask = get_attn_subsequent_mask(dec_inputs)
        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_subsequent_mask), 0)
        dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs)
        dec_self_attns, dec_enc_attns = [], []
        for layer in self.layers:
            dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_enc_attn_mask,
                                                             dec_self_attn_mask)
            dec_self_attns.append(dec_self_attn)
            dec_enc_attns.append(dec_enc_attn)
        return dec_outputs, dec_self_attns, dec_enc_attns


# 10. Transformer 类

In [12]:
class Transformer(nn.Module):
    def __init__(self, corpus):
        super().__init__()
        self.encoder = Encoder(corpus)
        self.decoder = Decoder(corpus)
        self.projection = nn.Linear(d_embedding, len(corpus.tgt_vocab), bias=False)

    def forward(self, enc_inputs, dec_inputs):
        enc_outputs, enc_self_attns = self.encoder(enc_inputs)
        dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
        dec_logits = self.projection(dec_outputs)
        return dec_logits, enc_self_attns, dec_self_attns, dec_enc_attns


11. 翻译任务

In [13]:
sentences = [
    ["咖哥 喜欢 小冰", "KaGe likes XiaoBing"],
    ["我 爱 学习 人工智能", "I love studying AI"],
    ["深度学习 改变 世界", "DL changed the world"],
    ["自然语言处理 很 强大", "NLP is powerful"],
    ["神经网络 非常 复杂", "Neural-networks are complex"]
]

In [14]:
from collections import Counter


class TranslationCorpus(object):
    def __init__(self, sentences):
        self.sentences = sentences
        self.src_len = max(len(sentence[0].split()) for sentence in self.sentences) + 1
        self.tgt_len = max(len(sentence[1].split()) for sentence in self.sentences) + 2
        self.src_vocab, self.tgt_vocab = self.create_vocabularies()
        self.src_idx2word = {v: k for k, v in self.src_vocab.items()}
        self.tgt_idx2word = {v: k for k, v in self.tgt_vocab.items()}

    def create_vocabularies(self):
        src_counter = Counter(word for sentence in self.sentences for word in sentence[0].split())
        tgt_counter = Counter(word for sentence in self.sentences for word in sentence[1].split())
        src_vocab = {"<pad>": 0, **{word: i + 1 for i, word in enumerate(src_counter)}}
        tgt_vodab = {"<pad>": 0, "<sos>": 1, "<eos>": 2, **{word: i + 3 for i, word in enumerate(tgt_counter)}}
        return src_vocab, tgt_vodab

    def make_batch(self, batch_size, test_batch=False):
        input_batch, output_batch, target_batch = [], [], []
        sentence_indices = torch.randperm(len(self.sentences))[:batch_size]
        for index in sentence_indices:
            src_sentence, tgt_sentence = self.sentences[index]
            src_seq = [self.src_vocab[word] for word in src_sentence.split()]
            tgt_seq = [self.tgt_vocab["<sos>"]] + [self.tgt_vocab[word] for word in tgt_sentence.split()] + [
                self.tgt_vocab["<eos>"]]
            src_seq += [self.src_vocab["<pad>"]] * (self.src_len - len(src_seq))
            tgt_seq += [self.tgt_vocab["<pad>"]] * (self.tgt_len - len(tgt_seq))
            input_batch.append(src_seq)
            output_batch.append(
                [self.tgt_vocab["<sos>"]] + ([self.tgt_vocab["<pad>"]] * (self.tgt_len - 2)) if test_batch else tgt_seq[
                                                                                                                :-1])
            target_batch.append(tgt_seq[1:])
        input_batch = torch.LongTensor(input_batch)
        output_batch = torch.LongTensor(output_batch)
        target_batch = torch.LongTensor(target_batch)
        return input_batch, output_batch, target_batch
        

In [15]:
corpus = TranslationCorpus(sentences)

In [26]:
import torch
import torch.optim as optim

model = Transformer(corpus)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
epoch = 5000
for i in range(epoch):
    optimizer.zero_grad()
    enc_inputs, dec_inputs, target_batch = corpus.make_batch(batch_size)
    outputs, _, _, _ = model(enc_inputs, dec_inputs)
    loss = criterion(outputs.view(-1, len(corpus.tgt_vocab)), target_batch.view(-1))
    if (i + 1) % 100 == 0:
        print(f"Epoch: {i+1}, loss: {loss}")
    loss.backward()
    optimizer.step()


Epoch: 100, loss: 0.0007152888574637473
Epoch: 200, loss: 0.0004245429008733481
Epoch: 300, loss: 0.00031567850965075195
Epoch: 400, loss: 0.0002447215956635773
Epoch: 500, loss: 0.00018159396131522954
Epoch: 600, loss: 0.0001502319355495274
Epoch: 700, loss: 0.00013289311027619988
Epoch: 800, loss: 0.00012239628995303065
Epoch: 900, loss: 0.00010145763371838257
Epoch: 1000, loss: 8.782143413554877e-05
Epoch: 1100, loss: 8.161515870597214e-05
Epoch: 1200, loss: 7.616380753461272e-05
Epoch: 1300, loss: 6.967136141611263e-05
Epoch: 1400, loss: 6.19550482952036e-05
Epoch: 1500, loss: 5.748100011260249e-05
Epoch: 1600, loss: 5.912597771384753e-05
Epoch: 1700, loss: 5.0996382924495265e-05
Epoch: 1800, loss: 4.494087988859974e-05
Epoch: 1900, loss: 4.487726982915774e-05
Epoch: 2000, loss: 4.2850802856264636e-05
Epoch: 2100, loss: 4.126142812310718e-05
Epoch: 2200, loss: 3.60561789420899e-05
Epoch: 2300, loss: 3.752635166165419e-05
Epoch: 2400, loss: 3.7844223697902635e-05
Epoch: 2500, loss: 

In [27]:
enc_inputs, dec_inputs, target_batch = corpus.make_batch(batch_size=1, test_batch=True)
predict, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)

predict = predict.view(-1, len(corpus.tgt_vocab))
predict = predict.max(1, keepdim=True)[1]

translated_sentence = [corpus.tgt_idx2word[idx.item()] for idx in predict.squeeze(0)]

input_sentence = " ".join([corpus.src_idx2word[idx.item()] for idx in enc_inputs[0]])
print(f"{input_sentence} -> {translated_sentence}")


我 爱 学习 人工智能 <pad> -> ['I', 'I', 'I', 'I', 'I']


# 11. 修正结果

In [41]:
def greedy_decoder(model, enc_inputs, start_symbol):
    enc_outputs, enc_self_attns = model.encoder(enc_inputs)
    dec_inputs = torch.zeros(1, 5).type_as(enc_inputs.data)
    next_symbol = start_symbol
    for i in range(0, 5):
        dec_inputs[0][i] = next_symbol
        dec_outputs, _, _ = model.decoder(dec_inputs, enc_inputs, enc_outputs)
        projected = model.projection(dec_outputs)
        prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]
        next_word = prob.data[i]
        next_symbol = next_word.item()
    dec_outputs = dec_inputs
    return dec_outputs

In [42]:
enc_inputs, dec_inputs, target_batch = corpus.make_batch(batch_size=1, test_batch=True)
predict, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)

gready_dec_input = greedy_decoder(model, enc_inputs, start_symbol=corpus.tgt_vocab["<sos>"])
greedy_dec_output_words = [corpus.tgt_idx2word[n.item()] for n in gready_dec_input.squeeze()]

input_sentence = " ".join([corpus.src_idx2word[idx.item()] for idx in enc_inputs[0]])
print(f"{input_sentence} -> {greedy_dec_output_words}")

咖哥 喜欢 小冰 <pad> <pad> -> ['<sos>', 'KaGe', 'likes', 'XiaoBing', '<eos>']
