In [1]:
import torch
import torch.nn as nn

# Implemented by myself
from config import *
from data_processer import CSCDataset, split_torch_dataset
from models import DecoderBaseRNN, DecoderTransformer
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import BertModel, BertTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class TempBertModel(nn.Module):
    def __init__(self, encoder_model, decoder_model, output_size=2):
        super(TempBertModel, self).__init__()
        self.encoder = encoder_model
        self.decoder = decoder_model

        self.linear = nn.Linear(decoder_model.config.hidden_size, output_size)

    def forward(self, src, src_mask):
        x = self.encoder(src, attention_mask=src_mask).last_hidden_state
        x = self.decoder(x)

        x = self.linear(x)

        return x

    def save(self, store_path):
        torch.save(self, store_path)

    def save_state(self, store_path):
        torch.save(self.state_dict(), store_path)

In [3]:
tokenizer = BertTokenizer.from_pretrained(checkpoint)

In [4]:
train_dataset = CSCDataset([SIGHAN_train_dir_err, SIGHAN_train_dir_corr], tokenizer)
test_dataset = CSCDataset([SIGHAN_train_dir_err14, SIGHAN_train_dir_corr14], tokenizer)

preprocessing sighan dataset: 2339it [00:00, 930255.74it/s]
preprocessing sighan dataset: 100%|██████████| 2339/2339 [00:00<00:00, 1577247.12it/s]


共2339句，共73264字，最长的句子有171字


preprocessing sighan dataset: 3437it [00:00, 862809.60it/s]
preprocessing sighan dataset: 100%|██████████| 3437/3437 [00:00<00:00, 1344132.67it/s]

共3437句，共170330字，最长的句子有258字





In [5]:
# split data
train_data, dev_data = split_torch_dataset(train_dataset, 0.3)

train_data_loader = DataLoader(train_data, num_workers=4, shuffle=True, batch_size=16)
dev_data_loader = DataLoader(dev_data, num_workers=4, shuffle=True, batch_size=16)
test_data_loader = DataLoader(test_dataset, num_workers=4, shuffle=True, batch_size=32)

In [18]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
from utils import cal_err


class Trainer:
    def __init__(self, model, tokenizer, optimizer):
        self.model = model
        self.tokenizer = tokenizer
        self.optimizer = optimizer

    def train(self, dataloader, epoch, test_dataloader=None, printepoch=float("inf")):
        self.iteration(dataloader, epoch, test_dataloader, printepoch)

    def test(self, dataloader):
        matrices = ["over_corr", "total_err", "true_corr"]
        self.test_char_level = {key: 0 for key in matrices}
        self.test_sent_level = {key: 0 for key in matrices}
        with torch.no_grad():
            self.iteration(dataloader, train=False)

    def iteration(
        self,
        dataloader,
        epochs=1,
        test_dataloader=None,
        printepoch=float("inf"),
        train=True,
    ):
        mode = "train" if train else "dev"
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(device)

        for epoch in range(epochs):
            self.model.train() if train else self.model.eval()
            total_loss = 0

            progress_bar = tqdm(
                enumerate(dataloader),
                desc=f"{mode} Epoch:{epoch+1}/{epochs}",
                total=len(dataloader),
            )
            for i, batch in progress_bar:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device).type(torch.long)
                labels = batch["labels"].to(device)
                labels = (labels == input_ids).to(torch.long)

                outputs = self.model(input_ids, src_mask=attention_mask)
                logits = outputs.permute(0, 2, 1)  # (batch_size, vocab_size, seq_len)

                # 反向传播在这，故labels不需要传入模型
                # loss = F.cross_entropy(
                #     logits, new_labels, ignore_index=self.tokenizer.pad_token_id
                # )
                loss = F.cross_entropy(
                    logits, labels
                )
                total_loss += loss.item()

                if train:
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                else:
                    t = torch.argmax(outputs, dim=-1)
                    for i in range(len(t)):
                        char_level, sent_level = cal_err(
                            (input_ids[i]==input_ids[i]),
                            t[i],
                            labels[i],
                            sum(attention_mask[i].to("cpu")),
                        )
                        self.test_char_level = {
                            key: self.test_char_level[key] + v
                            for key, v in char_level.items()
                        }
                        self.test_sent_level = {
                            key: self.test_sent_level[key] + v
                            for key, v in sent_level.items()
                        }

                progress_bar.set_postfix({"batches loss": "{:.3f}".format(loss.item())})
                if i == len(progress_bar) - 1:
                    progress_bar.set_postfix({"avg loss": "{:.3f}".format(total_loss / len(dataloader))})

            if (epoch + 1) % printepoch == 0:
                with torch.no_grad():
                    t = torch.argmax(outputs, dim=-1)
                    nt = t * attention_mask
                    pred = self.tokenizer.batch_decode(nt, skip_special_tokens=True)

                    for i, v in enumerate(nt):
                        r, l = input_ids[i], labels[i]
                        limit_length = sum(attention_mask[i].to("cpu"))
                        print(self.tokenizer.decode(r, skip_special_tokens=True))
                        print(self.tokenizer.decode(v, skip_special_tokens=True))
                        print(self.tokenizer.decode(l, skip_special_tokens=True))
                        print(cal_err(r, v, l, limit_length))

            # print(f"Epoch {epoch+1} Loss: {total_loss / len(dataloader)}")

            # dev
            if test_dataloader:
                self.test(test_dataloader)

        if mode == "dev":
            print(
                total_loss / len(dataloader),
                self.test_char_level,
                self.test_sent_level,
            )

