In [1]:
import torch
import torch.nn as nn

# Implemented by myself
from config import *
from data_processer import CSCDataset, split_torch_dataset
from models import CombineBertModel, DecoderBaseRNN, DecoderTransformer, Trainer
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import BertModel, BertTokenizer

  from .autonotebook import tqdm as notebook_tqdm


#### Tokenizer

In [2]:
tokenizer = BertTokenizer.from_pretrained(checkpoint)

#### DataSet

In [3]:
train_dataset = CSCDataset([SIGHAN_train_dir_err, SIGHAN_train_dir_corr], tokenizer)
test_dataset = CSCDataset([SIGHAN_train_dir_err14, SIGHAN_train_dir_corr14], tokenizer)

preprocessing sighan dataset: 2339it [00:00, 965978.44it/s]
preprocessing sighan dataset: 100%|██████████| 2339/2339 [00:00<00:00, 1584635.29it/s]


共2339句，共73264字，最长的句子有171字


preprocessing sighan dataset: 3437it [00:00, 888877.97it/s]
preprocessing sighan dataset: 100%|██████████| 3437/3437 [00:00<00:00, 1370846.60it/s]

共3437句，共170330字，最长的句子有258字





In [4]:
# split data
train_data, dev_data = split_torch_dataset(train_dataset, 0.3)

train_data_loader = DataLoader(train_data, num_workers=4, shuffle=True, batch_size=16)
dev_data_loader = DataLoader(dev_data, num_workers=4, shuffle=True, batch_size=16)
test_data_loader = DataLoader(test_dataset, num_workers=4, shuffle=True, batch_size=32)

#### Config
most config from config.py

In [5]:
# epochs = 35

#### BERT + LSTM

In [6]:
# The Hyperparameters can be defined in config.py
hidden_size = 1024
num_layers = 2

encoder_model = BertModel.from_pretrained(checkpoint)
decoder_model = DecoderBaseRNN(
    model=nn.LSTM,
    input_size=encoder_model.config.hidden_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
)
model = CombineBertModel(encoder_model=encoder_model, decoder_model=decoder_model)

optimizer = AdamW(model.parameters(), lr=learning_rate)
trainer = Trainer(model=model, tokenizer=tokenizer, optimizer=optimizer)

In [7]:
trainer.train(
    dataloader=train_data_loader, epoch=epochs, test_dataloader=dev_data_loader
)
trainer.test(test_data_loader)

train Epoch:1/50: 100%|██████████| 103/103 [00:28<00:00,  3.61it/s, avg loss=7.074]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.93it/s, batches loss=5.807]

5.756522612138228 {'over_corr': 20612, 'total_err': 925, 'true_corr': tensor(17, device='cuda:0')} {'over_corr': 701, 'total_err': 701, 'true_corr': 0}



train Epoch:2/50: 100%|██████████| 103/103 [00:28<00:00,  3.67it/s, avg loss=5.677]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.93it/s, batches loss=5.641]

5.63110322302038 {'over_corr': 20511, 'total_err': 925, 'true_corr': tensor(1, device='cuda:0')} {'over_corr': 701, 'total_err': 701, 'true_corr': 0}



train Epoch:3/50: 100%|██████████| 103/103 [00:28<00:00,  3.67it/s, avg loss=5.605]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.95it/s, batches loss=5.664]

5.572760657830671 {'over_corr': 19736, 'total_err': 925, 'true_corr': tensor(17, device='cuda:0')} {'over_corr': 701, 'total_err': 701, 'true_corr': 0}



train Epoch:4/50: 100%|██████████| 103/103 [00:28<00:00,  3.67it/s, avg loss=5.438]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.93it/s, batches loss=5.159]

5.249195348132741 {'over_corr': 19447, 'total_err': 925, 'true_corr': tensor(2, device='cuda:0')} {'over_corr': 701, 'total_err': 701, 'true_corr': 0}



train Epoch:5/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=5.276]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.91it/s, batches loss=5.382]

5.21177951856093 {'over_corr': 18822, 'total_err': 925, 'true_corr': tensor(3, device='cuda:0')} {'over_corr': 701, 'total_err': 701, 'true_corr': 0}



train Epoch:6/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=5.055]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.90it/s, batches loss=4.909]

4.87856032631614 {'over_corr': 17169, 'total_err': 925, 'true_corr': tensor(15, device='cuda:0')} {'over_corr': 701, 'total_err': 701, 'true_corr': 0}



train Epoch:7/50: 100%|██████████| 103/103 [00:28<00:00,  3.67it/s, avg loss=4.676]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.90it/s, batches loss=4.413]

