In [None]:
import warnings
warnings.filterwarnings("ignore")

import os
import jieba
import torch
import pickle
import torch.nn as nn
import torch.optim as optim
import pandas as pd

from ark_nlp.model.ner.w2ner_bert import W2NERBert
from ark_nlp.model.ner.w2ner_bert import W2NERBertConfig
from ark_nlp.model.ner.w2ner_bert import Dataset
from ark_nlp.model.ner.w2ner_bert import Task
from ark_nlp.model.ner.w2ner_bert import get_default_model_optimizer
from ark_nlp.factory.lr_scheduler import get_default_linear_schedule_with_warmup
from ark_nlp.model.ner.w2ner_bert import Tokenizer
from ark_nlp.factory.utils.seed import set_seed
set_seed(42)

In [None]:
# 目录地址
train_data_path = '../data/source_datasets/CMeEE/CMeEE_train.json'
dev_data_path = '../data/source_datasets/CMeEE/CMeEE_dev.json'

#### 1. 数据读入

In [None]:
train_data_df = pd.read_json(train_data_path)
dev_data_df = pd.read_json(dev_data_path)

In [None]:
def get_label(x):
    
    entities = []
    for entity in x:
        entity_ = {}
        idx = list(range(entity['start_idx'], entity['end_idx']+1))
        entity_['idx'] = idx
        entity_['type'] = entity['type']
        entity_['entity'] = entity['entity']
        entities.append(entity_)
    
    return entities

In [None]:
train_data_df['label'] = train_data_df['entities'].apply(lambda x: get_label(x))
dev_data_df['label'] = dev_data_df['entities'].apply(lambda x: get_label(x))

In [None]:
train_data_df = train_data_df.loc[:,['text', 'label']]
train_data_df['label'] = train_data_df['label'].apply(lambda x: str(x))
dev_data_df = dev_data_df.loc[:,['text', 'label']]
dev_data_df['label'] = dev_data_df['label'].apply(lambda x: str(x))

In [None]:
ner_train_dataset = Dataset(train_data_df)
ner_dev_dataset = Dataset(dev_data_df, categories=ner_train_dataset.categories)

#### 2. 词典创建和生成分词器

In [None]:
tokenizer = Tokenizer(vocab='nghuyong/ernie-1.0', max_seq_len=128)

#### 3. ID化

In [None]:
ner_train_dataset.convert_to_ids(tokenizer)
ner_dev_dataset.convert_to_ids(tokenizer)

<br>

### 二、模型构建

#### 1. 模型参数设置

In [None]:
config = W2NERBertConfig.from_pretrained('nghuyong/ernie-1.0',
                                         num_labels=len(ner_train_dataset.cat2id))

#### 2. 模型创建

In [None]:
torch.cuda.empty_cache()

In [None]:
dl_module = W2NERBert.from_pretrained('nghuyong/ernie-1.0',
                                    config=config)

<br>

### 三、任务构建

#### 1. 任务参数和必要部件设定

In [None]:
# 设置运行次数
num_epoches = 6
batch_size = 8

In [None]:
optimizer = get_default_model_optimizer(dl_module)

In [None]:
# 注意lr衰减轮次的设定
t_total = len(ner_train_dataset) // batch_size * num_epoches
scheduler = get_default_linear_schedule_with_warmup(optimizer, t_total, warmup_ratio=0.1)

#### 2. 任务创建

In [None]:
model = Task(dl_module, optimizer, 'ce', cude_device=0, scheduler=scheduler, grad_clip=5.0)

#### 3. 训练

In [None]:
model.fit(ner_train_dataset,
          ner_dev_dataset,
          # lr=3e-5,
          epochs=1,
          batch_size=batch_size
         )

<br>

### 四、生成提交数据

In [None]:
import json
from tqdm import tqdm
from ark_nlp.model.ner.w2ner_bert import Predictor

In [None]:
ner_predictor_instance = Predictor(model.module, tokenizer, ner_train_dataset.cat2id)

In [None]:
test_df = pd.read_json('../data/source_datasets/CMeEE/CMeEE_test.json')

submit = []
for _text in tqdm(test_df['text'].to_list()):
    
    entities = ner_predictor_instance.predict_one_sample(_text)
    
    entities_ = []
    for entity_ in entities:
        entities_.append({
                "start_idx": entity_['idx'][0],
                "end_idx": entity_['idx'][-1],
                "entity": entity_['entity'],
                "type": entity_['entity'],
            })
    
    submit.append({
        'text': _text,
        'entities': entities_
    })

In [None]:
output_path = '../data/output_datasets/CMeEE_test.json'

with open(output_path,'w', encoding='utf-8') as f:
    f.write(json.dumps(submit, ensure_ascii=False))