In [1]:
import torch
import torch.nn as nn

# Implemented by myself
from config import *
from data_processer import CSCDataset, split_torch_dataset
from models import DecoderBaseRNN, DecoderTransformer
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import BertModel, BertTokenizer

In [7]:
class TempBertModel(nn.Module):
    def __init__(self, encoder_model, decoder_model, output_size=2):
        super(CombineBertModel, self).__init__()
        self.encoder = encoder_model
        self.decoder = decoder_model

        self.linear = nn.Linear(decoder_model.config.hidden_size, output_size)

    def forward(self, src, src_mask):
        x = self.encoder(src, attention_mask=src_mask).last_hidden_state
        x = self.decoder(x)

        x = self.linear(x)

        return x

    def save(self, store_path):
        torch.save(self, store_path)

    def save_state(self, store_path):
        torch.save(self.state_dict(), store_path)

In [8]:
tokenizer = BertTokenizer.from_pretrained(checkpoint)

In [None]:
train_dataset = CSCDataset([SIGHAN_train_dir_err, SIGHAN_train_dir_corr], tokenizer)
test_dataset = CSCDataset([SIGHAN_train_dir_err14, SIGHAN_train_dir_corr14], tokenizer)

In [None]:
# split data
train_data, dev_data = split_torch_dataset(train_dataset, 0.3)

train_data_loader = DataLoader(train_data, num_workers=4, shuffle=True, batch_size=16)
dev_data_loader = DataLoader(dev_data, num_workers=4, shuffle=True, batch_size=16)
test_data_loader = DataLoader(test_dataset, num_workers=4, shuffle=True, batch_size=32)

In [None]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
from utils import cal_err


class Trainer:
    def __init__(self, model, tokenizer, optimizer):
        self.model = model
        self.tokenizer = tokenizer
        self.optimizer = optimizer

    def train(self, dataloader, epoch, test_dataloader=None, printepoch=float("inf")):
        self.iteration(dataloader, epoch, test_dataloader, printepoch)

    def test(self, dataloader):
        matrices = ["over_corr", "total_err", "true_corr"]
        self.test_char_level = {key: 0 for key in matrices}
        self.test_sent_level = {key: 0 for key in matrices}
        with torch.no_grad():
            self.iteration(dataloader, train=False)

    def iteration(
        self,
        dataloader,
        epochs=1,
        test_dataloader=None,
        printepoch=float("inf"),
        train=True,
    ):
        mode = "train" if train else "dev"
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(device)

        for epoch in range(epochs):
            self.model.train() if train else self.model.eval()
            total_loss = 0

            progress_bar = tqdm(
                enumerate(dataloader),
                desc=f"{mode} Epoch:{epoch+1}/{epochs}",
                total=len(dataloader),
            )
            for i, batch in progress_bar:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device).type(torch.float)
                labels = batch["labels"].to(device)
                new_labels = labels == input_ids

                outputs = self.model(input_ids, src_mask=attention_mask)
                logits = outputs.permute(0, 2, 1)  # (batch_size, vocab_size, seq_len)

                # 反向传播在这，故labels不需要传入模型
                loss = F.cross_entropy(
                    logits, new_labels, ignore_index=self.tokenizer.pad_token_id
                )
                total_loss += loss.item()

                if train:
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                else:
                    t = torch.argmax(outputs, dim=-1)
                    nt = t * attention_mask

                progress_bar.set_postfix({"batches loss": "{:.3f}".format(loss.item())})
                if i == len(progress_bar) - 1:
                    progress_bar.set_postfix(
                        {"avg loss": "{:.3f}".format(total_loss / len(dataloader))}
                    )

            # print(f"Epoch {epoch+1} Loss: {total_loss / len(dataloader)}")

            # dev
            if test_dataloader:
                self.test(test_dataloader)

        if mode == "dev":
            print(
                total_loss / len(dataloader),
            )

In [None]:
# The Hyperparameters can be defined in config.py
hidden_size = 1024
num_layers = 2
output_size = 2

encoder_model = BertModel.from_pretrained(checkpoint)
decoder_model = DecoderBaseRNN(
    model=nn.LSTM,
    input_size=encoder_model.config.hidden_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
)
model = TempBertModel(encoder_model=encoder_model, decoder_model=decoder_model, output_size=output_size)

optimizer = AdamW(model.parameters(), lr=learning_rate)
trainer = Trainer(model=model, tokenizer=tokenizer, optimizer=optimizer)

In [None]:
trainer.train(dataloader=train_data_loader, epoch=epochs)
# trainer.test(test_data_loader)