In [1]:
import torch
import torch.nn as nn

# Implemented by myself
from config import *
from data_processer import CSCDataset, load_confusion, split_torch_dataset
from models import CombineBertModel, DecoderBaseRNN, DecoderTransformer, Trainer
from opencc import OpenCC
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import BertModel, BertTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = BertTokenizer.from_pretrained(checkpoint)

In [3]:
converter = OpenCC("t2s")

In [4]:
train_dataset = CSCDataset(
    [SIGHAN_train_dir_err, SIGHAN_train_dir_corr], tokenizer, converter
)
test_dataset = CSCDataset(
    [SIGHAN_train_dir_err14, SIGHAN_train_dir_corr14], tokenizer, converter
)

preprocessing sighan dataset: 2339it [00:00, 10266.85it/s]
preprocessing sighan dataset: 100%|██████████| 2339/2339 [00:00<00:00, 10218.49it/s]


共2339句，共73264字，最长的句子有171字


preprocessing sighan dataset: 3437it [00:00, 6757.74it/s]
preprocessing sighan dataset: 100%|██████████| 3437/3437 [00:00<00:00, 6624.68it/s]

共3437句，共170330字，最长的句子有258字





In [5]:
# split data
train_data, dev_data = split_torch_dataset(train_dataset, 0.3)

train_data_loader = DataLoader(train_data, num_workers=4, shuffle=True, batch_size=16)
dev_data_loader = DataLoader(dev_data, num_workers=4, shuffle=True, batch_size=16)
test_data_loader = DataLoader(test_dataset, num_workers=4, shuffle=True, batch_size=32)

In [6]:
confusion_set = load_confusion(confusion_set_path, tokenizer)

In [7]:
# The Hyperparameters can be defined in config.py
hidden_size = 1024
num_layers = 2

encoder_model = BertModel.from_pretrained(checkpoint)
decoder_model = DecoderBaseRNN(
    model=nn.LSTM,
    input_size=encoder_model.config.hidden_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
)
model = CombineBertModel(encoder_model=encoder_model, decoder_model=decoder_model, confusion_set=confusion_set)

optimizer = AdamW(model.parameters(), lr=learning_rate)
trainer = Trainer(model=model, tokenizer=tokenizer, optimizer=optimizer)

In [9]:
epochs = 35

trainer.train(
    dataloader=train_data_loader, epoch=epochs, test_dataloader=dev_data_loader
)
trainer.test(test_data_loader)

train Epoch:1/35: 100%|██████████| 103/103 [00:29<00:00,  3.51it/s, avg loss=0.850]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.50it/s, batches loss=1.062]

1.002017307010564 {'over_corr': 1895, 'total_err': 882, 'true_corr': tensor(195, device='cuda:0')} {'over_corr': 544, 'total_err': 681, 'true_corr': 34}



train Epoch:2/35: 100%|██████████| 103/103 [00:29<00:00,  3.52it/s, avg loss=0.776]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.57it/s, batches loss=1.148]

0.9422040608796206 {'over_corr': 1744, 'total_err': 882, 'true_corr': tensor(204, device='cuda:0')} {'over_corr': 515, 'total_err': 681, 'true_corr': 41}



train Epoch:3/35: 100%|██████████| 103/103 [00:29<00:00,  3.51it/s, avg loss=0.708]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.50it/s, batches loss=0.669]

0.8903517127037048 {'over_corr': 1555, 'total_err': 882, 'true_corr': tensor(206, device='cuda:0')} {'over_corr': 495, 'total_err': 681, 'true_corr': 46}



train Epoch:4/35: 100%|██████████| 103/103 [00:29<00:00,  3.50it/s, avg loss=0.648]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.41it/s, batches loss=0.841]

0.8459505154327913 {'over_corr': 1418, 'total_err': 882, 'true_corr': tensor(212, device='cuda:0')} {'over_corr': 484, 'total_err': 681, 'true_corr': 53}



train Epoch:5/35: 100%|██████████| 103/103 [00:29<00:00,  3.51it/s, avg loss=0.589]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.44it/s, batches loss=0.819]

