### Transformer

In [None]:
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size),  "Embed size needs  to  be div by heads"
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads*self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0] # the number of training examples
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim)
        # keys shape: (N, key_len, heads, heads_dim)
        # energy shape: (N, heads, query_len, key_len)

        if mask is not None:
            energy = energy.masked_fill(mask==0, float("-1e20"))
            #Fills elements of self tensor with value where mask is True

        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)
        out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads*self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, head_dim)
        # after einsum (N, query_len, heads, head_dim) then flatten last two dimensions

        out = self.fc_out(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

class Encoder(nn.Module):
    def __init__(self,
                 src_vocab_size,
                 embed_size,
                 num_layers,
                 heads,
                 device,
                 forward_expansion,
                 dropout,
                 max_length
                 ):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion
                )
            for _ in range(num_layers)]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_lengh = x.shape
        positions = torch.arange(0, seq_lengh).expand(N, seq_lengh).to(self.device)
        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm = nn.LayerNorm(embed_size)
        self.transformer_block = TransformerBlock(embed_size, heads, dropout, forward_expansion)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x))
        out = self.transformer_block(value, key, query, src_mask)
        return out

class Decoder(nn.Module):
    def __init__(self,
                 trg_vocab_size,
                 embed_size,
                 num_layers,
                 heads,
                 forward_expansion,
                 dropout,
                 device,
                 max_length):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
             for _ in range(num_layers)]
        )
        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        out = self.fc_out(x)
        return out

class Transformer(nn.Module):
    def __init__(self,
                 src_vocab_size,
                 trg_vocab_size,
                 src_pad_idx,
                 trg_pad_idx,
                 embed_size=256,
                 num_layers=6,
                 forward_expansion=4,
                 heads=8,
                 dropout=0,
                 device="cuda",
                 max_length=100
                 ):
        super(Transformer, self).__init__()
        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length
        )

        self.decoder = Decoder(
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length
        )
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        #(N, 1, 1, src_len)
        return src_mask.to(self.device)

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )
        return trg_mask.to(self.device)

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask,  trg_mask)
        return out

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(
        device
    )
    trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)

    src_pad_idx = 0
    trg_pad_idx = 0
    src_vocab_size = 10
    trg_vocab_size = 10
    model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(device)
    out = model(x, trg[:, :-1])
    print(out.shape)


### 超参数

In [1]:
bert_dir = "tugstugi/bert-base-mongolian-cased" #"google-bert/bert-base-cased"
batch_size = 16
max_len = 64
epochs = 10

### 自建分词器

In [None]:
import re
import torch
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.decoders import ByteLevel

class CustomTokenizer():
    def __init__(self):
        self.tokenizer = Tokenizer(BPE())
        self.vocab_size = self.tokenizer.get_vocab_size()
        # self.pattern = re.compile(r" ")
        self.tokenizer.pre_tokenizer = Whitespace()
        self.tokenizer.decoder = ByteLevel()
        self.tokenizer.add_tokens([" "])

    def encode(self, text):
        return self.tokenizer.encode(text).ids

    def decode(self, text):
        try:
            return self.tokenizer.decode(text.tolist())
        except AttributeError:
            return self.tokenizer.decode(text)

    def batch_decode(self, tokens, skip_special_tokens):
        decoded_tokens = []
        for token in tokens:
            decoded_token = self.decode(token)
            decoded_tokens.append(decoded_token)
        return decoded_tokens

    def add_word(self, text):
        # word = self.pattern.findall(text)
        word = text.split(" ")
        self.tokenizer.add_tokens(word)

    def save_vocab(self, file_path):
        self.tokenizer.save(file_path)

    def load_vocab(self, file_path):
        self.tokenizer = Tokenizer.from_file(file_path)
        self.vocab_size = self.tokenizer.get_vocab_size()

    def encode_plus(self, text, max_length, **kwargs):
        encoding = self.tokenizer.encode(text)
        input_ids = encoding.ids
        attention_mask = [1] * len(input_ids)

        # 填充到最大长度
        padding_length = max_length - len(input_ids)
        input_ids += [0] * padding_length
        attention_mask += [0] * padding_length

        return {
            'input_ids': torch.tensor([input_ids]),
            'attention_mask': torch.tensor([attention_mask])
        }

# text = "хүүхэд буруу зүйл хийсэн даруйд хажууд нь бэлэн байхыг хүснэ ."
# tokenizer = CustomTokenizer()
# tokenizer.add_word(text)
# tokenizer.save_vocab("../custom_tokenizer.json")
# tokenizer.load_vocab("../custom_tokenizer.json")
# out = tokenizer.encode_plus(text, max_length=20)
# print(out)
# encoded_text = tokenizer.encode(text)
# print(encoded_text)
# decoded_text = tokenizer.decode(encoded_text)
# print(decoded_text)