4.4556426351720635 {'over_corr': 14673, 'total_err': 925, 'true_corr': tensor(25, device='cuda:0')} {'over_corr': 700, 'total_err': 701, 'true_corr': 0}



train Epoch:8/50: 100%|██████████| 103/103 [00:28<00:00,  3.67it/s, avg loss=4.228]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.95it/s, batches loss=3.993]

3.97667917880145 {'over_corr': 12468, 'total_err': 925, 'true_corr': tensor(26, device='cuda:0')} {'over_corr': 700, 'total_err': 701, 'true_corr': 0}



train Epoch:9/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=3.736]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.96it/s, batches loss=3.308]

3.476110740141435 {'over_corr': 9576, 'total_err': 925, 'true_corr': tensor(65, device='cuda:0')} {'over_corr': 699, 'total_err': 701, 'true_corr': 1}



train Epoch:10/50: 100%|██████████| 103/103 [00:28<00:00,  3.67it/s, avg loss=3.098]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.92it/s, batches loss=2.801]

2.8167636177756568 {'over_corr': 7749, 'total_err': 925, 'true_corr': tensor(99, device='cuda:0')} {'over_corr': 696, 'total_err': 701, 'true_corr': 0}



train Epoch:11/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=2.622]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.93it/s, batches loss=2.177]

2.4502982876517554 {'over_corr': 6113, 'total_err': 925, 'true_corr': tensor(121, device='cuda:0')} {'over_corr': 694, 'total_err': 701, 'true_corr': 0}



train Epoch:12/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=2.271]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.86it/s, batches loss=2.392]

2.142581983046098 {'over_corr': 4987, 'total_err': 925, 'true_corr': tensor(147, device='cuda:0')} {'over_corr': 684, 'total_err': 701, 'true_corr': 2}



train Epoch:13/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=1.969]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.84it/s, batches loss=1.472]

1.8880878686904907 {'over_corr': 4220, 'total_err': 925, 'true_corr': tensor(160, device='cuda:0')} {'over_corr': 671, 'total_err': 701, 'true_corr': 6}



train Epoch:14/50: 100%|██████████| 103/103 [00:28<00:00,  3.67it/s, avg loss=1.712]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.87it/s, batches loss=1.851]

1.6763903119347312 {'over_corr': 3413, 'total_err': 925, 'true_corr': tensor(176, device='cuda:0')} {'over_corr': 647, 'total_err': 701, 'true_corr': 8}



train Epoch:15/50: 100%|██████████| 103/103 [00:28<00:00,  3.67it/s, avg loss=1.500]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.85it/s, batches loss=1.540]

1.4891671863469211 {'over_corr': 2843, 'total_err': 925, 'true_corr': tensor(204, device='cuda:0')} {'over_corr': 623, 'total_err': 701, 'true_corr': 16}



train Epoch:16/50: 100%|██████████| 103/103 [00:28<00:00,  3.67it/s, avg loss=1.318]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.82it/s, batches loss=1.482]

1.3324302583932877 {'over_corr': 2371, 'total_err': 925, 'true_corr': tensor(231, device='cuda:0')} {'over_corr': 601, 'total_err': 701, 'true_corr': 30}



train Epoch:17/50: 100%|██████████| 103/103 [00:28<00:00,  3.67it/s, avg loss=1.156]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.81it/s, batches loss=1.094]

1.2080601833083413 {'over_corr': 2080, 'total_err': 925, 'true_corr': tensor(242, device='cuda:0')} {'over_corr': 577, 'total_err': 701, 'true_corr': 35}



train Epoch:18/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=1.017]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.73it/s, batches loss=1.157]

1.0925631306388162 {'over_corr': 1821, 'total_err': 925, 'true_corr': tensor(271, device='cuda:0')} {'over_corr': 548, 'total_err': 701, 'true_corr': 54}



train Epoch:19/50: 100%|██████████| 103/103 [00:28<00:00,  3.67it/s, avg loss=0.898]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.74it/s, batches loss=0.896]

1.0008708374066786 {'over_corr': 1583, 'total_err': 925, 'true_corr': tensor(269, device='cuda:0')} {'over_corr': 516, 'total_err': 701, 'true_corr': 58}



train Epoch:20/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=0.806]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.67it/s, batches loss=0.921]

0.9272495792670683 {'over_corr': 1475, 'total_err': 925, 'true_corr': tensor(283, device='cuda:0')} {'over_corr': 504, 'total_err': 701, 'true_corr': 70}



train Epoch:21/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=0.716]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.66it/s, batches loss=0.860]

