In [1]:
from src.dataset import BertGenerateDataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from torch.nn.utils.rnn import pad_sequence
from src.model import BertGenerate
import torch
from general_trainer.loss import MaskedSoftmaxCELoss
from transformers import BertTokenizer, BertConfig
from general_trainer.trainer import GeneralTrainer
from general_trainer.optim_schedule import ScheduleOptim

device = 'cuda'
def collate_fn(batch):
    sorted_batch = sorted(batch, key=lambda x: len(x[0]), reverse=True)
    input_ids, input_valid_lens, label_ids, label_valid_lens = zip(*sorted_batch)
    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=0).to(device)
    label_ids = pad_sequence(label_ids, batch_first=True, padding_value=0).to(device)
    input_valid_lens = torch.LongTensor(input_valid_lens).to(device)
    label_valid_lens = torch.LongTensor(label_valid_lens).to(device)
    return input_ids, input_valid_lens, label_ids, label_valid_lens

dataset = BertGenerateDataset('data_.json')
train_data, test_data = train_test_split(dataset, test_size=0.2)
train_loader = DataLoader(train_data, batch_size=2, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_data, batch_size=2, shuffle=False, collate_fn=collate_fn)

model = BertGenerate()
model.to(device)

loss = MaskedSoftmaxCELoss()
tokenize = BertTokenizer.from_pretrained('bert-base-chinese')
tgt_vocab = tokenize.vocab
def cal_loss(data, model):
    X, X_valid_len, Y, Y_valid_len = data
    bos = torch.tensor([tgt_vocab['[SEP]']] * Y.shape[0],
                        device=device).reshape(-1, 1)
    dec_input = torch.cat([bos, Y[:, :-1]], 1)  # Teacher forcing
    Y_hat, _ = model(X, dec_input, X_valid_len)
    return loss(Y_hat, Y, Y_valid_len).sum()
config = BertConfig.from_pretrained('bert-base-chinese')

optimizer = torch.optim.Adam(model.parameters(), 5e-5)
scheduler = ScheduleOptim(optimizer, config.hidden_size, 10000)
trainer = GeneralTrainer(model, MaskedSoftmaxCELoss, train_loader, test_loader, scheduler=scheduler, cal_loss=cal_loss)
trainer.training()

100%|██████████| 10/10 [00:00<00:00, 582.44it/s]


Total number of parameters:  248088968
Using device:  cuda


EP_train:0: 100%|██████████| 4/4 [00:04<00:00,  1.17s/it]


EP_train:0 avg_loss:15.4664


EP_train:1: 100%|██████████| 4/4 [00:00<00:00,  4.51it/s]


EP_train:1 avg_loss:15.5947


EP_train:2: 100%|██████████| 4/4 [00:00<00:00,  5.29it/s]


EP_train:2 avg_loss:15.3642


EP_train:3: 100%|██████████| 4/4 [00:00<00:00,  4.82it/s]


EP_train:3 avg_loss:14.4406


EP_train:4: 100%|██████████| 4/4 [00:00<00:00,  4.82it/s]


EP_train:4 avg_loss:14.6792


EP_train:5: 100%|██████████| 4/4 [00:00<00:00,  4.50it/s]


EP_train:5 avg_loss:14.5216


EP_train:6: 100%|██████████| 4/4 [00:00<00:00,  4.61it/s]


EP_train:6 avg_loss:14.5091


EP_train:7: 100%|██████████| 4/4 [00:00<00:00,  4.38it/s]


EP_train:7 avg_loss:14.3598


EP_train:8: 100%|██████████| 4/4 [00:00<00:00,  4.49it/s]


EP_train:8 avg_loss:15.4915


EP_train:9: 100%|██████████| 4/4 [00:00<00:00,  4.90it/s]

EP_train:9 avg_loss:15.1087



