In [None]:
import torch
import transformer
from transformer import Bert
import numpy as np
import torch.nn as nn

### 加载数据

In [None]:
from tqdm import tqdm

pinyin_list = []
hanzi_list = []
vocab = set()
max_length = 64

with open("./zh.tsv", errors='ignore', encoding='utf-8') as f:
    contexts = f.readlines()
    for line in contexts:
        line = line .strip().split(" ")
        pinyin = line[1].split(" ")
        hanzi = line[2].split(" ")
        for p,h in zip(pinyin,hanzi):
            vocab.add(p)
            vocab.add(h)
        pinyin = pinyin + ["PAD"]*(max_length-len(pinyin))
        hanzi = hanzi + ["PAD"]*(max_length-len(hanzi))
        if len(pinyin) <= max_length:
            pinyin_list.append(pinyin)
            hanzi_list.append(hanzi)

vocab = ["PAD"] + list(sorted(vocab))
vocab_size = len(vocab)

pinyin_list = pinyin_list[:3000]
hanzi_list = hanzi_list[:3000]

def get_token_ids():
    pinyin_ids = []
    hanzi_ids = []
    for pinyin,hanzi in zip(tqdm(pinyin_list,hanzi_list)):
        pinyin_ids.append([vocab.index(p) for p in pinyin])
        hanzi_ids.append([vocab.index(h) for h in hanzi])
    return pinyin_ids,hanzi_ids

class TextSampleDS(torch.utils.data.Dataset):
    def __init__(self, pinyin_ids, hanzi_ids):
        super().__init__()
        self.pinyin_ids = pinyin_ids
        self.hanzi_ids = hanzi_ids
    
    def __len__(self):
        return len(self.pinyin_ids)
    
    def __getitem__(self, idx):
        return torch.tensor(self.pinyin_ids[idx]), torch.tensor(self.hanzi_ids[idx])
    
from torch.utils.data import DataLoader
loader = DataLoader(TextSampleDS(*get_token_ids()), batch_size=32, shuffle=True)

### 训练代码

In [None]:
class BertModel(nn.Module):
    def __init__(self, bert_encoder, d_model, vocab_size, dropout=0.1):
        self.bert_encoder = bert_encoder
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.dropout = nn.Dropout(dropout)
        self.output_layer = nn.Linear(d_model, vocab_size)
    
    def forward(self, input_ids, attention_mask=None):
        encoder = self.bert_encoder(input_ids, attention_mask=attention_mask)
        encoder = self.dropout(encoder)
        logits = self.output_layer(encoder)
        return logits
bert_encoder = Bert.BERT(vocab_size, d_model=768)
model = BertModel(bert_encoder, 768, vocab_size, dropout=0.1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-8)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
            T_max = 2400, eta_min = 2e-6, last_epoch=-1)
from accelerate import Accelerator
accelerator = Accelerator()
model, optimizer, loader, lr_scheduler = accelerator.prepare(model, optimizer, loader, lr_scheduler)
criterion = nn.CrossEntropyLoss()

for epoch in range(5):
    model.train()
    pbar = tqdm(loader)
    for pinyin_ids, hanzi_ids in pbar:
        optimizer.zero_grad()
        pinyin_ids = pinyin_ids.to(device)
        hanzi_ids = hanzi_ids.to(device)
        logits = model(pinyin_ids)
        loss = criterion(logits.view(-1, logits.size(-1)), hanzi_ids.view(-1))
        accelerator.backward(loss)
        optimizer.step()
        lr_scheduler.step()
        pbar.set_description(f"epoch:{epoch+1}, train_loss:{loss.item():.4f},
                             lr:{lr_scheduler.get_last_lr()[0]:.6f}")
        
torch.save(model.state_dict(), "./model.pth")
#model.load_state_dict(torch.load("./model.pth"))

### 预测

In [None]:
with torch.no_grad():
    model.eval()
    pred = torch.softmax(pred, dim=-1)
    pred = torch.argmax(pred, dim=-1)
    