0.8541995354674079 {'over_corr': 1281, 'total_err': 925, 'true_corr': tensor(286, device='cuda:0')} {'over_corr': 468, 'total_err': 701, 'true_corr': 77}



train Epoch:22/50: 100%|██████████| 103/103 [00:28<00:00,  3.67it/s, avg loss=0.641]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.59it/s, batches loss=0.751]

0.800656187263402 {'over_corr': 1173, 'total_err': 925, 'true_corr': tensor(301, device='cuda:0')} {'over_corr': 455, 'total_err': 701, 'true_corr': 86}



train Epoch:23/50: 100%|██████████| 103/103 [00:28<00:00,  3.67it/s, avg loss=0.571]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.62it/s, batches loss=0.737]

0.7549976747144352 {'over_corr': 1078, 'total_err': 925, 'true_corr': tensor(305, device='cuda:0')} {'over_corr': 441, 'total_err': 701, 'true_corr': 88}



train Epoch:24/50: 100%|██████████| 103/103 [00:28<00:00,  3.67it/s, avg loss=0.517]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.60it/s, batches loss=0.894]

0.7142892581495371 {'over_corr': 985, 'total_err': 925, 'true_corr': tensor(297, device='cuda:0')} {'over_corr': 417, 'total_err': 701, 'true_corr': 95}



train Epoch:25/50: 100%|██████████| 103/103 [00:28<00:00,  3.67it/s, avg loss=0.469]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.52it/s, batches loss=0.788]

0.6817657019604336 {'over_corr': 909, 'total_err': 925, 'true_corr': tensor(313, device='cuda:0')} {'over_corr': 406, 'total_err': 701, 'true_corr': 107}



train Epoch:26/50: 100%|██████████| 103/103 [00:28<00:00,  3.67it/s, avg loss=0.425]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:07<00:00,  5.55it/s, batches loss=0.420]

0.6491995480927554 {'over_corr': 850, 'total_err': 925, 'true_corr': tensor(303, device='cuda:0')} {'over_corr': 393, 'total_err': 701, 'true_corr': 108}



train Epoch:27/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=0.387]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.49it/s, batches loss=0.580]

0.6234647550366141 {'over_corr': 798, 'total_err': 925, 'true_corr': tensor(308, device='cuda:0')} {'over_corr': 374, 'total_err': 701, 'true_corr': 116}



train Epoch:28/50: 100%|██████████| 103/103 [00:28<00:00,  3.65it/s, avg loss=0.351]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.47it/s, batches loss=0.640]

0.6034310155294158 {'over_corr': 768, 'total_err': 925, 'true_corr': tensor(309, device='cuda:0')} {'over_corr': 375, 'total_err': 701, 'true_corr': 111}



train Epoch:29/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=0.321]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.43it/s, batches loss=0.595]

0.5752637420188297 {'over_corr': 724, 'total_err': 925, 'true_corr': tensor(318, device='cuda:0')} {'over_corr': 358, 'total_err': 701, 'true_corr': 125}



train Epoch:30/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=0.295]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.48it/s, batches loss=0.635]

0.5635733645070683 {'over_corr': 692, 'total_err': 925, 'true_corr': tensor(310, device='cuda:0')} {'over_corr': 344, 'total_err': 701, 'true_corr': 127}



train Epoch:31/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=0.271]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.48it/s, batches loss=0.526]

0.5456715984777971 {'over_corr': 639, 'total_err': 925, 'true_corr': tensor(306, device='cuda:0')} {'over_corr': 321, 'total_err': 701, 'true_corr': 132}



train Epoch:32/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=0.247]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.43it/s, batches loss=0.411]

0.5381894328377463 {'over_corr': 610, 'total_err': 925, 'true_corr': tensor(300, device='cuda:0')} {'over_corr': 315, 'total_err': 701, 'true_corr': 131}



train Epoch:33/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=0.226]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.43it/s, batches loss=0.469]

0.5173768191175028 {'over_corr': 556, 'total_err': 925, 'true_corr': tensor(307, device='cuda:0')} {'over_corr': 292, 'total_err': 701, 'true_corr': 139}



train Epoch:34/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=0.210]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.42it/s, batches loss=0.399]

0.5072791657664559 {'over_corr': 545, 'total_err': 925, 'true_corr': tensor(316, device='cuda:0')} {'over_corr': 286, 'total_err': 701, 'true_corr': 143}



train Epoch:35/50: 100%|██████████| 103/103 [00:28<00:00,  3.65it/s, avg loss=0.191]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.40it/s, batches loss=0.553]