In [19]:
# The Hyperparameters can be defined in config.py
hidden_size = 1024
num_layers = 2
output_size = 2

encoder_model = BertModel.from_pretrained(checkpoint)
decoder_model = DecoderBaseRNN(
    model=nn.LSTM,
    input_size=encoder_model.config.hidden_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
)
model = TempBertModel(encoder_model=encoder_model, decoder_model=decoder_model, output_size=output_size)

optimizer = AdamW(model.parameters(), lr=learning_rate)

In [20]:

trainer = Trainer(model=model, tokenizer=tokenizer, optimizer=optimizer)
epochs = 20
trainer.train(dataloader=train_data_loader, epoch=epochs, test_dataloader=dev_data_loader)
# trainer.test(test_data_loader)

train Epoch:1/20: 100%|██████████| 103/103 [00:23<00:00,  4.44it/s, avg loss=0.041]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:06<00:00,  6.80it/s, batches loss=0.011]

0.011814418206499382 {'over_corr': 0, 'total_err': 974, 'true_corr': tensor(0, device='cuda:0')} {'over_corr': 0, 'total_err': 701, 'true_corr': 0}



train Epoch:2/20: 100%|██████████| 103/103 [00:23<00:00,  4.42it/s, avg loss=0.009]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:06<00:00,  6.71it/s, batches loss=0.009]

0.0076307733860713515 {'over_corr': 2, 'total_err': 974, 'true_corr': tensor(7, device='cuda:0')} {'over_corr': 2, 'total_err': 701, 'true_corr': 5}



train Epoch:3/20: 100%|██████████| 103/103 [00:23<00:00,  4.41it/s, avg loss=0.006]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:06<00:00,  6.45it/s, batches loss=0.004]

0.005939558539962904 {'over_corr': 62, 'total_err': 974, 'true_corr': tensor(186, device='cuda:0')} {'over_corr': 59, 'total_err': 701, 'true_corr': 106}



train Epoch:4/20: 100%|██████████| 103/103 [00:23<00:00,  4.42it/s, avg loss=0.004]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.39it/s, batches loss=0.008]

0.005587149088651958 {'over_corr': 154, 'total_err': 974, 'true_corr': tensor(579, device='cuda:0')} {'over_corr': 146, 'total_err': 701, 'true_corr': 359}



train Epoch:5/20: 100%|██████████| 103/103 [00:23<00:00,  4.42it/s, avg loss=0.002]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.69it/s, batches loss=0.006]

0.006230368892746893 {'over_corr': 85, 'total_err': 974, 'true_corr': tensor(485, device='cuda:0')} {'over_corr': 80, 'total_err': 701, 'true_corr': 308}



train Epoch:6/20: 100%|██████████| 103/103 [00:23<00:00,  4.42it/s, avg loss=0.002]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.44it/s, batches loss=0.009]

0.005751423648854887 {'over_corr': 190, 'total_err': 974, 'true_corr': tensor(661, device='cuda:0')} {'over_corr': 175, 'total_err': 701, 'true_corr': 381}



train Epoch:7/20: 100%|██████████| 103/103 [00:23<00:00,  4.44it/s, avg loss=0.001]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.36it/s, batches loss=0.017]

0.0058627356854479085 {'over_corr': 207, 'total_err': 974, 'true_corr': tensor(670, device='cuda:0')} {'over_corr': 177, 'total_err': 701, 'true_corr': 400}



train Epoch:8/20: 100%|██████████| 103/103 [00:23<00:00,  4.42it/s, avg loss=0.001]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.43it/s, batches loss=0.004]

0.006481772839007052 {'over_corr': 256, 'total_err': 974, 'true_corr': tensor(706, device='cuda:0')} {'over_corr': 214, 'total_err': 701, 'true_corr': 382}



train Epoch:9/20: 100%|██████████| 103/103 [00:23<00:00,  4.42it/s, avg loss=0.001]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.41it/s, batches loss=0.011]

