## 导库

In [1]:
import re
import torch
import torch.nn as nn
import numpy as np
import random
import matplotlib.ticker as ticker
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
import torch.nn.functional as F
import torch.optim as optim
import time
import math
import jieba
import plotly.express as px
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = 'TRUE'
jieba.load_userdict("./data/dict.txt")

Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\xiaof\AppData\Local\Temp\jieba.cache
Loading model cost 0.357 seconds.
Prefix dict has been built successfully.


In [2]:
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

## 数据

In [3]:
train_data_file = './data/MuCGEC_dev_filtered.txt'  # 训练数据
test_data_file = './data/MuCGEC_test.txt'

In [4]:
UNK = 0


class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {
            "[UNK]": 0
        }
        self.word2count = {
            "[UNK]": 0
        }
        self.index2word = {0: "[UNK]"}
        # self.index2word = {}
        self.n_words = 1  # Count SOS and EOS and UNK

    def addSentence(self, sentence):
        for word in jieba.lcut(sentence):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

In [5]:
MAX_LENGTH = 40


def filterPair(p):
    ret = len(jieba.lcut(p[0])) < MAX_LENGTH and len(
        jieba.lcut(p[1])) < MAX_LENGTH
    return ret

In [6]:
input_lang = Lang('input')
output_lang = Lang('output')
pairs = []
with open(train_data_file, 'r') as f:
    lines = f.readlines()
    for line in lines:
        pair = line.strip('\n').split(' ')
        if not filterPair(pair):
            # 如果长度不符合要求则跳过
            continue
        pairs.append(pair)
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])

In [7]:
test_sentences = []
with open(test_data_file, 'r', encoding='utf-8') as f:
    lines = f.readlines()
    for line in lines:
        s = line.split('\t')[1]
        test_sentences.append(s.strip('\n'))
test_sentences[:5]

['冬阴功是泰国最著名的菜之一，它虽然不是很豪华，但它的味确实让人上瘾，做法也不难、不复杂。',
 '首先，我们得准备：大虾六到九只、盐一茶匙、已搾好的柠檬汁三汤匙、泰国柠檬叶三叶、柠檬香草一根、鱼酱两汤匙、辣椒6粒，纯净水4量杯、香菜半量杯和草菇10个。',
 '这样，你就会尝到泰国人死爱的味道。',
 '另外，冬阴功对外国人的喜爱不断地增加。',
 '这部电影不仅是国内，在国外也很有名。']

In [8]:
filtered_test_sentences = []
for sentence in test_sentences:
    cut_s = jieba.lcut(sentence, HMM=False)
    if len(cut_s) < MAX_LENGTH:
        filtered_test_sentences.append(sentence)
for sentence in filtered_test_sentences:
    words = jieba.lcut(sentence, HMM=False)
    for w in words:
        if w not in input_lang.word2index.keys():
            input_lang.addWord(w)
        if w not in output_lang.word2index.keys():
            output_lang.addWord(w)

In [9]:
print("Counted words:")
print(input_lang.name, input_lang.n_words)
print(output_lang.name, output_lang.n_words)
print(random.choice(pairs))

Counted words:
input 12798
output 12776
['它们中一个甚至是像曼陀罗一样漂亮的', '它们中一个甚至是像曼陀罗一样漂亮的']


## 构建模型

### 编码模型

In [10]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, dropout_p=0.1):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, input):
        embedded = self.dropout(self.embedding(input))
        output, hidden = self.gru(embedded)
        return output, hidden

### 注意力模型

In [11]:
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size):
        super(BahdanauAttention, self).__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, query, keys):
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
        scores = scores.squeeze(2).unsqueeze(1)

        weights = F.softmax(scores, dim=-1)
        context = torch.bmm(weights, keys)

        return context, weights

### 编辑预测模型

In [None]:
class Seq2Edit(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_edit_types, dropout_p=0.1):
        super(Seq2Edit, self).__init__()
        self.attention = BahdanauAttention(hidden_size)
        self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout_p)
        self.edit_fc = nn.Linear(hidden_size, num_edit_types)  # 预测四种编辑操作
        self.char_fc = nn.Linear(hidden_size, vocab_size)  # 预测替换/插入的字符

    def forward(self, src, encoder_outputs, encoder_hidden, target_tensor, trg=None):
        # 编码器
        batch_size = src.size(0)
        device = src.device

        # 解码器初始化
        decoder_input = torch.zeros(
            batch_size, 1, dtype=torch.long, device=device)  # 初始输入为 <SOS>
        decoder_hidden = encoder_hidden
        edit_logits, char_logits = [], []

        # 解码过程
        for i in range(src.size(1)):
            # 注意力机制
            query = decoder_hidden.permute(1, 0, 2)
            context, _ = self.attention(query, encoder_outputs)

            # GRU 输入
            embedded = self.dropout(self.encoder.embedding(decoder_input))
            input_gru = torch.cat((embedded, context), dim=2)

            # GRU 解码
            output, decoder_hidden = self.gru(input_gru, decoder_hidden)

            # 预测编辑操作和字符
            edit_logit = self.edit_fc(output.squeeze(1))
            char_logit = self.char_fc(output.squeeze(1))
            edit_logits.append(edit_logit)
            char_logits.append(char_logit)

            # 更新解码器输入
            if trg is not None:
                decoder_input = trg[:, i].unsqueeze(1)  # 使用真实标签
            else:
                decoder_input = torch.argmax(
                    char_logit, dim=1).unsqueeze(1)  # 使用预测结果

        # 拼接输出
        edit_logits = torch.stack(edit_logits, dim=1)
        char_logits = torch.stack(char_logits, dim=1)
        return edit_logits, char_logits

