In [1]:
import warnings
warnings.filterwarnings("ignore")

import os
import jieba
import torch
import pickle
import torch.nn as nn
import torch.optim as optim
import pandas as pd

from ark_nlp.model.ner.global_pointer_bert import GlobalPointerBert
from ark_nlp.model.ner.global_pointer_bert import GlobalPointerBertConfig
from ark_nlp.model.ner.global_pointer_bert import Dataset
from ark_nlp.model.ner.global_pointer_bert import Task
from ark_nlp.model.ner.global_pointer_bert import get_default_model_optimizer
from ark_nlp.model.ner.global_pointer_bert import Tokenizer

### 一、数据读入与处理

#### 1. 数据读入

In [2]:
import os
from ark_nlp.factory.utils.conlleval import get_entity_bio


datalist = []
with open('./data/train_data/train.txt', 'r', encoding='utf-8') as f:
    lines = f.readlines()
    lines.append('\n')
    
    text = []
    labels = []
    label_set = set()
    
    for line in lines: 
        if line == '\n':                
            text = ''.join(text)
            entity_labels = []
            for _type, _start_idx, _end_idx in get_entity_bio(labels, id2label=None):
                entity_labels.append({
                    'start_idx': _start_idx,
                    'end_idx': _end_idx,
                    'type': _type,
                    'entity': text[_start_idx: _end_idx+1]
                })
                
            if text == '':
                continue
            
            datalist.append({
                'text': text,
                'label': entity_labels
            })
            
            text = []
            labels = []
            
        elif line == '  O\n':
            text.append(' ')
            labels.append('O')
        else:
            line = line.strip('\n').split()
            if len(line) == 1:
                term = ' '
                label = line[0]
            else:
                term, label = line
            text.append(term)
            label_set.add(label.split('-')[-1])
            labels.append(label)

In [3]:
# 这里随意分割了一下看指标，建议实际使用sklearn分割或者交叉验证

train_data_df = pd.DataFrame(datalist)
train_data_df['label'] = train_data_df['label'].apply(lambda x: str(x))

dev_data_df = pd.DataFrame(datalist[-400:])
dev_data_df['label'] = dev_data_df['label'].apply(lambda x: str(x))

In [4]:
ner_train_dataset = Dataset(train_data_df, categories=label_set)
ner_dev_dataset = Dataset(dev_data_df, categories=ner_train_dataset.categories)

#### 2. 词典创建和生成分词器

In [5]:
tokenizer = Tokenizer(vocab='hfl/chinese-bert-wwm', max_seq_len=128)

#### 3. ID化

In [6]:
ner_train_dataset.convert_to_ids(tokenizer)
ner_dev_dataset.convert_to_ids(tokenizer)

<br>

### 二、模型构建

#### 1. 模型参数设置

In [7]:
config = GlobalPointerBertConfig.from_pretrained('hfl/chinese-roberta-wwm-ext-large', 
                                                 num_labels=len(ner_train_dataset.cat2id))

#### 2. 模型创建

In [8]:
torch.cuda.empty_cache()

In [9]:
dl_module = GlobalPointerBert.from_pretrained('hfl/chinese-roberta-wwm-ext-large', 
                                              config=config)

Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext-large were not used when initializing GlobalPointerBert: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing GlobalPointerBert from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GlobalPointerBert from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GlobalPointerBert were not initialized from the model checkpoint at hfl/chinese-roberta-wwm-ext-lar

<br>

### 三、任务构建

#### 1. 任务参数和必要部件设定

In [10]:
# 设置运行次数
num_epoches = 10
batch_size = 32

In [11]:
optimizer = get_default_model_optimizer(dl_module)

#### 2. 任务创建

In [12]:
model = Task(dl_module, optimizer, 'gpce', cuda_device=0)

#### 3. 训练

In [13]:
model.fit(ner_train_dataset, 
          ner_dev_dataset,
          lr=2e-5,
          epochs=num_epoches, 
          batch_size=batch_size
         )

  8%|██████▎                                                                        | 100/1250 [00:42<07:57,  2.41it/s]

[99/1250],train loss is:2.869875


 16%|████████████▋                                                                  | 200/1250 [01:23<07:13,  2.42it/s]

[199/1250],train loss is:1.706675


 24%|██████████████████▉                                                            | 300/1250 [02:05<06:37,  2.39it/s]

