# 向量点击注意力

In [1]:
import torch
import torch.nn.functional as F

x1 = torch.randn(2, 3, 4)
x2 = torch.randn(2, 5, 4)

raw_weights = torch.bmm(x1, x2.transpose(1, 2))
print(raw_weights)

attn_weights = F.softmax(raw_weights, dim=2)
print(attn_weights)

attn_output = torch.bmm(attn_weights, x2)
print(attn_output)

# 缩放向量点积注意力

In [3]:
import torch
import torch.nn.functional as F

x1 = torch.randn(2, 3, 4)
x2 = torch.randn(2, 5, 4)

raw_weights = torch.bmm(x1, x2.transpose(1, 2))
print(raw_weights)

scale_factor = x1.size(-1)**5
scaled_weights = raw_weights / scale_factor
print(scaled_weights)

attn_weights = F.softmax(raw_weights, dim=2)
print(attn_weights)

attn_output = torch.bmm(attn_weights, x2)
print(attn_output)

In [7]:
x1.size(-1)**0.5

In [1]:
%reset -f
sentences = [
    ["咖哥 喜欢 小冰", "<sos> KaGe likes XiaoBing", "KaGe likes XiaoBing <eos>"],
    ["我 爱 学习 人工智能", "<sos> I love studying AI", "I love studying AI <eos>"],
    ["深度学习 改变 世界", "<sos> DL changed the world", "DL changed the world <eos>"],
    ["自然 语言 处理 很 强大", "<sos> NLP is so powerful", "NLP is so powerful <eos>"],
    ["神经网络 非常 复杂", "<sos> Neural-Nets are complex", "Neural-Nets are complex <eos>"]
]
word_list_cn, word_list_en = [], []
for s in sentences:
    word_list_cn.extend(s[0].split())
    word_list_en.extend(s[1].split())
    word_list_en.extend(s[2].split())
word_list_cn = list(set(word_list_cn))
word_list_en = list(set(word_list_en))

word_2_idx_cn = {w: i for i, w in enumerate(word_list_cn)}
word_2_idx_en = {w: i for i, w in enumerate(word_list_en)}

idx_2_word_cn = {i: w for i, w in enumerate(word_list_cn)}
idx_2_word_en = {i: w for i, w in enumerate(word_list_en)}

voc_size_cn = len(word_list_cn)
voc_size_en = len(word_list_en)

print(f"句子数量: {len(sentences)}")
print(f"中文词汇表大小: {voc_size_cn}")
print(f"英文词汇表大小: {voc_size_en}")
print(f"中文词汇到索引: {word_2_idx_cn}")
print(f"英文词汇到索引: {word_2_idx_en}")

In [2]:
import numpy as np
import torch
import random

def make_data(sentences):
    random_sentence = random.choice(sentences)
    encoder_input = np.array([[word_2_idx_cn[w] for w in random_sentence[0].split()]])
    decoder_input = np.array([[word_2_idx_en[w] for w in random_sentence[1].split()]])
    target = np.array([[word_2_idx_en[w] for w in random_sentence[2].split()]])
    encoder_input = torch.LongTensor(encoder_input)
    decoder_input = torch.LongTensor(decoder_input)
    target = torch.LongTensor(target)
    return encoder_input, decoder_input, target

encoder_input, decoder_input, target = make_data(sentences)

# 1. 定义 Attention 类

In [3]:
import torch.nn as nn

class Attention(nn.Module):
    
    def __init__(self):
        super().__init__()
    
    def forward(self, decoder_context, encoder_context):
        scores = torch.matmul(decoder_context, encoder_context.transpose(-2, -1))
        attn_weights = nn.functional.softmax(scores, dim=-1)
        context = torch.matmul(attn_weights, encoder_context)
        return context, attn_weights

# 2. 重构 Decoder 类

