In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
from transformers import BertModel, BertTokenizer

# Implemented by myself
from config import *
from data_processer import CSCDataset
from models import CombineBertModel, DecoderLSTM

In [2]:
def cal_err(raw_sentence, pred_sentence, corr_sentence):
    matrices = ["over_corr", "total_err", "true_corr"]
    char_level = {key: 0 for key in matrices}
    sent_level = {key: 0 for key in matrices}

    f1 = f2 = 0
    for i, c in enumerate(raw_sentence):
        pc, cc = pred_sentence[i], corr_sentence[i]

        if cc != c:
            char_level["total_err"] += 1
            char_level["true_corr"] += pc == cc
            f1 = 1
        elif pc != cc:
            char_level["over_corr"] += 1
            f2 = 1

    if f1:
        sent_level["true_corr"] += all(pred_sentence == corr_sentence)
        sent_level["total_err"] += f1
    sent_level["over_corr"] += f2

    return char_level, sent_level

In [3]:
def test(model, tokenizer, test_data_loader):
    model.eval()
    total_loss = 0
    matrices = ["over_corr", "total_err", "true_corr"]
    test_char_level = {key: 0 for key in matrices}
    test_sent_level = {key: 0 for key in matrices}

    with torch.no_grad():
        for batch in test_data_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device).type(torch.float)
            labels = batch["labels"].to(device)

            outputs = model(input_ids, src_mask=attention_mask)
            logits = outputs.permute(0, 2, 1)

            loss = cross_entropy(logits, labels, ignore_index=tokenizer.pad_token_id)
            total_loss += loss.item()

            t = torch.argmax(outputs, dim=-1)
            nt = t * attention_mask
            pred = tokenizer.batch_decode(nt, skip_special_tokens=True)

            for i in range(len(t)):
                char_level, sent_level = cal_err(input_ids[i], nt[i], labels[i])
                test_char_level = {
                    key: test_char_level[key] + v for key, v in char_level.items()
                }
                test_sent_level = {
                    key: test_sent_level[key] + v for key, v in sent_level.items()
                }
        print(total_loss / len(test_data_loader), test_char_level, test_sent_level)

In [4]:
tokenizer = BertTokenizer.from_pretrained(checkpoint)
encoder_model = BertModel.from_pretrained(checkpoint)

# The Hyperparameters can be defined in config.py
hidden_size = 1024
num_layers = 4

decoder_model = DecoderLSTM(
    input_size=tokenizer.vocab_size, hidden_size=hidden_size, num_layers=num_layers
)

csc_model = CombineBertModel(encoder_model=encoder_model, decoder_model=decoder_model)

In [8]:
train_dataset = CSCDataset(
    [SIGHAN_train_dir_err, SIGHAN_train_dir_corr], tokenizer
)
train_data_loader = DataLoader(
    train_dataset, num_workers=0, shuffle=True, batch_size=16
)

test_dataset = CSCDataset(
    [SIGHAN_train_dir_err14, SIGHAN_train_dir_corr14], tokenizer
)
test_data_loader = DataLoader(test_dataset, num_workers=0, shuffle=True, batch_size=32)

preprocessing sighan dataset: 2339it [00:00, 115839.85it/s]
preprocessing sighan dataset: 100%|█████████████████████████████████████████████| 2339/2339 [00:00<00:00, 90032.46it/s]


共2339句，共73264字，最长的句子有171字


preprocessing sighan dataset: 3437it [00:00, 74463.05it/s]
preprocessing sighan dataset: 100%|█████████████████████████████████████████████| 3437/3437 [00:00<00:00, 73102.92it/s]

共3437句，共170330字，最长的句子有258字



