In [12]:
import warnings
warnings.filterwarnings("ignore")

import os
import jieba
import torch
import pickle
import torch.nn as nn
import torch.optim as optim
import pandas as pd

from ark_nlp.model.ner.global_pointer_bert import GlobalPointerBert
from ark_nlp.model.ner.global_pointer_bert import GlobalPointerBertConfig
from ark_nlp.model.ner.global_pointer_bert import Dataset
from ark_nlp.model.ner.global_pointer_bert import Task
from ark_nlp.model.ner.global_pointer_bert import get_default_model_optimizer
from ark_nlp.model.ner.global_pointer_bert import Tokenizer

### 一、数据读入与处理

#### 1. 数据读入

In [13]:
import os
from ark_nlp.factory.utils.conlleval import get_entity_bio


datalist = []
with open('../data/train_data/train.txt', 'r', encoding='utf-8') as f:
    lines = f.readlines()
    lines.append('\n')
    
    text = []
    labels = []
    label_set = set()
    
    for line in lines: 
        if line == '\n':                
            text = ''.join(text)
            entity_labels = []
            for _type, _start_idx, _end_idx in get_entity_bio(labels, id2label=None):
                entity_labels.append({
                    'start_idx': _start_idx,
                    'end_idx': _end_idx,
                    'type': _type,
                    'entity': text[_start_idx: _end_idx+1]
                })
                
            if text == '':
                continue
            
            datalist.append({
                'text': text,
                'label': entity_labels
            })
            
            text = []
            labels = []
            
        elif line == '  O\n':
            text.append(' ')
            labels.append('O')
        else:
            line = line.strip('\n').split()
            if len(line) == 1:
                term = ' '
                label = line[0]
            else:
                term, label = line
            text.append(term)
            label_set.add(label.split('-')[-1])
            labels.append(label)

In [14]:
# 这里随意分割了一下看指标，建议实际使用sklearn分割或者交叉验证
# train_data_df = pd.DataFrame(datalist)
# train_data_df['label'] = train_data_df['label'].apply(lambda x: str(x))

# dev_data_df = pd.DataFrame(datalist[-400:])
# dev_data_df['label'] = dev_data_df['label'].apply(lambda x: str(x))


from sklearn.model_selection import train_test_split

train_data_df = pd.DataFrame(datalist)
train_data_df['label'] = train_data_df['label'].apply(lambda x: str(x))
train_data_df, dev_data_df = train_test_split(train_data_df, test_size=0.1, shuffle=True, random_state=42)

In [17]:
eval(train_data_df['label'][0])