0.5001249963586981 {'over_corr': 525, 'total_err': 925, 'true_corr': tensor(317, device='cuda:0')} {'over_corr': 286, 'total_err': 701, 'true_corr': 147}



train Epoch:36/50: 100%|██████████| 103/103 [00:28<00:00,  3.67it/s, avg loss=0.176]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.45it/s, batches loss=0.329]

0.489075054499236 {'over_corr': 501, 'total_err': 925, 'true_corr': tensor(315, device='cuda:0')} {'over_corr': 274, 'total_err': 701, 'true_corr': 147}



train Epoch:37/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=0.163]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.43it/s, batches loss=0.440]

0.4821327904408628 {'over_corr': 476, 'total_err': 925, 'true_corr': tensor(310, device='cuda:0')} {'over_corr': 265, 'total_err': 701, 'true_corr': 147}



train Epoch:38/50: 100%|██████████| 103/103 [00:28<00:00,  3.67it/s, avg loss=0.153]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.40it/s, batches loss=0.482]

0.4759653528982943 {'over_corr': 447, 'total_err': 925, 'true_corr': tensor(310, device='cuda:0')} {'over_corr': 262, 'total_err': 701, 'true_corr': 153}



train Epoch:39/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=0.138]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.38it/s, batches loss=0.500]

0.4696621678092263 {'over_corr': 423, 'total_err': 925, 'true_corr': tensor(305, device='cuda:0')} {'over_corr': 247, 'total_err': 701, 'true_corr': 150}



train Epoch:40/50: 100%|██████████| 103/103 [00:28<00:00,  3.67it/s, avg loss=0.128]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.39it/s, batches loss=0.735]

0.46718160198493436 {'over_corr': 409, 'total_err': 925, 'true_corr': tensor(305, device='cuda:0')} {'over_corr': 238, 'total_err': 701, 'true_corr': 156}



train Epoch:41/50: 100%|██████████| 103/103 [00:28<00:00,  3.67it/s, avg loss=0.117]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.38it/s, batches loss=0.469]

0.45912517607212067 {'over_corr': 401, 'total_err': 925, 'true_corr': tensor(306, device='cuda:0')} {'over_corr': 240, 'total_err': 701, 'true_corr': 154}



train Epoch:42/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=0.108]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.34it/s, batches loss=0.443]

0.45253015038642014 {'over_corr': 383, 'total_err': 925, 'true_corr': tensor(308, device='cuda:0')} {'over_corr': 233, 'total_err': 701, 'true_corr': 158}



train Epoch:43/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=0.101]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.35it/s, batches loss=0.427]

0.44910801472989 {'over_corr': 371, 'total_err': 925, 'true_corr': tensor(311, device='cuda:0')} {'over_corr': 227, 'total_err': 701, 'true_corr': 159}



train Epoch:44/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=0.092]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.39it/s, batches loss=0.316]

0.44341119717467914 {'over_corr': 359, 'total_err': 925, 'true_corr': tensor(309, device='cuda:0')} {'over_corr': 221, 'total_err': 701, 'true_corr': 160}



train Epoch:45/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=0.085]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.37it/s, batches loss=0.470]

0.4448757645758716 {'over_corr': 354, 'total_err': 925, 'true_corr': tensor(305, device='cuda:0')} {'over_corr': 218, 'total_err': 701, 'true_corr': 158}



train Epoch:46/50: 100%|██████████| 103/103 [00:28<00:00,  3.65it/s, avg loss=0.079]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.37it/s, batches loss=0.480]

0.4385344535112381 {'over_corr': 349, 'total_err': 925, 'true_corr': tensor(308, device='cuda:0')} {'over_corr': 218, 'total_err': 701, 'true_corr': 164}



train Epoch:47/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=0.073]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.35it/s, batches loss=0.562]

0.44403391297567973 {'over_corr': 367, 'total_err': 925, 'true_corr': tensor(317, device='cuda:0')} {'over_corr': 225, 'total_err': 701, 'true_corr': 167}



train Epoch:48/50: 100%|██████████| 103/103 [00:28<00:00,  3.65it/s, avg loss=0.071]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.31it/s, batches loss=0.447]

0.4373138587583195 {'over_corr': 359, 'total_err': 925, 'true_corr': tensor(320, device='cuda:0')} {'over_corr': 225, 'total_err': 701, 'true_corr': 168}



train Epoch:49/50: 100%|██████████| 103/103 [00:28<00:00,  3.67it/s, avg loss=0.064]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.33it/s, batches loss=0.463]