[299/1250],train loss is:1.264363


 32%|█████████████████████████▎                                                     | 400/1250 [02:47<05:54,  2.40it/s]

[399/1250],train loss is:1.032047


 40%|███████████████████████████████▌                                               | 500/1250 [03:29<05:13,  2.39it/s]

[499/1250],train loss is:0.893194


 48%|█████████████████████████████████████▉                                         | 600/1250 [04:11<04:33,  2.37it/s]

[599/1250],train loss is:0.797776


 56%|████████████████████████████████████████████▏                                  | 700/1250 [04:53<03:49,  2.39it/s]

[699/1250],train loss is:0.727979


 64%|██████████████████████████████████████████████████▌                            | 800/1250 [05:35<03:11,  2.36it/s]

[799/1250],train loss is:0.675154


 72%|████████████████████████████████████████████████████████▉                      | 900/1250 [06:18<02:27,  2.37it/s]

[899/1250],train loss is:0.633732


 80%|██████████████████████████████████████████████████████████████▍               | 1000/1250 [07:00<01:45,  2.38it/s]

[999/1250],train loss is:0.599352


 88%|████████████████████████████████████████████████████████████████████▋         | 1100/1250 [07:42<01:02,  2.39it/s]

[1099/1250],train loss is:0.571305


 96%|██████████████████████████████████████████████████████████████████████████▉   | 1200/1250 [08:24<00:21,  2.38it/s]

[1199/1250],train loss is:0.547438


100%|██████████████████████████████████████████████████████████████████████████████| 1250/1250 [08:45<00:00,  2.38it/s]


epoch:[0],train loss is:0.537170 

eval loss is 0.247390, precision is:5236.0, recall is:12980.0, f1_score is:0.8067796610169492


  8%|██████▎                                                                        | 100/1250 [00:42<08:01,  2.39it/s]

[99/1250],train loss is:0.273027


 16%|████████████▋                                                                  | 200/1250 [01:24<07:38,  2.29it/s]

[199/1250],train loss is:0.275952


 24%|██████████████████▉                                                            | 300/1250 [02:06<06:42,  2.36it/s]

[299/1250],train loss is:0.274062


 32%|█████████████████████████▎                                                     | 400/1250 [02:48<05:56,  2.39it/s]

[399/1250],train loss is:0.272838


 40%|███████████████████████████████▌                                               | 500/1250 [03:30<05:15,  2.38it/s]

[499/1250],train loss is:0.272484


 48%|█████████████████████████████████████▉                                         | 600/1250 [04:13<04:34,  2.36it/s]

[599/1250],train loss is:0.271448


 56%|████████████████████████████████████████████▏                                  | 700/1250 [04:55<03:53,  2.35it/s]

[699/1250],train loss is:0.270597


 64%|██████████████████████████████████████████████████▌                            | 800/1250 [05:37<03:09,  2.38it/s]

[799/1250],train loss is:0.269977


 72%|████████████████████████████████████████████████████████▉                      | 900/1250 [06:19<02:27,  2.37it/s]

[899/1250],train loss is:0.269695


 80%|██████████████████████████████████████████████████████████████▍               | 1000/1250 [07:04<01:52,  2.21it/s]

[999/1250],train loss is:0.269060


 88%|████████████████████████████████████████████████████████████████████▋         | 1100/1250 [07:49<01:05,  2.29it/s]

[1099/1250],train loss is:0.268346


 96%|██████████████████████████████████████████████████████████████████████████▉   | 1200/1250 [08:33<00:22,  2.25it/s]

[1199/1250],train loss is:0.268147


100%|██████████████████████████████████████████████████████████████████████████████| 1250/1250 [08:56<00:00,  2.33it/s]


epoch:[1],train loss is:0.267854 

eval loss is 0.220065, precision is:5414.0, recall is:13095.0, f1_score is:0.8268804887361588


  8%|██████▎                                                                        | 100/1250 [00:44<08:30,  2.25it/s]

[99/1250],train loss is:0.245692


 16%|████████████▋                                                                  | 200/1250 [01:28<07:50,  2.23it/s]

[199/1250],train loss is:0.244454


 24%|██████████████████▉                                                            | 300/1250 [02:13<07:03,  2.24it/s]

