In [None]:
import warnings
warnings.filterwarnings("ignore")

import os
import gc
import json
import jieba
import torch
import pickle
import codecs
import torch.nn as nn
import torch.optim as optim
import pandas as pd

from ark_nlp.model.ner.global_pointer_bert import GlobalPointerBert
from ark_nlp.model.ner.global_pointer_bert import GlobalPointerBertConfig
from ark_nlp.model.ner.global_pointer_bert import Dataset
from ark_nlp.model.ner.global_pointer_bert import Task
from ark_nlp.model.ner.global_pointer_bert import get_default_model_optimizer
from ark_nlp.model.ner.global_pointer_bert import Tokenizer
from ark_nlp.factory.utils.seed import set_seed
set_seed(42)

In [None]:
# 目录地址

train_data_path = '../../data/source_datasets/CMeIE/CMeIE_train.json'
dev_data_path = '../../data/source_datasets/CMeIE/CMeIE_dev.json'

### 一、数据读入与处理

#### 1. 数据读入

In [None]:
def data_preprocess(data_path):

    data_list = []

    with codecs.open(data_path, mode='r', encoding='utf8') as f:
        lines = f.readlines()
        for index_, line_ in enumerate(lines):
            record_ = {}
            line_ = json.loads(line_.strip())
            record_['text'] = line_['text']
            record_['entities'] = []
            for triple_ in line_['spo_list']:
                record_['entities'].append([
                    triple_['subject'],
                    '疾病',
                    record_['text'].index(triple_['subject']),
                    record_['text'].index(triple_['subject'])+ len(triple_['subject']) - 1,
                ])
                record_['entities'].append([
                    triple_['object']['@value'],
                    triple_['object_type']['@value'],
                    record_['text'].index(triple_['object']['@value']),
                    record_['text'].index(triple_['object']['@value']) + len(triple_['object']['@value']) - 1,
                ])
            record_['entities'] = list(set([tuple(entity) for entity in record_['entities']]))
            record_['entities'] = sorted(record_['entities'], key = lambda x: x[2])
            record_['label'] = [{'entity': entity_[0], 'type': entity_[1], 'start_idx': entity_[2], 'end_idx': entity_[3]} for entity_ in record_['entities']]
            data_list.append(record_)
    return data_list

train_data_list = data_preprocess(train_data_path)
train_data_df = pd.DataFrame(train_data_list)

dev_data_list = data_preprocess(dev_data_path)
dev_data_df = pd.DataFrame(dev_data_list)

In [None]:
train_data_df['label'][0]

In [None]:
ner_train_dataset = Dataset(train_data_df)
ner_dev_dataset = Dataset(dev_data_df)

#### 2. 词典创建和生成分词器

In [None]:
# 可以先创建词典，再加载入分词器
# 也可以使用分词器自动加载
# bert_vocab = transformers.AutoTokenizer.from_pretrained('nghuyong/ernie-1.0')
# tokenizer = TransfomerTokenizer(bert_vocab, max_seq_len=30)

In [None]:
tokenizer = Tokenizer(vocab='nghuyong/ernie-1.0', max_seq_len=128)

#### 3. ID化

In [None]:
ner_train_dataset.convert_to_ids(tokenizer)
ner_dev_dataset.convert_to_ids(tokenizer)

<br>

### 二、模型构建

#### 1. 模型参数设置

In [None]:
config = GlobalPointerBertConfig.from_pretrained('nghuyong/ernie-1.0', 
                                                 num_labels=len(ner_train_dataset.cat2id))

#### 2. 模型创建

In [None]:
torch.cuda.empty_cache()

In [None]:
dl_module = GlobalPointerBert.from_pretrained('nghuyong/ernie-1.0', 
                                              config=config)

<br>

### 三、任务构建

#### 1. 任务参数和必要部件设定

In [None]:
# 设置运行次数
num_epoches = 1
batch_size = 32

In [None]:
optimizer = get_default_model_optimizer(dl_module)

#### 2. 任务创建

In [None]:
model = Task(dl_module, optimizer, 'gpce', cuda_device=0)

#### 3. 训练

In [None]:
model.fit(ner_train_dataset, 
          ner_dev_dataset,
          lr=5e-5,
          epochs=7, 
          batch_size=batch_size,
         )

<br>

### 四、生成提交数据

In [None]:
import json
from ark_nlp.model.ner.global_pointer_bert import Predictor

In [None]:
ner_predictor_instance = Predictor(model.module, tokenizer, ner_train_dataset.cat2id)

In [None]:
test_df = pd.read_json('../../data/source_datasets/CMeIE/CMeIE_test.json', lines=True)

submit = []
for _text in test_df['text'].to_list():
    submit.append({
        'text': _text,
        'entities': ner_predictor_instance.predict_one_sample(_text)
    })

In [None]:
output_path = '../../data/output_datasets/CMeIE_test_entities.json'

with open(output_path,'w', encoding='utf-8') as f:
    f.write(json.dumps(submit, ensure_ascii=False))