### 数据处理

In [7]:
from tqdm import tqdm

f_clean = open('../train_clean.txt', 'r').readlines()
f_error = open('../train_spell_error.txt', 'r').readlines()

train_data = []
val_data = []
text = ""
i = 0
for c, e in tqdm(zip(f_clean, f_error), total=len(f_clean)):
    c, e = c.strip(), e.strip()
    text += " " + c + " " + e
    i += 1
    if i <= 40000:
        train_data.append([c, e])
    if i > 40000:
        val_data.append([c, e])
    # if i >= 10000:
    #     break
print(len(train_data), len(val_data))

# tokenizer = CustomTokenizer()
# tokenizer.add_word(text)
# tokenizer.save_vocab("../custom_tokenizer.json")

100%|██████████| 50000/50000 [00:25<00:00, 1973.88it/s]

40000 10000





In [8]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer

def collate_fn(batch):
    src_input_ids, tar_input_ids, src_key_padding_mask, tgt_key_padding_mask, src_len, tar_len = map(torch.stack, zip(*batch))
    src_max_len = max(src_len).item()
    tar_max_len = max(tar_len).item()

    src_input_ids = src_input_ids[:,:src_max_len]
    tar_input_ids = tar_input_ids[:,:tar_max_len]

    src_key_padding_mask = src_key_padding_mask[:,:src_max_len]
    tgt_key_padding_mask = tgt_key_padding_mask[:,:tar_max_len]

    return src_input_ids, tar_input_ids, src_key_padding_mask, tgt_key_padding_mask

def gen_nopeek_mask(length):
        """
        Returns the nopeek mask
                Parameters:
                        length (int): Number of tokens in each sentence in the target batch
                Returns:
                        mask (arr): tgt_mask, looks like [[0., -inf, -inf],
                                                        [0., 0., -inf],
                                                        [0., 0., 0.]]
        """
        mask = torch.triu(torch.ones(length, length))
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