[299/1250],train loss is:0.245184


 32%|█████████████████████████▎                                                     | 400/1250 [02:58<06:13,  2.27it/s]

[399/1250],train loss is:0.245402


 40%|███████████████████████████████▌                                               | 500/1250 [03:42<05:34,  2.24it/s]

[499/1250],train loss is:0.244520


 48%|█████████████████████████████████████▉                                         | 600/1250 [04:26<04:47,  2.26it/s]

[599/1250],train loss is:0.244595


 56%|████████████████████████████████████████████▏                                  | 700/1250 [05:11<04:02,  2.27it/s]

[699/1250],train loss is:0.244687


 64%|██████████████████████████████████████████████████▌                            | 800/1250 [05:55<03:18,  2.26it/s]

[799/1250],train loss is:0.244319


 72%|████████████████████████████████████████████████████████▉                      | 900/1250 [06:42<02:38,  2.20it/s]

[899/1250],train loss is:0.243830


 80%|██████████████████████████████████████████████████████████████▍               | 1000/1250 [07:28<01:56,  2.15it/s]

[999/1250],train loss is:0.243709


 88%|████████████████████████████████████████████████████████████████████▋         | 1100/1250 [08:13<01:07,  2.22it/s]

[1099/1250],train loss is:0.243775


 96%|██████████████████████████████████████████████████████████████████████████▉   | 1200/1250 [08:58<00:22,  2.24it/s]

[1199/1250],train loss is:0.243575


100%|██████████████████████████████████████████████████████████████████████████████| 1250/1250 [09:20<00:00,  2.23it/s]


epoch:[2],train loss is:0.243547 

eval loss is 0.197997, precision is:5501.0, recall is:13031.0, f1_score is:0.8442943749520374


  8%|██████▎                                                                        | 100/1250 [00:45<08:39,  2.22it/s]

[99/1250],train loss is:0.218538


 16%|████████████▋                                                                  | 200/1250 [01:30<07:59,  2.19it/s]

[199/1250],train loss is:0.220662


 24%|██████████████████▉                                                            | 300/1250 [02:15<07:10,  2.21it/s]

[299/1250],train loss is:0.221111


 32%|█████████████████████████▎                                                     | 400/1250 [03:00<06:31,  2.17it/s]

[399/1250],train loss is:0.222321


 40%|███████████████████████████████▌                                               | 500/1250 [03:45<05:41,  2.20it/s]

[499/1250],train loss is:0.222899


 48%|█████████████████████████████████████▉                                         | 600/1250 [04:28<04:37,  2.34it/s]

[599/1250],train loss is:0.223524


 56%|████████████████████████████████████████████▏                                  | 700/1250 [05:10<03:52,  2.36it/s]

[699/1250],train loss is:0.223757


 64%|██████████████████████████████████████████████████▌                            | 800/1250 [05:53<03:10,  2.36it/s]

[799/1250],train loss is:0.224179


 72%|████████████████████████████████████████████████████████▉                      | 900/1250 [06:35<02:28,  2.36it/s]

[899/1250],train loss is:0.224411


 80%|██████████████████████████████████████████████████████████████▍               | 1000/1250 [07:17<01:46,  2.34it/s]

[999/1250],train loss is:0.225068


 88%|████████████████████████████████████████████████████████████████████▋         | 1100/1250 [08:00<01:04,  2.33it/s]

[1099/1250],train loss is:0.225215


 96%|██████████████████████████████████████████████████████████████████████████▉   | 1200/1250 [08:43<00:21,  2.34it/s]

[1199/1250],train loss is:0.225512


100%|██████████████████████████████████████████████████████████████████████████████| 1250/1250 [09:04<00:00,  2.30it/s]


epoch:[3],train loss is:0.225702 

eval loss is 0.177722, precision is:5537.0, recall is:12901.0, f1_score is:0.8583830710797613


  8%|██████▎                                                                        | 100/1250 [00:42<08:06,  2.37it/s]

[99/1250],train loss is:0.199850


 16%|████████████▋                                                                  | 200/1250 [01:24<07:17,  2.40it/s]

[199/1250],train loss is:0.202689


 24%|██████████████████▉                                                            | 300/1250 [02:06<06:36,  2.39it/s]

[299/1250],train loss is:0.204328


 32%|█████████████████████████▎                                                     | 400/1250 [02:49<06:06,  2.32it/s]

