In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# 解码器层


解码器层结构如下:

![](./images/image-20241103212541411.png)

它的组成部分如下:
1. embedding层
2. GRU 层

输入一个批次的文本,先通过Embedding层将其转化为向量。接着送入GRU神经网络, 最后返回当前时间步GRU的输出和隐藏状态



In [3]:
class Encoder(nn.Module):
    def __init__(self, input_dim, embed_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, embed_dim)
        self.GRU = nn.GRU(embed_dim, hidden_dim, batch_first=True)
        self.dropout = nn.Dropout(0.5)

    def forward(self, src):
        # src: [batch_size, src_len]
        embedded = self.dropout(self.embedding(src))
        # embedded: [batch_size, src_len, embed_dim]
        outputs, hidden = self.GRU(embedded)
        # outputs: [batch_size, src_len, hidden_dim]
        # hidden: [1, batch_size, hidden_dim]
        return outputs, hidden

In [6]:
class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.hidden_dim = hidden_dim
        self.softmax = nn.Softmax(dim=2)

    def forward(self, query, key, value):
        # query: [batch_size, 1, hidden_dim]
        # key: [batch_size, src_len, hidden_dim]
        # value: [batch_size, src_len, hidden_dim]

        # 计算 QK^T
        scores = torch.bmm(query, key.transpose(1, 2))
        # scores: [batch_size, 1, src_len]

        # 缩放
        scores = scores / torch.sqrt(
            torch.tensor(self.hidden_dim, dtype=torch.float32, device=scores.device)
        )
        # 应用 softmax
        attention_weights = self.softmax(scores)
        # attention_weights: [batch_size, 1, src_len]

        # 计算上下文向量
        context = torch.bmm(attention_weights, value)
        # context: [batch_size, 1, hidden_dim]

        return context, attention_weights

In [8]:
class Decoder(nn.Module):
    def __init__(self, output_dim, embed_dim, hidden_dim):
        super(Decoder, self).__init__()
        self.output_dim = output_dim
        self.attention = Attention(hidden_dim)
        self.embedding = nn.Embedding(output_dim, embed_dim)
        self.GRU = nn.GRU(embed_dim + hidden_dim, hidden_dim, batch_first=True)
        self.fc_out = nn.Linear(hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(0.5)

    def forward(self, input, hidden, encoder_outputs):
        # input: [batch_size]
        # hidden: [1, batch_size, hidden_dim]
        # encoder_outputs: [batch_size, src_len, hidden_dim]

        input = input.unsqueeze(1)
        # input: [batch_size, 1]
        embedded = self.dropout(self.embedding(input))
        # embedded: [batch_size, 1, embed_dim]

        # 转置 hidden，使其维度与 query 匹配
        hidden_transposed = hidden.permute(1, 0, 2)
        # hidden_transposed: [batch_size, 1, hidden_dim]

        # 计算注意力
        context, attention_weights = self.attention(
            hidden_transposed, encoder_outputs, encoder_outputs
        )
        # context: [batch_size, 1, hidden_dim]

        rnn_input = torch.cat((embedded, context), dim=2)
        # rnn_input: [batch_size, 1, embed_dim + hidden_dim]

        output, hidden = self.GRU(rnn_input, hidden)
        # output: [batch_size, 1, hidden_dim]

        output = output.squeeze(1)
        context = context.squeeze(1)
        # output/context: [batch_size, hidden_dim]

        prediction = self.fc_out(torch.cat((output, context), dim=1))
        # prediction: [batch_size, output_dim]
        return prediction, hidden

In [16]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device, trg_pad_idx, trg_eos_idx):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.trg_pad_idx = trg_pad_idx  # 目标词表中 <pad> 的索引
        self.trg_eos_idx = trg_eos_idx  # 目标词表中 <eos> 的索引

    def forward(self, src, trg, teacher_forcing_ratio=0.5, max_len=100):
        # src: [batch_size, src_len]
        # trg: [batch_size, trg_len]

        batch_size = src.shape[0]
        trg_vocab_size = self.decoder.output_dim

        # 编码器输出和隐藏状态
        encoder_outputs, hidden = self.encoder(src)
        # encoder_outputs: [batch_size, src_len, hidden_dim]
        # hidden: [1, batch_size, hidden_dim]

        # 解码器的初始输入是 <sos> 标记
        input = trg[:, 0]

        # 初始化输出张量
        outputs = []
        outputs.append(torch.zeros(batch_size, trg_vocab_size).to(self.device))

        # 用于跟踪每个序列是否已完成
        finished = torch.zeros(batch_size, dtype=torch.bool).to(self.device)

        for t in range(1, trg.shape[1]):
            # 传递当前输入、隐藏状态和编码器输出到解码器
            output, hidden = self.decoder(input, hidden, encoder_outputs)
            # output: [batch_size, output_dim]
            # hidden: [1, batch_size, hidden_dim]

            # 存储预测结果
            outputs.append(output)

            # 选择最高概率的词汇作为下一个输入
            top1 = output.argmax(1)
            # 更新已完成的序列
            eos_generated = top1 == self.trg_eos_idx
            finished = finished | eos_generated

            # 如果所有序列都已完成，提前退出循环
            if finished.all():
                break

            # 决定是否使用教师强制
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio

            # 下一个输入
            input = trg[:, t] if teacher_force else top1

            # 对于已完成的序列，后续输入设为 <pad>，避免影响其他序列
            input = input.masked_fill(finished, self.trg_pad_idx)

        # 将输出列表转换为张量
        outputs = torch.stack(outputs, dim=1)
        # outputs: [batch_size, seq_len, trg_vocab_size]

        return outputs

In [17]:
TRG_PAD_IDX = 0  # <pad> 的索引
TRG_EOS_IDX = 1  # <eos> 的索引

# 定义模型参数

INPUT_DIM = 50  # 源语言词表大小
OUTPUT_DIM = 50  # 目标语言词表大小
EMBED_DIM = 16  # 词嵌入维度
HIDDEN_DIM = 32  # 隐藏状态维度
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 初始化模型

encoder = Encoder(INPUT_DIM, EMBED_DIM, HIDDEN_DIM).to(DEVICE)
decoder = Decoder(OUTPUT_DIM, EMBED_DIM, HIDDEN_DIM).to(DEVICE)
model = Seq2Seq(encoder, decoder, DEVICE, TRG_PAD_IDX, TRG_EOS_IDX).to(DEVICE)


# 假设有一个批次的数据
batch_size = 2
src_len = 5
trg_len = 7

# 随机生成源序列和目标序列
src = torch.randint(2, INPUT_DIM, (batch_size, src_len)).to(DEVICE)
trg = torch.randint(2, OUTPUT_DIM, (batch_size, trg_len)).to(DEVICE)
# 设置第一个时间步为 <sos>
trg[:, 0] = 2  # 假设 <sos> 的索引为 2
# 手动在目标序列中添加 <eos> 标记
trg[:, -1] = TRG_EOS_IDX

# 运行模型
outputs = model(src, trg)
print(outputs.shape)  # 应该为 [batch_size, seq_len, OUTPUT_DIM]

torch.Size([2, 7, 50])