class TextErrorCorrectionDataset(Dataset):
    def __init__(self, data, max_len, tokenizer, device):
        self.data = data
        self.max_len = max_len
        self.tokenizer = tokenizer
        self.device = device

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        src, tar = self.data[idx]

        src_tokens = self.tokenizer.encode_plus(
            src,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        tar_tokens = self.tokenizer.encode_plus(
            tar,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        # print(src_tokens['input_ids'], type(src_tokens['input_ids']))
        # Create masks for the input and target sequences
        src_mask = gen_nopeek_mask(src_tokens['input_ids'].shape[1]).to(device)
        tar_mask = gen_nopeek_mask(tar_tokens['input_ids'].shape[1]).to(device)

        # Create key padding masks for the input and target sequences
        src_key_padding_mask = src_tokens['attention_mask'].flatten()#.unsqueeze(1).unsqueeze(2)
        tar_key_padding_mask = tar_tokens['attention_mask'].flatten()#.unsqueeze(1).unsqueeze(2)

        # Create memory key padding mask
        memory_key_padding_mask = src_key_padding_mask.clone()

        return {
            'src_input_ids': src_tokens['input_ids'].flatten().to(device),
            'tar_input_ids': tar_tokens['input_ids'].flatten().to(device),
            'src_mask': src_mask.to(device),
            'tar_mask': tar_mask.to(device),
            'src_key_padding_mask': src_key_padding_mask.to(device),
            'tar_key_padding_mask': tar_key_padding_mask.to(device),
            'memory_key_padding_mask': memory_key_padding_mask.to(device)
        }

data = [
    ("This is a sentece with a typo.", "This is a sentence with a typo a ? ."),
    ("The quik brown fox jummped over the laizy dog.", "The quick brown fox jumped over the lazy dog.")
    ]

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained(bert_dir)

train_dataset = TextErrorCorrectionDataset(train_data, max_len=max_len, tokenizer=tokenizer, device=device)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
evalu_dataset = TextErrorCorrectionDataset(val_data, max_len=max_len, tokenizer=tokenizer, device=device)
evalu_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

print(tokenizer.vocab_size)
for k,v in train_loader.dataset[0].items():
    print(k, v.shape)
tokenizer.decode(train_loader.dataset[0]['src_input_ids'])

  from .autonotebook import tqdm as notebook_tqdm


32000
src_input_ids torch.Size([64])
tar_input_ids torch.Size([64])
src_mask torch.Size([64, 64])
tar_mask torch.Size([64, 64])
src_key_padding_mask torch.Size([64])
tar_key_padding_mask torch.Size([64])
memory_key_padding_mask torch.Size([64])


'[CLS] хүүхэд буруу зүйл хийсэн даруйд хажууд нь бэлэн байхыг хүснэ.[SEP]<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>'

### *模型*



### Models

Transformer 的序列到序列模型，由编码器和解码器组成。编码器接收源序列并输出连续表示，然后将其馈入解码器以生成目标序列。该模型使用自注意力机制和前馈神经网络 （FFNN） 来处理输入序列


**优势**

并行化：Transformer 架构允许跨输入序列进行并行计算，使其比长序列的递归神经网络 （RNN） 快得多。

可伸缩性：该模型可以处理任意长度的输入序列，使其适用于需要处理长序列的任务。

灵活性：Transformer 架构非常灵活，可以很容易地适应不同的序列到序列任务。

**弊**

计算复杂性：Transformer 架构需要大量的计算资源和内存，尤其是对于长序列。

过拟合：模型可能会受到过拟合的影响，尤其是当训练数据集较小时。

缺乏可解释性：Transformer 架构复杂且难以解释，因此很难理解模型为什么要做出某些预测。

**改进建议**

正则化技术：实现正则化技术，例如辍学和权重衰减，以防止过拟合。


**评估指标**

Accuracy: 精度，计算模型正确纠错的比例。

Precision: 精准率，计算模型正确纠错的比例，考虑了模型的 false positive。

Recall: 召回率，计算模型正确纠错的比例，考虑了模型的 false negative。

F1-score: F1 分数，计算模型的精准率和召回率的调和平均值。

BLEU score: BLEU 评分，计算模型生成的文本与参考文本的相似度。

ROUGE score: ROUGE 评分，计算模型生成的文本与参考文本的相似度，考虑了词语的顺序。

Word Error Rate (WER): 字符错误率，计算模型生成的文本与参考文本的字符级别错误率。

Character Error Rate (CER): 字符错误率，计算模型生成的文本与参考文本的字符级别错误率。


In [None]:
!pip install rouge

In [6]:
from tqdm import tqdm
import json
import torch
import math
import torch.nn as nn
import torch.nn.modules.transformer as T
import torch.optim as optim
import numpy as np
from nltk.translate.bleu_score import sentence_bleu
from rouge import Rouge

class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class Transformer(nn.Module):
    def __init__(self,
        max_len: int = 64,
        num_of_vocab: int = 21128,
        d_model: int = 768,
        nhead: int = 8,
        num_encoder_layers: int = 6,
        num_decoder_layers: int = 6,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        activation: str = "relu"
        ):
        super(Transformer, self).__init__()

        #encoder由encoder_layer和encoder_norm组成
        encoder_norm = T.LayerNorm(d_model)
        encoder_layer = T.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
        self.encoder = T.TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        # decoder由decoder_layer和decoder_norm
        decoder_norm = T.LayerNorm(d_model)
        decoder_layer = T.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
        self.decoder = T.TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)

        self.embeddings = Embeddings(d_model, num_of_vocab)
        self.positional_encoding = PositionalEncoding(d_model=d_model, max_len=max_len)
        self.projection = nn.Linear(d_model, num_of_vocab, bias=False)

    def forward(self, src_input_ids, tar_input_ids, src_mask, tar_mask, src_key_padding_mask, tar_key_padding_mask, memory_key_padding_mask):
        """
        src: 源文本Ids --(N,S,E)
        tar：目标文本Ids --(N,T,E)
        src_key_padding_mask：源文本关键掩码 --(S,S)
        tar_key_padding_mask：目标文本掩码 --(T,T)
        src_mask：the additive mask for the src sequence --(N,S)
        tar_mask：the additive mask for the tgt sequence --(N,T)
        N:batch_size
        S T:序列长度
        decoder的输出[T,N,E]
        """

        src_embedding = self.embeddings(src_input_ids)
        src_embedding = self.positional_encoding(src_embedding)

        #维度要变换一下，(N,S,E)----->(S,N,E)
        src_embedding = src_embedding.permute(1, 0, 2)
        src_mask = src_mask.repeat(8, 1, 1)
        memory = self.encoder(
            src_embedding,
            mask=src_mask,
            src_key_padding_mask=src_key_padding_mask
            )

        #decoder输入的位置编码和token_embedding
        tar_embedding = self.embeddings(tar_input_ids)
        tar_embedding = self.positional_encoding(tar_embedding)

        tar_embedding = tar_embedding.permute(1, 0, 2)
        tar_mask = tar_mask.repeat(8, 1, 1)
        #decoder的输入是encoder的输出mermory和tar_embedding以及相关的mask
        decoder_out = self.decoder(
            tar_embedding,memory,
            tgt_mask=tar_mask,
            tgt_key_padding_mask=tar_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask
            )

        decoder_out = decoder_out.permute(1,0,2)
        out = self.projection(decoder_out)

        return out

def accuracy(predicts, labels):
    correct = 0
    for pred, label in zip(predicts, labels):
        if pred == label:
            correct += 1
    return correct / len(labels)

def precision(predicts, labels):
    true_positives = 0
    false_positives = 0
    for pred, label in zip(predicts, labels):
        if pred == label:
            true_positives += 1
        else:
            false_positives += 1
    return true_positives / (true_positives + false_positives)

def recall(predicts, labels):
    true_positives = 0
    false_negatives = 0
    for pred, label in zip(predicts, labels):
        if pred == label:
            true_positives += 1
        else:
            false_negatives += 1
    return true_positives / (true_positives + false_negatives)

def f1_score(predicts, labels):
    precision_val = precision(predicts, labels)
    recall_val = recall(predicts, labels)
    return 2 * (precision_val * recall_val) / (precision_val + recall_val)

def bleu_score(predicts, labels):
    scores = []
    for pred, label in zip(predicts, labels):
        scores.append(sentence_bleu([label.split()], pred.split()))
    return np.mean(scores)

def rouge_score(predicts, labels):
    rouge = Rouge()
    scores = []
    for pred, label in zip(predicts, labels):
        scores.append(rouge.get_scores(pred, label)[0]['rouge-1']['f'])
    return np.mean(scores)

def wer(predicts, labels):
    errors = 0
    total_chars = 0
    for pred, label in zip(predicts, labels):
        errors += sum(el != ll for el, ll in zip(pred, label))
        total_chars += len(label)
    return errors / total_chars

def cer(predicts, labels):
    errors = 0
    total_chars = 0
    for pred, label in zip(predicts, labels):
        errors += sum(el != ll for el, ll in zip(pred, label))
        total_chars += len(label)
    return errors / total_chars


### 训练&评估

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = Transformer(max_len=max_len, num_of_vocab=tokenizer.vocab_size).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
# for name, module in model.named_modules():
#     print(f"{name}: {module}")

for epoch in range(epochs):
    # ---------训练
    model.train()
    total_loss = 0.0
    train_iterator = tqdm(train_loader, desc="Training", unit="batch")
    for batch in train_iterator:
        src_input_ids = batch['src_input_ids']
        tar_input_ids = batch['tar_input_ids']
        src_input_ids[src_input_ids == 32000] = 0
        tar_input_ids[tar_input_ids == 32000] = 0
        src_mask = batch['src_mask'].bool()
        tar_mask = batch['tar_mask'].bool()
        src_key_padding_mask = batch['src_key_padding_mask'].bool()
        tar_key_padding_mask = batch['tar_key_padding_mask'].bool()
        memory_key_padding_mask = batch['memory_key_padding_mask'].bool()

        # 计算输出
        output = model(src_input_ids, tar_input_ids, src_mask, tar_mask, src_key_padding_mask, tar_key_padding_mask, memory_key_padding_mask)

        # 计算损失
        loss = criterion(output.view(-1, output.size(-1)), tar_input_ids.view(-1))
        total_loss += loss.item()

        # 反向传播和参数更新
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 实时更新进度条的描述信息
        train_iterator.set_postfix({'epoch': epoch + 1, 'loss': total_loss / len(train_iterator)})

    torch.save(model, 'transformers_text_denoise_v2.bin')
    model = torch.load('transformers_text_denoise_v2.bin')

    # ---------评估
    model.eval()
    predicts = []
    labels = []
    total_loss = 0.0
    evalu_iterator = tqdm(evalu_loader, desc="Evaluaing", unit="batch")
    with torch.no_grad():
        for batch in evalu_iterator:
            src_input_ids = batch['src_input_ids']
            tar_input_ids = batch['tar_input_ids']
            src_input_ids[src_input_ids == 32000] = 0
            tar_input_ids[tar_input_ids == 32000] = 0
            src_mask = batch['src_mask'].bool()
            tar_mask = batch['tar_mask'].bool()
            src_key_padding_mask = batch['src_key_padding_mask'].bool()
            tar_key_padding_mask = batch['tar_key_padding_mask'].bool()
            memory_key_padding_mask = batch['memory_key_padding_mask'].bool()

            output = model(src_input_ids, tar_input_ids, src_mask, tar_mask, src_key_padding_mask, tar_key_padding_mask, memory_key_padding_mask)

            loss = criterion(output.view(-1, output.size(-1)), tar_input_ids.view(-1))
            total_loss += loss.item()

            output_tokens = torch.argmax(output, dim=-1)
            output_text = tokenizer.batch_decode(output_tokens, skip_special_tokens=True)
            for o in output_text:
                predicts.append(o)

            label_text = tokenizer.batch_decode(tar_input_ids, skip_special_tokens=True)
            for l in label_text:
                labels.append(l)

            evalu_iterator.set_postfix({'epoch': epoch + 1, 'loss': total_loss / len(evalu_iterator)})

        # Evaluate the model
        accuracy_val = accuracy(predicts, labels)
        precision_val = precision(predicts, labels)
        recall_val = recall(predicts, labels)
        # f1_score_val = f1_score(predicts, labels)
        # bleu_score_val = bleu_score(predicts, labels)
        rouge_score_val = rouge_score(predicts, labels)
        wer_val = wer(predicts, labels)
        cer_val = cer(predicts, labels)

        # Display the results
        print("Accuracy:", accuracy_val)
        print("Precision:", precision_val)
        print("Recall:", recall_val)
        # print("F1 Score:", f1_score_val)
        # print("BLEU Score:", bleu_score_val)
        print("ROUGE Score:", rouge_score_val)
        print("Word Error Rate (WER):", wer_val)
        print("Character Error Rate (CER):", cer_val)


    data = []
    for p,l in zip(predicts, labels):
        data.append({'predict': p, 'label': l})

    with open('output.json', 'w') as f:
        json.dump(data, f)


### 预测

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = Transformer(max_len=max_len, num_of_vocab=tokenizer.vocab_size).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=0.0001)