[399/1250],train loss is:0.204123


 40%|███████████████████████████████▌                                               | 500/1250 [03:31<05:15,  2.37it/s]

[499/1250],train loss is:0.205347


 48%|█████████████████████████████████████▉                                         | 600/1250 [04:12<04:29,  2.41it/s]

[599/1250],train loss is:0.206231


 56%|████████████████████████████████████████████▏                                  | 700/1250 [04:55<03:48,  2.41it/s]

[699/1250],train loss is:0.206463


 64%|██████████████████████████████████████████████████▌                            | 800/1250 [05:37<03:09,  2.37it/s]

[799/1250],train loss is:0.206570


 72%|████████████████████████████████████████████████████████▉                      | 900/1250 [06:19<02:25,  2.40it/s]

[899/1250],train loss is:0.206697


 80%|██████████████████████████████████████████████████████████████▍               | 1000/1250 [07:01<01:44,  2.39it/s]

[999/1250],train loss is:0.207053


 88%|████████████████████████████████████████████████████████████████████▋         | 1100/1250 [07:43<01:02,  2.41it/s]

[1099/1250],train loss is:0.207301


 96%|██████████████████████████████████████████████████████████████████████████▉   | 1200/1250 [08:24<00:20,  2.42it/s]

[1199/1250],train loss is:0.207742


100%|██████████████████████████████████████████████████████████████████████████████| 1250/1250 [08:45<00:00,  2.38it/s]


epoch:[4],train loss is:0.208012 

eval loss is 0.157265, precision is:5667.0, recall is:12971.0, f1_score is:0.8737953897155193


  8%|██████▎                                                                        | 100/1250 [00:41<07:56,  2.41it/s]

[99/1250],train loss is:0.182637


 16%|████████████▋                                                                  | 200/1250 [01:23<07:15,  2.41it/s]

[199/1250],train loss is:0.184236


 24%|██████████████████▉                                                            | 300/1250 [02:04<06:34,  2.41it/s]

[299/1250],train loss is:0.184038


 32%|█████████████████████████▎                                                     | 400/1250 [02:46<05:52,  2.41it/s]

[399/1250],train loss is:0.186420


 40%|███████████████████████████████▌                                               | 500/1250 [03:27<05:10,  2.42it/s]

[499/1250],train loss is:0.186689


 48%|█████████████████████████████████████▉                                         | 600/1250 [04:09<04:29,  2.41it/s]

[599/1250],train loss is:0.187515


 56%|████████████████████████████████████████████▏                                  | 700/1250 [04:50<03:48,  2.40it/s]

[699/1250],train loss is:0.188037


 64%|██████████████████████████████████████████████████▌                            | 800/1250 [05:32<03:06,  2.41it/s]

[799/1250],train loss is:0.188433


 72%|████████████████████████████████████████████████████████▉                      | 900/1250 [06:13<02:25,  2.41it/s]

[899/1250],train loss is:0.188607


 80%|██████████████████████████████████████████████████████████████▍               | 1000/1250 [06:55<01:43,  2.41it/s]

[999/1250],train loss is:0.189152


 88%|████████████████████████████████████████████████████████████████████▋         | 1100/1250 [07:37<01:02,  2.40it/s]

[1099/1250],train loss is:0.189507


 96%|██████████████████████████████████████████████████████████████████████████▉   | 1200/1250 [08:18<00:20,  2.40it/s]

[1199/1250],train loss is:0.190072


100%|██████████████████████████████████████████████████████████████████████████████| 1250/1250 [08:39<00:00,  2.40it/s]


epoch:[5],train loss is:0.190201 

eval loss is 0.136241, precision is:5774.0, recall is:12971.0, f1_score is:0.8902937321717678


  8%|██████▎                                                                        | 100/1250 [00:41<07:55,  2.42it/s]

[99/1250],train loss is:0.160314


 16%|████████████▋                                                                  | 200/1250 [01:23<07:15,  2.41it/s]

[199/1250],train loss is:0.163678


 24%|██████████████████▉                                                            | 300/1250 [02:04<06:33,  2.41it/s]

[299/1250],train loss is:0.165643


 32%|█████████████████████████▎                                                     | 400/1250 [02:46<05:52,  2.41it/s]