In [4]:
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__() 
        self.hidden_size = hidden_size
        self.embedding  = nn.Embedding(input_size, hidden_size)
        self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True)
        
    def forward(self, inputs, hidden):
        embedded = self.embedding(inputs)
        output, hidden = self.rnn(embedded, hidden)
        return output, hidden

class DecoderWithAttention(nn.Module):
    
    def __init__(self, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True)
        self.attention = Attention()
        self.out = nn.Linear(2*hidden_size, output_size)
    
    def forward(self, dec_input, hidden, enc_output):
        embedded = self.embedding(dec_input)
        rnn_output, hidden = self.rnn(embedded, hidden)
        context, attn_weights = self.attention(rnn_output, enc_output)
        dec_output = torch.cat((rnn_output, context), dim=-1)
        dec_output = self.out(dec_output)
        return dec_output, hidden, attn_weights

n_hidden = 128
encoder = Encoder(voc_size_cn, n_hidden)
print(f"编码器: {encoder}")
decoder = DecoderWithAttention(n_hidden, voc_size_en)
print(f"解码器: {decoder}")

# 重构 Seq2Seq 类

In [5]:
class Seq2Seq(nn.Module):
    
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, encoder_input, hidden, decoder_input):
        encoder_output, encoder_hidden = self.encoder(encoder_input, hidden)
        decoder_hidden = encoder_hidden
        decoder_output, _, attn_weights = self.decoder(decoder_input, decoder_hidden, encoder_output)
        return decoder_output, attn_weights

model = Seq2Seq(encoder, decoder)

print(model)

In [6]:
def train_seq2seq(model, creterion, optimizer, epochs):
    for epoch in range(epochs):
        encoder_input, decoder_input, target = make_data(sentences)
        hidden = torch.zeros(1, encoder_input.size(0), n_hidden)
        optimizer.zero_grad()
        output, _ = model(encoder_input, hidden, decoder_input)
        loss = creterion(output.view(-1, voc_size_en), target.view(-1))
        if (epoch + 1) % 100 == 0:
            print(f"Epoch: {epoch+1}, Loss: {loss:.10f}")
        loss.backward()
        optimizer.step()

epochs = 10000
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
train_seq2seq(model, criterion, optimizer, epochs)

In [7]:
import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams["font.family"] = ["Arial Unicode MS"]
plt.rcParams["font.sans-serif"] = ["Arial Unicode MS"]
plt.rcParams["axes.unicode_minus"] = False

def visualize_attention(source_sentence, predicted_sentence, atten_weights):
    plt.figure(figsize=(10, 10))
    ax = sns.heatmap(atten_weights, annot=True, cbar=False,
                     xticklabels=source_sentence.split(),
                     yticklabels=predicted_sentence,
                     cmap="Greens")
    plt.xlabel("源序列")
    plt.ylabel("目标序列")
    plt.show()

In [8]:
def test_seq2seq(model, source_sentence):
    encoder_input = np.array([[word_2_idx_cn[w] for w in source_sentence.split()]])
    decoder_input = np.array([word_2_idx_en["<sos>"]] + [word_2_idx_en["<eos>"]]*(len(encoder_input[0])-1))
    encoder_input = torch.LongTensor(encoder_input)
    decoder_input = torch.LongTensor(decoder_input).unsqueeze(0)
    hidden = torch.zeros(1, encoder_input.size(0), n_hidden)
    predict,  attn_weights = model(encoder_input, hidden, decoder_input)
    predict = predict.data.max(2, keepdim=True)[1]
    print(f"{source_sentence} -> {[idx_2_word_en[n.item()] for n in predict.squeeze()]}")
    attn_weights = attn_weights.squeeze(0).detach().numpy()
    visualize_attention(source_sentence, [idx_2_word_en[n.item()] for n in predict.squeeze()], attn_weights)

test_seq2seq(model, "咖哥 喜欢 小冰")
test_seq2seq(model, "自然 语言 处理 很 强大")