In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from transformers import BertModel, BertTokenizer

# Implemented by myself
from config import *
from data_processer import CSCDataset, split_torch_dataset
from models import CombineBertModel, DecoderBaseRNN, DecoderTransformer, Trainer
from utils import cal_err

In [2]:
tokenizer = BertTokenizer.from_pretrained(checkpoint)

In [3]:
train_dataset = CSCDataset([SIGHAN_train_dir_err, SIGHAN_train_dir_corr], tokenizer)

test_dataset = CSCDataset([SIGHAN_train_dir_err14, SIGHAN_train_dir_corr14], tokenizer)

preprocessing sighan dataset: 2339it [00:00, 65705.43it/s]
preprocessing sighan dataset: 100%|████████████████████████████████████████████| 2339/2339 [00:00<00:00, 137932.89it/s]


共2339句，共73264字，最长的句子有171字


preprocessing sighan dataset: 3437it [00:00, 94304.27it/s]
preprocessing sighan dataset: 100%|█████████████████████████████████████████████| 3437/3437 [00:00<00:00, 97672.81it/s]

共3437句，共170330字，最长的句子有258字





In [4]:
train_data, dev_data = split_torch_dataset(train_dataset, 0.3)

train_data_loader = DataLoader(train_data, num_workers=4, shuffle=True, batch_size=16)

dev_data_loader = DataLoader(dev_data, num_workers=4, shuffle=True, batch_size=16)

test_data_loader = DataLoader(test_dataset, num_workers=4, shuffle=True, batch_size=32)

In [5]:
epochs = 35

# tokenizer = BertTokenizer.from_pretrained(checkpoint)
encoder_model = BertModel.from_pretrained(checkpoint)

# The Hyperparameters can be defined in config.py
hidden_size = 1024
num_layers = 2

decoder_model = DecoderBaseRNN(
    model=nn.LSTM,
    input_size=encoder_model.config.hidden_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
)

model = CombineBertModel(encoder_model=encoder_model, decoder_model=decoder_model)

optimizer = AdamW(model.parameters(), lr=learning_rate)
trainer = Trainer(model=model, tokenizer=tokenizer, optimizer=optimizer)

In [None]:
trainer.train(dataloader=train_data_loader, epoch=1, test_dataloader=dev_data_loader)
trainer.test(test_data_loader)

Epoch:1/1:   3%|█▋                                                         | 3/103 [01:34<52:57, 31.78s/it, loss=9.926]