[{'start_idx': 0, 'end_idx': 1, 'type': '40', 'entity': '手机'},
 {'start_idx': 2, 'end_idx': 4, 'type': '4', 'entity': '三脚架'},
 {'start_idx': 5, 'end_idx': 6, 'type': '14', 'entity': '网红'},
 {'start_idx': 7, 'end_idx': 8, 'type': '5', 'entity': '直播'},
 {'start_idx': 9, 'end_idx': 10, 'type': '4', 'entity': '支架'},
 {'start_idx': 11, 'end_idx': 12, 'type': '7', 'entity': '桌面'},
 {'start_idx': 13, 'end_idx': 15, 'type': '4', 'entity': '自拍杆'},
 {'start_idx': 16, 'end_idx': 17, 'type': '11', 'entity': '蓝牙'},
 {'start_idx': 18, 'end_idx': 19, 'type': '11', 'entity': '遥控'},
 {'start_idx': 20, 'end_idx': 22, 'type': '4', 'entity': '三脚架'},
 {'start_idx': 23, 'end_idx': 24, 'type': '5', 'entity': '摄影'},
 {'start_idx': 25, 'end_idx': 26, 'type': '5', 'entity': '拍摄'},
 {'start_idx': 27, 'end_idx': 28, 'type': '5', 'entity': '拍照'},
 {'start_idx': 29, 'end_idx': 30, 'type': '13', 'entity': '抖音'},
 {'start_idx': 31, 'end_idx': 35, 'type': '4', 'entity': '看电视神器'},
 {'start_idx': 36, 'end_idx': 38, 'typ

In [4]:
ner_train_dataset = Dataset(train_data_df, categories=label_set)
ner_dev_dataset = Dataset(dev_data_df, categories=ner_train_dataset.categories)

In [5]:
path = '../pretrain_model/chinese-roberta-wwm-ext-large/'

#### 2. 词典创建和生成分词器

In [6]:
tokenizer = Tokenizer(vocab=path, max_seq_len=128)

#### 3. ID化

In [7]:
ner_train_dataset.convert_to_ids(tokenizer)
ner_dev_dataset.convert_to_ids(tokenizer)

<br>

### 二、模型构建

#### 1. 模型参数设置

In [8]:
config = GlobalPointerBertConfig.from_pretrained(path, num_labels=len(ner_train_dataset.cat2id))

#### 2. 模型创建

In [9]:
torch.cuda.empty_cache()

In [10]:
dl_module = GlobalPointerBert.from_pretrained(path, config=config)

Some weights of the model checkpoint at ../pretrain_model/chinese-roberta-wwm-ext-large/ were not used when initializing GlobalPointerBert: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing GlobalPointerBert from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GlobalPointerBert from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GlobalPointerBert were not initialized from the model checkpoint at ../pretrain_mode

<br>

### 三、任务构建

#### 1. 任务参数和必要部件设定

In [11]:
# 设置运行次数
num_epoches = 5
batch_size = 32

In [12]:
optimizer = get_default_model_optimizer(dl_module)

#### 2. 任务创建

In [13]:
model = Task(dl_module, optimizer, 'gpce', cuda_device=0)

#### 3. 训练

In [14]:
model.fit(ner_train_dataset, 
          ner_dev_dataset,
          lr=2e-5,
          epochs=num_epoches, 
          batch_size=batch_size
         )

  9%|████▍                                             | 100/1125 [00:42<07:24,  2.31it/s]

[99/1125],train loss is:2.875331


 18%|████████▉                                         | 200/1125 [01:24<06:26,  2.40it/s]

[199/1125],train loss is:1.671295


 27%|█████████████▎                                    | 300/1125 [02:06<05:45,  2.39it/s]

[299/1125],train loss is:1.238539


 36%|█████████████████▊                                | 400/1125 [02:49<05:03,  2.39it/s]

[399/1125],train loss is:1.013926


 44%|██████████████████████▏                           | 500/1125 [03:30<04:19,  2.41it/s]

[499/1125],train loss is:0.875486


 53%|██████████████████████████▋                       | 600/1125 [04:12<03:37,  2.41it/s]

[599/1125],train loss is:0.780941


 62%|███████████████████████████████                   | 700/1125 [04:53<02:56,  2.41it/s]

[699/1125],train loss is:0.713699


 71%|███████████████████████████████████▌              | 800/1125 [05:35<02:14,  2.42it/s]

[799/1125],train loss is:0.662399


 80%|████████████████████████████████████████          | 900/1125 [06:17<01:33,  2.41it/s]

[899/1125],train loss is:0.621719


 89%|███████████████████████████████████████████▌     | 1000/1125 [06:58<00:51,  2.40it/s]

[999/1125],train loss is:0.588364


 98%|███████████████████████████████████████████████▉ | 1100/1125 [07:40<00:10,  2.41it/s]

[1099/1125],train loss is:0.561435


100%|█████████████████████████████████████████████████| 1125/1125 [07:50<00:00,  2.39it/s]


epoch:[0],train loss is:0.555345 

eval loss is 0.281849, precision is:50834.0, recall is:129103.0, f1_score is:0.7874952557260482


  9%|████▍                                             | 100/1125 [00:41<07:07,  2.40it/s]

[99/1125],train loss is:0.271461


 18%|████████▉                                         | 200/1125 [01:22<06:26,  2.39it/s]

[199/1125],train loss is:0.272167


 27%|█████████████▎                                    | 300/1125 [02:04<05:43,  2.40it/s]

[299/1125],train loss is:0.269575


 36%|█████████████████▊                                | 400/1125 [02:46<05:00,  2.41it/s]

[399/1125],train loss is:0.268301


 44%|██████████████████████▏                           | 500/1125 [03:27<04:19,  2.41it/s]

[499/1125],train loss is:0.268452


 53%|██████████████████████████▋                       | 600/1125 [04:09<03:38,  2.40it/s]

[599/1125],train loss is:0.267422


 62%|███████████████████████████████                   | 700/1125 [04:50<02:56,  2.41it/s]

[699/1125],train loss is:0.266638


 71%|███████████████████████████████████▌              | 800/1125 [05:32<02:15,  2.41it/s]

[799/1125],train loss is:0.265916


 80%|████████████████████████████████████████          | 900/1125 [06:13<01:33,  2.41it/s]

[899/1125],train loss is:0.265215


 89%|███████████████████████████████████████████▌     | 1000/1125 [06:55<00:51,  2.41it/s]

[999/1125],train loss is:0.265080


 98%|███████████████████████████████████████████████▉ | 1100/1125 [07:36<00:10,  2.41it/s]

[1099/1125],train loss is:0.264745


100%|█████████████████████████████████████████████████| 1125/1125 [07:47<00:00,  2.41it/s]


epoch:[1],train loss is:0.264626 

eval loss is 0.266450, precision is:51339.0, recall is:129155.0, f1_score is:0.7949982579071658


  9%|████▍                                             | 100/1125 [00:41<07:03,  2.42it/s]

[99/1125],train loss is:0.238633


 18%|████████▉                                         | 200/1125 [01:22<06:25,  2.40it/s]

[199/1125],train loss is:0.238633


 27%|█████████████▎                                    | 300/1125 [02:05<05:54,  2.33it/s]

[299/1125],train loss is:0.238617


 36%|█████████████████▊                                | 400/1125 [02:47<05:06,  2.37it/s]

[399/1125],train loss is:0.238723


 44%|██████████████████████▏                           | 500/1125 [03:29<04:21,  2.39it/s]

[499/1125],train loss is:0.239092


 53%|██████████████████████████▋                       | 600/1125 [04:11<03:37,  2.42it/s]

[599/1125],train loss is:0.238997


 62%|███████████████████████████████                   | 700/1125 [04:52<02:56,  2.41it/s]

[699/1125],train loss is:0.238403


 71%|███████████████████████████████████▌              | 800/1125 [05:34<02:14,  2.41it/s]

[799/1125],train loss is:0.239020


 80%|████████████████████████████████████████          | 900/1125 [06:15<01:33,  2.41it/s]

[899/1125],train loss is:0.238548


 89%|███████████████████████████████████████████▌     | 1000/1125 [06:57<00:51,  2.41it/s]

[999/1125],train loss is:0.238652


 98%|███████████████████████████████████████████████▉ | 1100/1125 [07:38<00:10,  2.41it/s]

[1099/1125],train loss is:0.238787


100%|█████████████████████████████████████████████████| 1125/1125 [07:48<00:00,  2.40it/s]


epoch:[2],train loss is:0.238913 

eval loss is 0.263871, precision is:52545.0, recall is:130746.0, f1_score is:0.8037721995319168


  9%|████▍                                             | 100/1125 [00:42<07:17,  2.34it/s]

[99/1125],train loss is:0.218775


 18%|████████▉                                         | 200/1125 [01:25<06:36,  2.33it/s]

[199/1125],train loss is:0.216446


 27%|█████████████▎                                    | 300/1125 [02:08<05:56,  2.31it/s]

[299/1125],train loss is:0.216180


 36%|█████████████████▊                                | 400/1125 [02:52<05:12,  2.32it/s]

[399/1125],train loss is:0.216733


 44%|██████████████████████▏                           | 500/1125 [03:35<04:30,  2.31it/s]

[499/1125],train loss is:0.216783


 53%|██████████████████████████▋                       | 600/1125 [04:18<03:40,  2.38it/s]

[599/1125],train loss is:0.216917


 62%|███████████████████████████████                   | 700/1125 [05:00<03:02,  2.32it/s]

[699/1125],train loss is:0.217385


 71%|███████████████████████████████████▌              | 800/1125 [05:43<02:21,  2.30it/s]

[799/1125],train loss is:0.217770


 80%|████████████████████████████████████████          | 900/1125 [06:25<01:33,  2.41it/s]

[899/1125],train loss is:0.218223


 89%|███████████████████████████████████████████▌     | 1000/1125 [07:06<00:51,  2.41it/s]

[999/1125],train loss is:0.218093


 98%|███████████████████████████████████████████████▉ | 1100/1125 [07:48<00:10,  2.40it/s]

[1099/1125],train loss is:0.217982


100%|█████████████████████████████████████████████████| 1125/1125 [07:58<00:00,  2.35it/s]


epoch:[3],train loss is:0.218072 

eval loss is 0.266369, precision is:51717.0, recall is:129128.0, f1_score is:0.8010191437953039


  9%|████▍                                             | 100/1125 [00:41<07:04,  2.42it/s]

[99/1125],train loss is:0.192180


 18%|████████▉                                         | 200/1125 [01:22<06:23,  2.41it/s]

[199/1125],train loss is:0.194129


 27%|█████████████▎                                    | 300/1125 [02:06<05:55,  2.32it/s]

[299/1125],train loss is:0.194167


 36%|█████████████████▊                                | 400/1125 [02:49<05:13,  2.32it/s]

[399/1125],train loss is:0.195416


 44%|██████████████████████▏                           | 500/1125 [03:32<04:18,  2.42it/s]

[499/1125],train loss is:0.196097


 53%|██████████████████████████▋                       | 600/1125 [04:13<03:38,  2.41it/s]

[599/1125],train loss is:0.196861


 62%|███████████████████████████████                   | 700/1125 [04:55<02:57,  2.40it/s]

[699/1125],train loss is:0.196512


 71%|███████████████████████████████████▌              | 800/1125 [05:37<02:14,  2.41it/s]

[799/1125],train loss is:0.196663


 80%|████████████████████████████████████████          | 900/1125 [06:19<01:33,  2.41it/s]

[899/1125],train loss is:0.197268


 89%|███████████████████████████████████████████▌     | 1000/1125 [07:00<00:52,  2.38it/s]

[999/1125],train loss is:0.197640


 98%|███████████████████████████████████████████████▉ | 1100/1125 [07:43<00:10,  2.31it/s]

[1099/1125],train loss is:0.197836


100%|█████████████████████████████████████████████████| 1125/1125 [07:54<00:00,  2.37it/s]


epoch:[4],train loss is:0.197796 

eval loss is 0.283281, precision is:53493.0, recall is:133319.0, f1_score is:0.8024812667361741


<br>

### 四、生成提交数据

In [15]:
import json
import torch
import numpy as np

# ark-nlp提供该函数：from ark_nlp.model.ner.global_pointer_bert import Predictor
# 这里主要是为了可以比较清晰地看到解码过程，所以将代码copy到这
class GlobalPointerNERPredictor(object):
    """
    GlobalPointer命名实体识别的预测器

    Args:
        module: 深度学习模型
        tokernizer: 分词器
        cat2id (:obj:`dict`): 标签映射
    """  # noqa: ignore flake8"

    def __init__(
        self,
        module,
        tokernizer,
        cat2id
    ):
        self.module = module
        self.module.task = 'TokenLevel'

        self.cat2id = cat2id
        self.tokenizer = tokernizer
        self.device = list(self.module.parameters())[0].device

        self.id2cat = {}
        for cat_, idx_ in self.cat2id.items():
            self.id2cat[idx_] = cat_

    def _convert_to_transfomer_ids(
        self,
        text
    ):

        tokens = self.tokenizer.tokenize(text)
        token_mapping = self.tokenizer.get_token_mapping(text, tokens)

        input_ids = self.tokenizer.sequence_to_ids(tokens)
        input_ids, input_mask, segment_ids = input_ids

        zero = [0 for i in range(self.tokenizer.max_seq_len)]
        span_mask = [input_mask for i in range(sum(input_mask))]
        span_mask.extend([zero for i in range(sum(input_mask), self.tokenizer.max_seq_len)])
        span_mask = np.array(span_mask)

        features = {
            'input_ids': input_ids,
            'attention_mask': input_mask,
            'token_type_ids': segment_ids,
            'span_mask': span_mask
        }

        return features, token_mapping

    def _get_input_ids(
        self,
        text
    ):
        if self.tokenizer.tokenizer_type == 'vanilla':
            return self._convert_to_vanilla_ids(text)
        elif self.tokenizer.tokenizer_type == 'transfomer':
            return self._convert_to_transfomer_ids(text)
        elif self.tokenizer.tokenizer_type == 'customized':
            return self._convert_to_customized_ids(text)
        else:
            raise ValueError("The tokenizer type does not exist")

    def _get_module_one_sample_inputs(
        self,
        features
    ):
        return {col: torch.Tensor(features[col]).type(torch.long).unsqueeze(0).to(self.device) for col in features}

    def predict_one_sample(
        self,
        text='',
        threshold=0
    ):
        """
        单样本预测

        Args:
            text (:obj:`string`): 输入文本
            threshold (:obj:`float`, optional, defaults to 0): 预测的阈值
        """  # noqa: ignore flake8"

        features, token_mapping = self._get_input_ids(text)
        self.module.eval()

        with torch.no_grad():
            inputs = self._get_module_one_sample_inputs(features)
            scores = self.module(**inputs)[0].cpu()
            
        scores[:, [0, -1]] -= np.inf
        scores[:, :, [0, -1]] -= np.inf

        entities = []

        for category, start, end in zip(*np.where(scores > threshold)):
            if end-1 > token_mapping[-1][-1]:
                break
            if token_mapping[start-1][0] <= token_mapping[end-1][-1]:
                entitie_ = {
                    "start_idx": token_mapping[start-1][0],
                    "end_idx": token_mapping[end-1][-1],
                    "entity": text[token_mapping[start-1][0]: token_mapping[end-1][-1]+1],
                    "type": self.id2cat[category]
                }

                if entitie_['entity'] == '':
                    continue

                entities.append(entitie_)

        return entities


In [16]:
ner_predictor_instance = GlobalPointerNERPredictor(model.module, tokenizer, ner_train_dataset.cat2id)

In [17]:
from tqdm import tqdm

predict_results = []

with open('../data/preliminary_test_a/sample_per_line_preliminary_A.txt', 'r', encoding='utf-8') as f:
    lines = f.readlines()
    for _line in tqdm(lines):
        label = len(_line) * ['O']
        for _preditc in ner_predictor_instance.predict_one_sample(_line[:-1]):
            if 'I' in label[_preditc['start_idx']]:
                continue
            if 'B' in label[_preditc['start_idx']] and 'O' not in label[_preditc['end_idx']]:
                continue
            if 'O' in label[_preditc['start_idx']] and 'B' in label[_preditc['end_idx']]:
                continue

            label[_preditc['start_idx']] = 'B-' + _preditc['type']
            label[_preditc['start_idx']+1: _preditc['end_idx']+1] = (
                _preditc['end_idx'] - _preditc['start_idx']) * [('I-' + _preditc['type'])]

        predict_results.append([_line, label])


100%|███████████████████████████████████████████████| 10000/10000 [06:09<00:00, 27.07it/s]


In [19]:
with open('gobal_pointer_baseline1.txt', 'w', encoding='utf-8') as f:
    for _result in predict_results:
        for word, tag in zip(_result[0], _result[1]):
            if word == '\n':
                continue
            f.write(f'{word} {tag}\n')
        f.write('\n')

In [11]:
s = set()
s.add(2)
s.add(3)
s.add(1)
s

{1, 2, 3}

In [6]:
del s