In [13]:
def apply_edits(original_text, predicted_edits, predicted_chars=None):
    """
    根据 `Seq2Edit` 预测的编辑操作，应用到原始文本上。

    :param original_text: str, 原始输入文本
    :param predicted_edits: List[int], 每个字符对应的编辑操作 (0: 保持, 1: 替换, 2: 删除, 3: 插入)
    :param predicted_chars: List[str] 或 None, 仅用于 `替换` 和 `插入` 的新字符预测 (如果模型支持)
    :return: str, 纠正后的文本
    """
    corrected_text = []
    predicted_chars = predicted_chars or [
        ""] * len(original_text)  # 如果没有预测字符，则默认空

    for i, (char, edit_op) in enumerate(zip(original_text, predicted_edits)):
        if edit_op == 0:  # 保持
            corrected_text.append(char)
        elif edit_op == 1:  # 替换
            corrected_text.append(
                predicted_chars[i] if predicted_chars[i] else char)
        elif edit_op == 2:  # 删除
            continue  # 跳过此字符
        elif edit_op == 3:  # 插入
            corrected_text.append(char + predicted_chars[i])  # 在当前字符后插入新字符

    return "".join(corrected_text)

## 训练

In [14]:
def indexesFromSentence(lang, sentence):
    temp_list = []
    for word in jieba.lcut(sentence):
        temp_list.append(lang.word2index.get(word, lang.word2index['[UNK]']))
    return temp_list


def tensorFromSentence(lang, sentence):
    # 将句子转换为索引的Tensor
    indexes = indexesFromSentence(lang, sentence)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)


def tensorsFromPair(pair):
    input_tensor = tensorFromSentence(input_lang, pair[0])
    target_tensor = tensorFromSentence(output_lang, pair[1])
    return (input_tensor, target_tensor)

In [15]:
def get_dataloader(batch_size):
    n = len(pairs)
    input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)
    target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)

    for idx, (inp, tgt) in enumerate(pairs):
        inp_ids = indexesFromSentence(input_lang, inp)
        tgt_ids = indexesFromSentence(output_lang, tgt)

        input_ids[idx, :len(inp_ids)] = inp_ids
        target_ids[idx, :len(tgt_ids)] = tgt_ids

    train_data = TensorDataset(torch.LongTensor(input_ids).to(device),
                               torch.LongTensor(target_ids).to(device))

    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(
        train_data, sampler=train_sampler, batch_size=batch_size)
    return train_dataloader

### 每周期的训练

In [16]:
def train_epoch(dataloader, encoder, decoder, encoder_optimizer,
                decoder_optimizer, criterion):

    total_loss = 0
    for data in dataloader:
        input_tensor, target_tensor = data

        print(input_tensor.shape)
        print(target_tensor.shape)

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        encoder_outputs, encoder_hidden = encoder(input_tensor)
        edit_logits, char_logits, _ = decoder(input_tensor,
                                              encoder_outputs, encoder_hidden, target_tensor)

        corrected_tensor = apply_edits(input_tensor, edit_logits, char_logits)

        # 定义损失函数 ,查看编辑后的文本和原文本的差异
        loss = criterion(
            corrected_tensor.view(-1, corrected_tensor.size(-1)),
            target_tensor.view(-1)
        )

        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

In [17]:
def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001,
          print_every=100, plot_every=100):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)

    encoder.train()
    decoder.train()

    criterion = nn.NLLLoss()

    for epoch in range(1, n_epochs + 1):
        loss = train_epoch(train_dataloader, encoder, decoder,
                           encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        if epoch % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),
                                         epoch, epoch / n_epochs * 100, print_loss_avg))

        if epoch % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0

### 启动

In [18]:
hidden_size = 256
batch_size = 64
n_epochs = 100

train_dataloader = get_dataloader(batch_size)

encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
decoder = Seq2Edit(hidden_size, output_lang.n_words,
                   num_edit_types=4).to(device)

train(train_dataloader, encoder, decoder,
      n_epochs, print_every=1, plot_every=1)

torch.Size([64, 40])
torch.Size([64, 40])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x256 and 12776x12776)