0.8035763468254696 {'over_corr': 1324, 'total_err': 882, 'true_corr': tensor(217, device='cuda:0')} {'over_corr': 472, 'total_err': 681, 'true_corr': 64}



train Epoch:6/35: 100%|██████████| 103/103 [00:29<00:00,  3.51it/s, avg loss=0.537]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.45it/s, batches loss=0.505]

0.7649020755832846 {'over_corr': 1224, 'total_err': 882, 'true_corr': tensor(233, device='cuda:0')} {'over_corr': 450, 'total_err': 681, 'true_corr': 69}



train Epoch:7/35: 100%|██████████| 103/103 [00:29<00:00,  3.52it/s, avg loss=0.493]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.46it/s, batches loss=0.667]

0.7274990623647516 {'over_corr': 1128, 'total_err': 882, 'true_corr': tensor(230, device='cuda:0')} {'over_corr': 432, 'total_err': 681, 'true_corr': 71}



train Epoch:8/35: 100%|██████████| 103/103 [00:29<00:00,  3.50it/s, avg loss=0.447]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.38it/s, batches loss=0.494]

0.7015925943851471 {'over_corr': 1076, 'total_err': 882, 'true_corr': tensor(248, device='cuda:0')} {'over_corr': 423, 'total_err': 681, 'true_corr': 81}



train Epoch:9/35: 100%|██████████| 103/103 [00:29<00:00,  3.52it/s, avg loss=0.411]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.38it/s, batches loss=0.553]

0.677770508961244 {'over_corr': 1031, 'total_err': 882, 'true_corr': tensor(245, device='cuda:0')} {'over_corr': 415, 'total_err': 681, 'true_corr': 83}



train Epoch:10/35: 100%|██████████| 103/103 [00:29<00:00,  3.52it/s, avg loss=0.375]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.46it/s, batches loss=0.762]

0.6552754674445499 {'over_corr': 968, 'total_err': 882, 'true_corr': tensor(241, device='cuda:0')} {'over_corr': 396, 'total_err': 681, 'true_corr': 82}



train Epoch:11/35: 100%|██████████| 103/103 [00:29<00:00,  3.51it/s, avg loss=0.344]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.37it/s, batches loss=0.558]

0.630167964507233 {'over_corr': 922, 'total_err': 882, 'true_corr': tensor(248, device='cuda:0')} {'over_corr': 383, 'total_err': 681, 'true_corr': 95}



train Epoch:12/35: 100%|██████████| 103/103 [00:29<00:00,  3.50it/s, avg loss=0.314]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.29it/s, batches loss=0.463]

0.6157019544731487 {'over_corr': 888, 'total_err': 882, 'true_corr': tensor(253, device='cuda:0')} {'over_corr': 367, 'total_err': 681, 'true_corr': 100}



train Epoch:13/35: 100%|██████████| 103/103 [00:29<00:00,  3.52it/s, avg loss=0.292]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.32it/s, batches loss=0.609]

0.5951351862062108 {'over_corr': 818, 'total_err': 882, 'true_corr': tensor(239, device='cuda:0')} {'over_corr': 357, 'total_err': 681, 'true_corr': 95}



train Epoch:14/35: 100%|██████████| 103/103 [00:29<00:00,  3.51it/s, avg loss=0.268]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.27it/s, batches loss=0.527]

0.5786606689745729 {'over_corr': 794, 'total_err': 882, 'true_corr': tensor(257, device='cuda:0')} {'over_corr': 352, 'total_err': 681, 'true_corr': 107}



train Epoch:15/35: 100%|██████████| 103/103 [00:29<00:00,  3.51it/s, avg loss=0.246]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.34it/s, batches loss=0.579]

0.5608029433272101 {'over_corr': 764, 'total_err': 882, 'true_corr': tensor(263, device='cuda:0')} {'over_corr': 348, 'total_err': 681, 'true_corr': 106}



train Epoch:16/35: 100%|██████████| 103/103 [00:29<00:00,  3.51it/s, avg loss=0.227]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.30it/s, batches loss=0.327]

0.5498586378314279 {'over_corr': 726, 'total_err': 882, 'true_corr': tensor(250, device='cuda:0')} {'over_corr': 333, 'total_err': 681, 'true_corr': 105}



