In [None]:
import jieba
import torch
import pickle
import transformers 
import pandas as pd
import torch.nn as nn
import torch.optim as optim

from ark_nlp.nn import ErnieConfig
from ark_nlp.dataset import BIONERDataset
from ark_nlp.processor.tokenizer.transfomer import TokensTokenizer
from ark_nlp.nn.crf_bert import CrfBert
from ark_nlp.factory.task import CRFNERTask
from ark_nlp.factory.optimizer import get_default_crf_bert_optimizer

In [None]:
# 目录地址

train_data_path = '../CLUE/data/source_datasets/CMeEE/CMeEE_train.json'
dev_data_path = '../CLUE/data/source_datasets/CMeEE/CMeEE_dev.json'

### 一、数据读入与处理

#### 1. 数据读入

In [None]:
train_data_df = pd.read_json(train_data_path)
dev_data_df = pd.read_json(dev_data_path)

In [None]:
def gen_label(df):
    label = len(df['text']) * ['O']
    for entity_ in df['entities']:
        if entity_['end_idx'] > entity_['start_idx']:
            label[entity_['start_idx']] = 'B-'+ entity_['type']
            label[entity_['start_idx']+1: entity_['end_idx']+1] = ['I-'+ entity_['type']] * (entity_['end_idx'] - entity_['start_idx'])
        elif entity_['end_idx'] == entity_['start_idx']:
            label[entity_['start_idx']] = 'B-'+ entity_['type']
        else:
            continue
    return label

In [None]:
train_data_df['label'] = train_data_df.apply(lambda x: gen_label(x), axis=1)
dev_data_df['label'] = dev_data_df.apply(lambda x: gen_label(x), axis=1)

In [None]:
train_data_df = train_data_df.loc[:,['text', 'label']]
train_data_df['label'] = train_data_df['label'].apply(lambda x: str(x))
dev_data_df = dev_data_df.loc[:,['text', 'label']]
dev_data_df['label'] = dev_data_df['label'].apply(lambda x: str(x))

In [None]:
ner_train_dataset = BIONERDataset(train_data_df)
ner_dev_dataset = BIONERDataset(dev_data_df,
                                categories = ner_train_dataset.categories)

#### 2. 词典创建和生成分词器

In [None]:
# 可以先创建词典，再加载入分词器
# 也可以使用分词器自动加载
# bert_vocab = transformers.AutoTokenizer.from_pretrained('nghuyong/ernie-1.0')
# tokenizer = TokensTokenizer(bert_vocab, max_seq_len=30)

In [None]:
tokenizer = TokensTokenizer(vocab='nghuyong/ernie-1.0', max_seq_len=100)

#### 4. ID化

In [None]:
ner_train_dataset.convert_to_ids(tokenizer)
ner_dev_dataset.convert_to_ids(tokenizer)

<br>

### 二、模型构建

#### 1. 模型参数设置

In [None]:
bert_config = ErnieConfig.from_pretrained('nghuyong/ernie-1.0', 
                                         num_labels=len(ner_train_dataset.cat2id))

#### 2. 模型创建

In [None]:
from ark_nlp.nn.crf_bert import CrfBert

In [None]:
dl_module = CrfBert.from_pretrained('nghuyong/ernie-1.0', 
                                        config=bert_config)

<br>

### 三、任务构建

#### 1. 任务参数和必要部件设定

In [None]:
optimizer = get_default_crf_bert_optimizer(dl_module) 

#### 2. 任务创建

In [None]:
model = CRFNERTask(dl_module, optimizer, 'ce', cuda_device=0)

#### 3. 训练

In [None]:
model.fit(ner_train_dataset,
          ner_dev_dataset,
          epochs=3,
          batch_size=16
         )

<br>

### 四、模型预测

In [None]:
from ark_nlp.factory.predictor import CRFNERPredictor

In [None]:
predictor = CRFNERPredictor(model.module, tokenizer, ner_train_dataset.cat2id)

In [None]:
predictor.predict_one_sample('电生理检查时可诱发持续性室上性心动过速者')