In [1]:
from transformers import AutoModelForTokenClassification, AutoTokenizer,DataCollatorForTokenClassification
import torch
import evaluate  # pip install evaluate
import seqeval   # pip install seqeval

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 查看数据
from datasets import load_dataset
data = load_dataset("doushabao4766/msra_ner_k_V3")
label_names = data["train"].features["ner_tags"].feature.names
print(data["train"].features)
print("标签种类：", label_names)
# O: 未识别 B/I-PER：人名开始/中间/结束 B/I-ORG：机构开始/结束 B/I-LOC 位置开始/结束
print(data["train"][0]["ner_tags"])
print(len(data["train"][0]["tokens"]))
print(data["train"][0]["knowledge"])


{'id': Value(dtype='string', id=None), 'tokens': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'ner_tags': Sequence(feature=ClassLabel(names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC'], id=None), length=-1, id=None), 'knowledge': Value(dtype='string', id=None)}
标签种类： ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
50



In [20]:
def data_input_proc(item):
    tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-chinese')
    input_data = tokenizer(item['tokens'], is_split_into_words=True, truncation=True, padding="max_length", add_special_tokens=False, max_length=512)
    input_data['labels'] = item['ner_tags']
    return input_data

ds2 = data.map(data_input_proc,batched=True)
# 记录转换为pytorch
ds2.set_format('torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
print(ds2.shape)

Map: 100%|██████████| 45001/45001 [00:09<00:00, 4909.00 examples/s]
Map: 100%|██████████| 3443/3443 [00:00<00:00, 4613.06 examples/s]

{'train': (45001, 8), 'test': (3443, 8)}





In [21]:
tags = data["train"].features["ner_tags"].feature.names
id2lbl = {i:tag for i, tag in enumerate(tags)}
lbl2id = {tag:i for i, tag in enumerate(tags)}

In [22]:
print(data)

model = AutoModelForTokenClassification.from_pretrained('google-bert/bert-base-chinese', num_labels=7, id2label=id2lbl,label2id=lbl2id)
model


Some weights of BertForTokenClassification were not initialized from the model checkpoint at google-bert/bert-base-chinese and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'ner_tags', 'knowledge'],
        num_rows: 45001
    })
    test: Dataset({
        features: ['id', 'tokens', 'ner_tags', 'knowledge'],
        num_rows: 3443
    })
})


BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12

In [23]:

# 评估指标
from sklearn.metrics import classification_report
import numpy as np

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [id2lbl[pred] for (pred, lab) in zip(prediction, label) if lab != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [id2lbl[lab] for (pred, lab) in zip(prediction, label) if lab != -100]
        for prediction, label in zip(predictions, labels)
    ]

    # 使用 seqeval 来计算实体级别指标
    seqeval_metric = evaluate.load("seqeval")
    results = seqeval_metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [27]:
## 模型训练
### TrainingArguments
from transformers import TrainingArguments, Trainer
args = TrainingArguments(
    output_dir="ner_train",  # 模型训练工作目录（tensorboard，临时模型存盘文件，日志）
    num_train_epochs=1,  # 训练 epoch
    save_safetensors=False,  # 设置False保存文件可以通过torch.load加载
    per_device_train_batch_size=8,  # 训练批次
    per_device_eval_batch_size=8,
    report_to='tensorboard',  # 训练输出记录
    eval_strategy="epoch",
)
tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-chinese')
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, padding=True)

trainer = Trainer(
    model,
    args,
    train_dataset=ds2['train'].select(range(5000)),
    eval_dataset=ds2['test'].select(range(500)),
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

In [28]:
trainer.train()

Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,0.0833,0.057909,0.809359,0.854529,0.831331,0.980037


TrainOutput(global_step=625, training_loss=0.07402871475219727, metrics={'train_runtime': 19491.0342, 'train_samples_per_second': 0.257, 'train_steps_per_second': 0.032, 'total_flos': 1306542842880000.0, 'train_loss': 0.07402871475219727, 'epoch': 1.0})

In [32]:
result = trainer.predict(ds2['test'].select(range(20)))

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 0, 5, 5, 5, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
[   0    0    0    0    0    0    0    0    0    0    3    4    4    4
    4    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    5    6    0    0    0    0
    0    0    0    0    0    0    0    5    6    0    0    0    5    5
    5    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    5    6    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0 -100 -100 -100

In [64]:
#print(type(ds2['test'][0]['labels']))
#print(result.label_ids[0])
def get_ent_list(tokens, labels):
    ner_tags = data["train"].features["ner_tags"].feature.names
    ner_list = [ner_tags[label] for label in labels]
    ent_dict = zip(tokens, ner_list)
    filtered = [(tok, tag) for tok, tag in ent_dict if tag != 'O']
    return filtered

def merge_entities_to_dict(filtered):
    ent_dict = {}
    current_word = ""
    current_tag = ""

    for token, tag in filtered:
        if tag.startswith('B-'):
            if current_word:
                ent_dict[current_word] = current_tag
            current_word = token
            current_tag = tag[2:]  # 去掉"B-"
        elif tag.startswith('I-') and current_tag == tag[2:]:
            current_word += token
        else:
            # 出现不连续或意外标签时，先收尾当前
            if current_word:
                ent_dict[current_word] = current_tag
            current_word = ""
            current_tag = ""

    if current_word:
        ent_dict[current_word] = current_tag

    return ent_dict

single_result = get_ent_list(data["test"][0]["tokens"], ds2['test'][0]['labels'])
single_result = merge_entities_to_dict(single_result)
print(single_result)


{'中共中央': 'ORG', '中国致公党十一大': 'ORG', '中国致公党第十一次全国代表大会': 'ORG', '中国共产党中央委员会': 'ORG', '致公党': 'ORG'}