train Epoch:17/35: 100%|██████████| 103/103 [00:29<00:00,  3.50it/s, avg loss=0.209]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.25it/s, batches loss=0.672]

0.5396966873244806 {'over_corr': 690, 'total_err': 882, 'true_corr': tensor(257, device='cuda:0')} {'over_corr': 324, 'total_err': 681, 'true_corr': 111}



train Epoch:18/35: 100%|██████████| 103/103 [00:29<00:00,  3.51it/s, avg loss=0.193]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.27it/s, batches loss=0.407]

0.525457542728294 {'over_corr': 653, 'total_err': 882, 'true_corr': tensor(257, device='cuda:0')} {'over_corr': 312, 'total_err': 681, 'true_corr': 111}



train Epoch:19/35: 100%|██████████| 103/103 [00:29<00:00,  3.51it/s, avg loss=0.178]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.20it/s, batches loss=0.818]

0.520991438491778 {'over_corr': 621, 'total_err': 882, 'true_corr': tensor(255, device='cuda:0')} {'over_corr': 300, 'total_err': 681, 'true_corr': 116}



train Epoch:20/35: 100%|██████████| 103/103 [00:29<00:00,  3.52it/s, avg loss=0.165]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.23it/s, batches loss=0.455]

0.5094762007621202 {'over_corr': 601, 'total_err': 882, 'true_corr': tensor(255, device='cuda:0')} {'over_corr': 301, 'total_err': 681, 'true_corr': 113}



train Epoch:21/35: 100%|██████████| 103/103 [00:29<00:00,  3.52it/s, avg loss=0.152]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.28it/s, batches loss=0.345]

0.5027322511781346 {'over_corr': 572, 'total_err': 882, 'true_corr': tensor(258, device='cuda:0')} {'over_corr': 292, 'total_err': 681, 'true_corr': 119}



train Epoch:22/35: 100%|██████████| 103/103 [00:29<00:00,  3.51it/s, avg loss=0.142]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.30it/s, batches loss=0.392]

0.4974165924570777 {'over_corr': 556, 'total_err': 882, 'true_corr': tensor(254, device='cuda:0')} {'over_corr': 283, 'total_err': 681, 'true_corr': 120}



train Epoch:23/35: 100%|██████████| 103/103 [00:29<00:00,  3.51it/s, avg loss=0.132]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.31it/s, batches loss=0.774]

0.4926236298951236 {'over_corr': 534, 'total_err': 882, 'true_corr': tensor(256, device='cuda:0')} {'over_corr': 278, 'total_err': 681, 'true_corr': 123}



train Epoch:24/35: 100%|██████████| 103/103 [00:29<00:00,  3.51it/s, avg loss=0.121]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.27it/s, batches loss=0.416]

0.47918081080371683 {'over_corr': 503, 'total_err': 882, 'true_corr': tensor(254, device='cuda:0')} {'over_corr': 264, 'total_err': 681, 'true_corr': 125}



train Epoch:25/35: 100%|██████████| 103/103 [00:29<00:00,  3.50it/s, avg loss=0.113]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.21it/s, batches loss=0.504]

0.48037028177218005 {'over_corr': 509, 'total_err': 882, 'true_corr': tensor(258, device='cuda:0')} {'over_corr': 270, 'total_err': 681, 'true_corr': 123}



train Epoch:26/35: 100%|██████████| 103/103 [00:29<00:00,  3.51it/s, avg loss=0.106]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.27it/s, batches loss=0.852]

0.47655446150086145 {'over_corr': 487, 'total_err': 882, 'true_corr': tensor(264, device='cuda:0')} {'over_corr': 260, 'total_err': 681, 'true_corr': 127}



train Epoch:27/35: 100%|██████████| 103/103 [00:29<00:00,  3.51it/s, avg loss=0.098]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.27it/s, batches loss=0.603]

0.47188553654334764 {'over_corr': 507, 'total_err': 882, 'true_corr': tensor(269, device='cuda:0')} {'over_corr': 273, 'total_err': 681, 'true_corr': 125}



