In [1]:
import argparse
import pandas as pd
from tqdm import tqdm
from transformers import BertTokenizer, BertForTokenClassification
from torch.utils.data import DataLoader
from tools.data_utils import get_unique_tags
from tools.data import SequenceLabelingDataset, collate_fn, PredSequenceLabelingDataset, pred_collate_fn
from tools.data_utils import read_data
from torch.utils.data import RandomSampler, SequentialSampler
from train_and_eval import Trainer
import torch
from tools.utils import get_logger
from config import Config


def load_data(opt):
    """
    根据设置的参数加载数据
    包括得到标签信息(如映射等)和dataset以及dataloader

    """

    # 加载数据
    train_data, dev_data, test_data = read_data(opt.data_dir)

    # 得到标签相关信息, 一方面为了将标签转换为索引, 另一方面为BERT初始化时使用
    unique_tags, labels_to_ids, ids_to_labels = get_unique_tags(train_data)
    opt.unique_tags = unique_tags
    opt.labels_to_ids = labels_to_ids
    opt.ids_to_labels = ids_to_labels

    train_dataset = SequenceLabelingDataset(data=train_data, labels_to_ids=labels_to_ids, tokenizer=opt.tokenizer, max_length=opt.max_length)
    dev_dataset = SequenceLabelingDataset(data=dev_data, labels_to_ids=labels_to_ids, tokenizer=opt.tokenizer, max_length=opt.max_length)
    test_dataset = PredSequenceLabelingDataset(data=test_data, tokenizer=opt.tokenizer, max_length=opt.max_length)
    
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, batch_size=32, sampler=train_sampler, collate_fn=collate_fn)

    dev_sampler = SequentialSampler(dev_dataset)
    dev_dataloader = DataLoader(dev_dataset, batch_size=32, sampler=dev_sampler, collate_fn=collate_fn)

    test_dataloader = DataLoader(test_dataset, batch_size=32, sampler=dev_sampler, collate_fn=pred_collate_fn)

    return train_dataloader, dev_dataloader, test_dataloader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
opt = Config()
logger = get_logger(opt.log_path)
opt.logger = logger

In [3]:
# tokenizer用于处理文本
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
opt.tokenizer = tokenizer

# 加载数据
opt.logger.info("开始加载数据")
train_dataloader, dev_dataloader, test_dataloader = load_data(opt)
opt.logger.info("加载数据完成")
opt.logger.info(f"标签信息: {opt.labels_to_ids}")

2022-07-10 09:55:11 - INFO: 开始加载数据
2022-07-10 09:55:11 - INFO: 加载数据完成
2022-07-10 09:55:11 - INFO: 标签信息: {'B-address': 0, 'B-book': 1, 'B-company': 2, 'B-game': 3, 'B-government': 4, 'B-movie': 5, 'B-name': 6, 'B-organization': 7, 'B-position': 8, 'B-scene': 9, 'I-address': 10, 'I-book': 11, 'I-company': 12, 'I-game': 13, 'I-government': 14, 'I-movie': 15, 'I-name': 16, 'I-organization': 17, 'I-position': 18, 'I-scene': 19, 'O': 20}


In [4]:
# 模型
model = BertForTokenClassification.from_pretrained("bert-base-chinese", num_labels=len(opt.unique_tags))
# 设置设备, 数据已经移动到config中设置的设备上
# 因此需要将模型也移动到相同设备上
model = model.to(opt.device)
# 模型训练
trainer = Trainer(model=model, opt=opt)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-c

In [5]:
for batch in test_dataloader:
    print(batch)
    break

{'input_ids': tensor([[ 101, 1724, 2335,  ..., 5865,  868,  511],
        [ 101, 2225, 3189,  ...,    0,    0,    0],
        [ 101, 7218, 1545,  ...,    0,    0,    0],
        ...,
        [ 101,  100,  157,  ...,    0,    0,    0],
        [ 101,  517, 4868,  ...,    0,    0,    0],
        [ 101,  123,  121,  ...,    0,    0,    0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'sent_length': tensor([48, 31, 17, 21, 40, 26, 15, 49, 44, 46, 26, 41, 49, 45, 37, 45, 36, 49,
        22, 40, 48,  8, 11, 21, 45, 47, 37, 42, 49, 25, 43, 38])}


In [6]:
pred_result = trainer.predict(test_dataloader)

2022-07-10 09:55:43 - INFO: Start Predicting...
Predicting: 100%|██████████| 42/42 [00:02<00:00, 17.63it/s]
2022-07-10 09:55:46 - INFO: Predicting Done.
2022-07-10 09:55:46 - INFO: Predict result save to ./predict_results/predict.txt