[399/1250],train loss is:0.168367


 40%|███████████████████████████████▌                                               | 500/1250 [03:27<05:11,  2.41it/s]

[499/1250],train loss is:0.169336


 48%|█████████████████████████████████████▉                                         | 600/1250 [04:09<04:29,  2.41it/s]

[599/1250],train loss is:0.169715


 56%|████████████████████████████████████████████▏                                  | 700/1250 [04:50<03:48,  2.41it/s]

[699/1250],train loss is:0.170058


 64%|██████████████████████████████████████████████████▌                            | 800/1250 [05:32<03:06,  2.41it/s]

[799/1250],train loss is:0.170866


 72%|████████████████████████████████████████████████████████▉                      | 900/1250 [06:13<02:24,  2.41it/s]

[899/1250],train loss is:0.171367


 80%|██████████████████████████████████████████████████████████████▍               | 1000/1250 [06:55<01:44,  2.40it/s]

[999/1250],train loss is:0.172151


 88%|████████████████████████████████████████████████████████████████████▋         | 1100/1250 [07:37<01:02,  2.41it/s]

[1099/1250],train loss is:0.172435


 96%|██████████████████████████████████████████████████████████████████████████▉   | 1200/1250 [08:18<00:20,  2.40it/s]

[1199/1250],train loss is:0.172635


100%|██████████████████████████████████████████████████████████████████████████████| 1250/1250 [08:39<00:00,  2.41it/s]


epoch:[6],train loss is:0.172976 

eval loss is 0.119482, precision is:5825.0, recall is:12875.0, f1_score is:0.9048543689320389


  8%|██████▎                                                                        | 100/1250 [00:41<07:56,  2.41it/s]

[99/1250],train loss is:0.146675


 16%|████████████▋                                                                  | 200/1250 [01:23<07:16,  2.41it/s]

[199/1250],train loss is:0.146065


 24%|██████████████████▉                                                            | 300/1250 [02:04<06:35,  2.40it/s]

[299/1250],train loss is:0.148205


 32%|█████████████████████████▎                                                     | 400/1250 [02:46<05:53,  2.41it/s]

[399/1250],train loss is:0.149154


 40%|███████████████████████████████▌                                               | 500/1250 [03:27<05:13,  2.39it/s]

[499/1250],train loss is:0.150373


 48%|█████████████████████████████████████▉                                         | 600/1250 [04:09<04:29,  2.41it/s]

[599/1250],train loss is:0.151439


 56%|████████████████████████████████████████████▏                                  | 700/1250 [04:51<03:48,  2.41it/s]

[699/1250],train loss is:0.152589


 64%|██████████████████████████████████████████████████▌                            | 800/1250 [05:32<03:07,  2.39it/s]

[799/1250],train loss is:0.153209


 72%|████████████████████████████████████████████████████████▉                      | 900/1250 [06:14<02:25,  2.40it/s]

[899/1250],train loss is:0.153428


 80%|██████████████████████████████████████████████████████████████▍               | 1000/1250 [06:56<01:43,  2.41it/s]

[999/1250],train loss is:0.153748


 88%|████████████████████████████████████████████████████████████████████▋         | 1100/1250 [07:37<01:02,  2.40it/s]

[1099/1250],train loss is:0.154142


 96%|██████████████████████████████████████████████████████████████████████████▉   | 1200/1250 [08:19<00:20,  2.41it/s]

[1199/1250],train loss is:0.154299


100%|██████████████████████████████████████████████████████████████████████████████| 1250/1250 [08:39<00:00,  2.40it/s]


epoch:[7],train loss is:0.154607 

eval loss is 0.098849, precision is:6042.0, recall is:13063.0, f1_score is:0.9250555002679324


  8%|██████▎                                                                        | 100/1250 [00:41<08:02,  2.38it/s]

[99/1250],train loss is:0.126738


 16%|████████████▋                                                                  | 200/1250 [01:23<07:24,  2.36it/s]

[199/1250],train loss is:0.143530


 24%|██████████████████▉                                                            | 300/1250 [02:05<06:36,  2.40it/s]

[299/1250],train loss is:0.142241


 32%|█████████████████████████▎                                                     | 400/1250 [02:47<05:51,  2.42it/s]

[399/1250],train loss is:0.140792


 40%|███████████████████████████████▌                                               | 500/1250 [03:29<05:10,  2.42it/s]