0.005756701924838126 {'over_corr': 186, 'total_err': 974, 'true_corr': tensor(670, device='cuda:0')} {'over_corr': 167, 'total_err': 701, 'true_corr': 389}



train Epoch:10/20: 100%|██████████| 103/103 [00:23<00:00,  4.42it/s, avg loss=0.001]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.40it/s, batches loss=0.006]


0.006417100756979463 {'over_corr': 167, 'total_err': 974, 'true_corr': tensor(659, device='cuda:0')} {'over_corr': 156, 'total_err': 701, 'true_corr': 400}


train Epoch:11/20: 100%|██████████| 103/103 [00:23<00:00,  4.42it/s, avg loss=0.000]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.42it/s, batches loss=0.005]

0.006812394519908015 {'over_corr': 123, 'total_err': 974, 'true_corr': tensor(617, device='cuda:0')} {'over_corr': 119, 'total_err': 701, 'true_corr': 385}



train Epoch:12/20: 100%|██████████| 103/103 [00:23<00:00,  4.43it/s, avg loss=0.000]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.41it/s, batches loss=0.005]

0.00714326270496134 {'over_corr': 188, 'total_err': 974, 'true_corr': tensor(660, device='cuda:0')} {'over_corr': 175, 'total_err': 701, 'true_corr': 392}



train Epoch:13/20: 100%|██████████| 103/103 [00:23<00:00,  4.41it/s, avg loss=0.000]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.27it/s, batches loss=0.006]

0.007130208017770201 {'over_corr': 217, 'total_err': 974, 'true_corr': tensor(689, device='cuda:0')} {'over_corr': 191, 'total_err': 701, 'true_corr': 411}



train Epoch:14/20: 100%|██████████| 103/103 [00:23<00:00,  4.41it/s, avg loss=0.000]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.36it/s, batches loss=0.003]

0.007133552637903697 {'over_corr': 105, 'total_err': 974, 'true_corr': tensor(611, device='cuda:0')} {'over_corr': 104, 'total_err': 701, 'true_corr': 381}



train Epoch:15/20: 100%|██████████| 103/103 [00:23<00:00,  4.43it/s, avg loss=0.000]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.33it/s, batches loss=0.013]

0.007810324753253636 {'over_corr': 173, 'total_err': 974, 'true_corr': tensor(633, device='cuda:0')} {'over_corr': 163, 'total_err': 701, 'true_corr': 406}



train Epoch:16/20: 100%|██████████| 103/103 [00:23<00:00,  4.41it/s, avg loss=0.000]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.40it/s, batches loss=0.010]

0.007143394349523905 {'over_corr': 149, 'total_err': 974, 'true_corr': tensor(641, device='cuda:0')} {'over_corr': 136, 'total_err': 701, 'true_corr': 399}



train Epoch:17/20: 100%|██████████| 103/103 [00:23<00:00,  4.43it/s, avg loss=0.000]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.44it/s, batches loss=0.002]

0.007401261393996802 {'over_corr': 110, 'total_err': 974, 'true_corr': tensor(595, device='cuda:0')} {'over_corr': 108, 'total_err': 701, 'true_corr': 376}



train Epoch:18/20: 100%|██████████| 103/103 [00:23<00:00,  4.43it/s, avg loss=0.000]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.38it/s, batches loss=0.004]

0.0074878564443100586 {'over_corr': 134, 'total_err': 974, 'true_corr': tensor(615, device='cuda:0')} {'over_corr': 124, 'total_err': 701, 'true_corr': 387}



train Epoch:19/20: 100%|██████████| 103/103 [00:23<00:00,  4.42it/s, avg loss=0.000]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.40it/s, batches loss=0.008]


0.0075878641853870995 {'over_corr': 162, 'total_err': 974, 'true_corr': tensor(664, device='cuda:0')} {'over_corr': 149, 'total_err': 701, 'true_corr': 396}


train Epoch:20/20: 100%|██████████| 103/103 [00:23<00:00,  4.41it/s, avg loss=0.000]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.31it/s, batches loss=0.009]

0.007925395343177528 {'over_corr': 182, 'total_err': 974, 'true_corr': tensor(690, device='cuda:0')} {'over_corr': 170, 'total_err': 701, 'true_corr': 414}





In [21]:
trainer.test(test_data_loader)

dev Epoch:1/1: 100%|██████████| 108/108 [00:39<00:00,  2.74it/s, batches loss=0.008]

0.011483406777390175 {'over_corr': 1461, 'total_err': 5278, 'true_corr': tensor(3284, device='cuda:0')} {'over_corr': 1240, 'total_err': 3436, 'true_corr': 1570}