train Epoch:28/35: 100%|██████████| 103/103 [00:29<00:00,  3.50it/s, avg loss=0.090]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.21it/s, batches loss=0.388]

0.4635430811481042 {'over_corr': 474, 'total_err': 882, 'true_corr': tensor(261, device='cuda:0')} {'over_corr': 257, 'total_err': 681, 'true_corr': 129}



train Epoch:29/35: 100%|██████████| 103/103 [00:29<00:00,  3.51it/s, avg loss=0.084]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.26it/s, batches loss=0.307]

0.4605134078724818 {'over_corr': 442, 'total_err': 882, 'true_corr': tensor(257, device='cuda:0')} {'over_corr': 241, 'total_err': 681, 'true_corr': 132}



train Epoch:30/35: 100%|██████████| 103/103 [00:29<00:00,  3.52it/s, avg loss=0.077]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.21it/s, batches loss=0.345]

0.4565389671108939 {'over_corr': 442, 'total_err': 882, 'true_corr': tensor(262, device='cuda:0')} {'over_corr': 242, 'total_err': 681, 'true_corr': 136}



train Epoch:31/35: 100%|██████████| 103/103 [00:29<00:00,  3.51it/s, avg loss=0.071]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.22it/s, batches loss=0.539]

0.4603310190141201 {'over_corr': 420, 'total_err': 882, 'true_corr': tensor(246, device='cuda:0')} {'over_corr': 237, 'total_err': 681, 'true_corr': 126}



train Epoch:32/35: 100%|██████████| 103/103 [00:29<00:00,  3.52it/s, avg loss=0.065]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.24it/s, batches loss=0.373]

0.4501369216225364 {'over_corr': 407, 'total_err': 882, 'true_corr': tensor(255, device='cuda:0')} {'over_corr': 229, 'total_err': 681, 'true_corr': 136}



train Epoch:33/35: 100%|██████████| 103/103 [00:29<00:00,  3.52it/s, avg loss=0.060]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.23it/s, batches loss=0.316]

0.4458432810550386 {'over_corr': 407, 'total_err': 882, 'true_corr': tensor(268, device='cuda:0')} {'over_corr': 231, 'total_err': 681, 'true_corr': 138}



train Epoch:34/35: 100%|██████████| 103/103 [00:29<00:00,  3.51it/s, avg loss=0.055]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.19it/s, batches loss=0.240]

0.44602563977241516 {'over_corr': 392, 'total_err': 882, 'true_corr': tensor(259, device='cuda:0')} {'over_corr': 226, 'total_err': 681, 'true_corr': 139}



train Epoch:35/35:  11%|█         | 11/103 [00:03<00:29,  3.16it/s, batches loss=0.052]


KeyboardInterrupt: 

In [None]:
# The Hyperparameters can be defined in config.py
hidden_size = 1024
num_layers = 2

encoder_model = BertModel.from_pretrained(checkpoint)
decoder_model = DecoderBaseRNN(
    model=nn.GRU,
    input_size=encoder_model.config.hidden_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
)
model = CombineBertModel(encoder_model=encoder_model, decoder_model=decoder_model)

optimizer = AdamW(model.parameters(), lr=learning_rate)
trainer = Trainer(model=model, tokenizer=tokenizer, optimizer=optimizer)

In [None]:
trainer.train(
    dataloader=train_data_loader, epoch=epochs, test_dataloader=dev_data_loader
)
trainer.test(test_data_loader)

In [None]:
nhead = 2
num_encoder_layers = 2
num_decoder_layers = 2

encoder_model = BertModel.from_pretrained(checkpoint)
decoder_model = DecoderTransformer(
    input_size=encoder_model.config.hidden_size,
    nhead=nhead,
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
)
model = CombineBertModel(encoder_model=encoder_model, decoder_model=decoder_model)

optimizer = AdamW(model.parameters(), lr=learning_rate)
trainer = Trainer(model=model, tokenizer=tokenizer, optimizer=optimizer)

In [None]:
trainer.train(
    dataloader=train_data_loader, epoch=epochs, test_dataloader=dev_data_loader
)
trainer.test(test_data_loader)