[499/1250],train loss is:0.140621


 48%|█████████████████████████████████████▉                                         | 600/1250 [04:10<04:29,  2.41it/s]

[599/1250],train loss is:0.141191


 56%|████████████████████████████████████████████▏                                  | 700/1250 [04:52<03:47,  2.42it/s]

[699/1250],train loss is:0.141305


 64%|██████████████████████████████████████████████████▌                            | 800/1250 [05:33<03:07,  2.40it/s]

[799/1250],train loss is:0.140923


 72%|████████████████████████████████████████████████████████▉                      | 900/1250 [06:15<02:24,  2.42it/s]

[899/1250],train loss is:0.140879


 80%|██████████████████████████████████████████████████████████████▍               | 1000/1250 [06:56<01:43,  2.41it/s]

[999/1250],train loss is:0.141060


 88%|████████████████████████████████████████████████████████████████████▋         | 1100/1250 [07:38<01:02,  2.41it/s]

[1099/1250],train loss is:0.141252


 96%|██████████████████████████████████████████████████████████████████████████▉   | 1200/1250 [08:19<00:20,  2.41it/s]

[1199/1250],train loss is:0.141677


100%|██████████████████████████████████████████████████████████████████████████████| 1250/1250 [08:40<00:00,  2.40it/s]


epoch:[8],train loss is:0.142001 

eval loss is 0.084189, precision is:6133.0, recall is:13062.0, f1_score is:0.9390598683203185


  8%|██████▎                                                                        | 100/1250 [00:41<07:58,  2.40it/s]

[99/1250],train loss is:0.116023


 16%|████████████▋                                                                  | 200/1250 [01:23<07:15,  2.41it/s]

[199/1250],train loss is:0.116736


 24%|██████████████████▉                                                            | 300/1250 [02:04<06:35,  2.40it/s]

[299/1250],train loss is:0.118983


 32%|█████████████████████████▎                                                     | 400/1250 [02:46<05:52,  2.41it/s]

[399/1250],train loss is:0.119712


 40%|███████████████████████████████▌                                               | 500/1250 [03:27<05:13,  2.39it/s]

[499/1250],train loss is:0.120414


 48%|█████████████████████████████████████▉                                         | 600/1250 [04:09<04:29,  2.41it/s]

[599/1250],train loss is:0.121242


 56%|████████████████████████████████████████████▏                                  | 700/1250 [04:50<03:48,  2.41it/s]

[699/1250],train loss is:0.121546


 64%|██████████████████████████████████████████████████▌                            | 800/1250 [05:32<03:07,  2.41it/s]

[799/1250],train loss is:0.121808


 72%|████████████████████████████████████████████████████████▉                      | 900/1250 [06:14<02:26,  2.40it/s]

[899/1250],train loss is:0.122856


 80%|██████████████████████████████████████████████████████████████▍               | 1000/1250 [06:55<01:44,  2.39it/s]

[999/1250],train loss is:0.123107


 88%|████████████████████████████████████████████████████████████████████▋         | 1100/1250 [07:37<01:02,  2.40it/s]

[1099/1250],train loss is:0.123441


 96%|██████████████████████████████████████████████████████████████████████████▉   | 1200/1250 [08:19<00:20,  2.39it/s]

[1199/1250],train loss is:0.124105


100%|██████████████████████████████████████████████████████████████████████████████| 1250/1250 [08:39<00:00,  2.40it/s]


epoch:[9],train loss is:0.124447 

eval loss is 0.069019, precision is:6248.0, recall is:13135.0, f1_score is:0.9513513513513514


<br>

### 四、生成提交数据

In [14]:
import json
import torch
import numpy as np