0.4357498958706856 {'over_corr': 335, 'total_err': 925, 'true_corr': tensor(312, device='cuda:0')} {'over_corr': 215, 'total_err': 701, 'true_corr': 166}



train Epoch:50/50: 100%|██████████| 103/103 [00:28<00:00,  3.66it/s, avg loss=0.058]    
dev Epoch:1/1: 100%|██████████| 44/44 [00:08<00:00,  5.34it/s, batches loss=0.292]

0.42910330403934827 {'over_corr': 317, 'total_err': 925, 'true_corr': tensor(310, device='cuda:0')} {'over_corr': 200, 'total_err': 701, 'true_corr': 171}



dev Epoch:1/1: 100%|██████████| 108/108 [00:38<00:00,  2.84it/s, batches loss=0.404]

0.5043595328375146 {'over_corr': 3832, 'total_err': 5278, 'true_corr': tensor(1187, device='cuda:0')} {'over_corr': 1710, 'total_err': 3436, 'true_corr': 378}





In [42]:
model = torch.load(save_path)

In [44]:
# import time
# from utils import cal_err

# def split_lines(i, length=20):
#     print(f"第{i}句".center(length, '-'))

# def effectiveness_of_beam(model, test_data, beam_width):
#     begin_time = time.time()
#     matrices = ["over_corr", "total_err", "true_corr"]
#     test_char_level = {key: 0 for key in matrices}
#     test_sent_level = {key: 0 for key in matrices}

#     for i, data in enumerate(test_data, 1):
#         best_score, best_sequence = -float('inf'), None
#         beam = model.generate_with_beam(beam_width, data)
#         for score, seq in beam:
#             if score > best_score:
#                 best_score = score
#                 best_sequence = seq
#         split_lines(i)
#         input_ids = data["input_ids"]
#         labels = data["labels"]
#         attention_mask = data["attention_mask"]

#         raw_model_output = model(input_ids.resize(1, len(input_ids)).to("cuda"),
#                   attention_mask.resize(1, len(input_ids)).to("cuda"))

#         length = len(best_sequence)
#         char_level, sent_level = cal_err(
#             input_ids[:length],
#             torch.tensor(best_sequence),
#             labels[:length],
#             length,
#         )
#         test_char_level = {
#             key: test_char_level[key] + v
#             for key, v in char_level.items()
#         }
#         test_sent_level = {
#             key: test_sent_level[key] + v
#             for key, v in sent_level.items()
#         }

#         print(f"origin sentence:  {tokenizer.decode(input_ids, skip_special_tokens=True)}")
#         print(f"correct sentence: {tokenizer.decode(labels, skip_special_tokens=True)}")
#         print(f"predict sentence: {tokenizer.decode(best_sequence, skip_special_tokens=True)}")
#         print("r-model sentence:", tokenizer.decode(raw_model_output.argmax(dim=-1).squeeze()[:length], skip_special_tokens=True))

#     end_time = time.time()
#     print(f"It cost total {end_time - begin_time} time")

# effectiveness_of_beam(model, test_dataset, 4)

In [41]:
save_path = "weights/BertLstm.pt"
model.save(save_path)

#### BERT + GRU

In [None]:
# The Hyperparameters can be defined in config.py
hidden_size = 1024
num_layers = 2

encoder_model = BertModel.from_pretrained(checkpoint)
decoder_model = DecoderBaseRNN(
    model=nn.GRU,
    input_size=encoder_model.config.hidden_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
)
model = CombineBertModel(encoder_model=encoder_model, decoder_model=decoder_model)

train(model, tokenizer, train_data_loader, dev_data_loader)

In [None]:
trainer.train(
    dataloader=train_data_loader, epoch=epochs, test_dataloader=dev_data_loader
)
trainer.test(test_data_loader)

In [None]:
effectiveness_of_beam(model, test_dataset, 4)

In [None]:
save_path = "weights/BertGru.pt"
model.save(save_path)

#### BERT + Transformer

In [None]:
nhead = 2
num_encoder_layers = 2
num_decoder_layers = 2

encoder_model = BertModel.from_pretrained(checkpoint)
decoder_model = DecoderTransformer(
    input_size=encoder_model.config.hidden_size,
    nhead=nhead,
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
)
model = CombineBertModel(encoder_model=encoder_model, decoder_model=decoder_model)

train(model, tokenizer, train_data_loader, dev_data_loader)

In [None]:
trainer.train(
    dataloader=train_data_loader, epoch=epochs, test_dataloader=dev_data_loader
)
trainer.test(test_data_loader)

In [None]:
save_path = "weights/BertTransformer.pt"
model.save(save_path)