model = torch.load('../transformers_text_denoise_v2.bin', map_location=torch.device(device))
model.eval()
predicts = []
labels = []
total_loss = 0.0
evalu_iterator = tqdm(evalu_loader, desc="Evaluaing", unit="batch")
with torch.no_grad():
    for batch in evalu_iterator:
        src_input_ids = batch['src_input_ids']
        tar_input_ids = batch['tar_input_ids']
        src_input_ids[src_input_ids == 32000] = 0
        tar_input_ids[tar_input_ids == 32000] = 0
        src_mask = batch['src_mask'].bool()
        tar_mask = batch['tar_mask'].bool()
        src_key_padding_mask = batch['src_key_padding_mask'].bool()
        tar_key_padding_mask = batch['tar_key_padding_mask'].bool()
        memory_key_padding_mask = batch['memory_key_padding_mask'].bool()

        output = model(src_input_ids, tar_input_ids, src_mask, tar_mask, src_key_padding_mask, tar_key_padding_mask, memory_key_padding_mask)

        loss = criterion(output.view(-1, output.size(-1)), tar_input_ids.view(-1))
        total_loss += loss.item()

        output_tokens = torch.argmax(output, dim=-1)
        output_text = tokenizer.batch_decode(output_tokens, skip_special_tokens=True)
        for o in output_text:
            predicts.append(o)

        label_text = tokenizer.batch_decode(tar_input_ids, skip_special_tokens=True)
        for l in label_text:
            labels.append(l)

        evalu_iterator.set_postfix({'loss': total_loss / len(evalu_iterator)})

    # Evaluate the model
    accuracy_val = accuracy(predicts, labels)
    precision_val = precision(predicts, labels)
    recall_val = recall(predicts, labels)
    # f1_score_val = f1_score(predicts, labels)
    # bleu_score_val = bleu_score(predicts, labels)
    rouge_score_val = rouge_score(predicts, labels)
    wer_val = wer(predicts, labels)
    cer_val = cer(predicts, labels)

    # Display the results
    print("Accuracy:", accuracy_val)
    print("Precision:", precision_val)
    print("Recall:", recall_val)
    # print("F1 Score:", f1_score_val)
    # print("BLEU Score:", bleu_score_val)
    print("ROUGE Score:", rouge_score_val)
    print("Word Error Rate (WER):", wer_val)
    print("Character Error Rate (CER):", cer_val)

for p,l in zip(predicts, labels):
    data.append({'predict': p, 'label': l})

with open('../output.json', 'w') as json_file:
    json.dump(data, json_file)