# ark-nlp提供该函数：from ark_nlp.model.ner.global_pointer_bert import Predictor
# 这里主要是为了可以比较清晰地看到解码过程，所以将代码copy到这
class GlobalPointerNERPredictor(object):
    """
    GlobalPointer命名实体识别的预测器

    Args:
        module: 深度学习模型
        tokernizer: 分词器
        cat2id (:obj:`dict`): 标签映射
    """  # noqa: ignore flake8"

    def __init__(
        self,
        module,
        tokernizer,
        cat2id
    ):
        self.module = module
        self.module.task = 'TokenLevel'

        self.cat2id = cat2id
        self.tokenizer = tokernizer
        self.device = list(self.module.parameters())[0].device

        self.id2cat = {}
        for cat_, idx_ in self.cat2id.items():
            self.id2cat[idx_] = cat_

    def _convert_to_transfomer_ids(
        self,
        text
    ):

        tokens = self.tokenizer.tokenize(text)
        token_mapping = self.tokenizer.get_token_mapping(text, tokens)

        input_ids = self.tokenizer.sequence_to_ids(tokens)
        input_ids, input_mask, segment_ids = input_ids

        zero = [0 for i in range(self.tokenizer.max_seq_len)]
        span_mask = [input_mask for i in range(sum(input_mask))]
        span_mask.extend([zero for i in range(sum(input_mask), self.tokenizer.max_seq_len)])
        span_mask = np.array(span_mask)

        features = {
            'input_ids': input_ids,
            'attention_mask': input_mask,
            'token_type_ids': segment_ids,
            'span_mask': span_mask
        }

        return features, token_mapping

    def _get_input_ids(
        self,
        text
    ):
        if self.tokenizer.tokenizer_type == 'vanilla':
            return self._convert_to_vanilla_ids(text)
        elif self.tokenizer.tokenizer_type == 'transfomer':
            return self._convert_to_transfomer_ids(text)
        elif self.tokenizer.tokenizer_type == 'customized':
            return self._convert_to_customized_ids(text)
        else:
            raise ValueError("The tokenizer type does not exist")

    def _get_module_one_sample_inputs(
        self,
        features
    ):
        return {col: torch.Tensor(features[col]).type(torch.long).unsqueeze(0).to(self.device) for col in features}

    def predict_one_sample(
        self,
        text='',
        threshold=0
    ):
        """
        单样本预测

        Args:
            text (:obj:`string`): 输入文本
            threshold (:obj:`float`, optional, defaults to 0): 预测的阈值
        """  # noqa: ignore flake8"

        features, token_mapping = self._get_input_ids(text)
        self.module.eval()

        with torch.no_grad():
            inputs = self._get_module_one_sample_inputs(features)
            scores = self.module(**inputs)[0].cpu()
            
        scores[:, [0, -1]] -= np.inf
        scores[:, :, [0, -1]] -= np.inf

        entities = []

        for category, start, end in zip(*np.where(scores > threshold)):
            if end-1 > token_mapping[-1][-1]:
                break
            if token_mapping[start-1][0] <= token_mapping[end-1][-1]:
                entitie_ = {
                    "start_idx": token_mapping[start-1][0],
                    "end_idx": token_mapping[end-1][-1],
                    "entity": text[token_mapping[start-1][0]: token_mapping[end-1][-1]+1],
                    "type": self.id2cat[category]
                }

                if entitie_['entity'] == '':
                    continue

                entities.append(entitie_)

        return entities


In [15]:
ner_predictor_instance = GlobalPointerNERPredictor(model.module, tokenizer, ner_train_dataset.cat2id)

In [16]:
from tqdm import tqdm

predict_results = []

with open('./data/preliminary_test_a/sample_per_line_preliminary_A.txt', 'r', encoding='utf-8') as f:
    lines = f.readlines()
    for _line in tqdm(lines):
        label = len(_line) * ['O']
        for _preditc in ner_predictor_instance.predict_one_sample(_line[:-1]):
            if 'I' in label[_preditc['start_idx']]:
                continue
            if 'B' in label[_preditc['start_idx']] and 'O' not in label[_preditc['end_idx']]:
                continue
            if 'O' in label[_preditc['start_idx']] and 'B' in label[_preditc['end_idx']]:
                continue

            label[_preditc['start_idx']] = 'B-' +  _preditc['type']
            label[_preditc['start_idx']+1: _preditc['end_idx']+1] = (_preditc['end_idx'] - _preditc['start_idx']) * [('I-' +  _preditc['type'])]
            
        predict_results.append([_line, label])

100%|████████████████████████████████████████████████████████████████████████████| 10000/10000 [07:13<00:00, 23.07it/s]


In [17]:
with open('gobal_pointer_baseline.txt', 'w', encoding='utf-8') as f:
    for _result in predict_results:
        for word, tag in zip(_result[0], _result[1]):
            if word == '\n':
                continue
            f.write(f'{word} {tag}\n')
        f.write('\n')