In [None]:
import warnings
warnings.filterwarnings("ignore")

import os
import json
import jieba
import torch
import pickle
import codecs
import torch.nn as nn
import torch.optim as optim
import pandas as pd

from ark_nlp.model.rc.pure_bert import PUREBert
from ark_nlp.model.rc.pure_bert import PUREBertConfig
from ark_nlp.model.rc.pure_bert import Dataset
from ark_nlp.model.rc.pure_bert import Task
from ark_nlp.model.rc.pure_bert import get_default_model_optimizer
from ark_nlp.model.rc.pure_bert import Tokenizer
from ark_nlp.factory.utils.seed import set_seed
set_seed(42)

### 一、数据读入与处理

#### 1. 数据读入

In [None]:
# 目录地址

train_data_path = '../data/source_datasets/CMeIE/CMeIE_train.json'
dev_data_path = '../data/source_datasets/CMeIE/CMeIE_dev.json'

In [None]:
def data_preprocess(data_path):

    data_list = []

    with codecs.open(data_path, mode='r', encoding='utf8') as f:
        lines = f.readlines()
        for index_, line_ in enumerate(lines):
            record_ = {}
            line_ = json.loads(line_.strip())
            record_['text'] = line_['text']
            record_['entities'] = []
            record_['triples'] = []
            for triple_ in line_['spo_list']:
                record_['entities'].append([
                    triple_['subject'],
                    '疾病',
                    record_['text'].index(triple_['subject']),
                    record_['text'].index(triple_['subject'])+ len(triple_['subject']) - 1,
                ])
                record_['entities'].append([
                    triple_['object']['@value'],
                    triple_['object_type']['@value'],
                    record_['text'].index(triple_['object']['@value']),
                    record_['text'].index(triple_['object']['@value']) + len(triple_['object']['@value']) - 1,
                ])
                record_['triples'].append([
                    triple_['subject'],
                    '疾病',
                    record_['text'].index(triple_['subject']),
                    record_['text'].index(triple_['subject'])+ len(triple_['subject']) - 1,
                    triple_['predicate'],
                    triple_['object']['@value'],
                    triple_['object_type']['@value'],
                    record_['text'].index(triple_['object']['@value']),
                    record_['text'].index(triple_['object']['@value']) + len(triple_['object']['@value']) - 1,
                ])
            record_['entities'] = list(set([tuple(entity) for entity in record_['entities']]))
            record_['entities'] = sorted(record_['entities'], key = lambda x: x[2])
            data_list.append(record_)
    return data_list

train_data_list = data_preprocess(train_data_path)
train_df = pd.DataFrame(train_data_list)

dev_data_list = data_preprocess(dev_data_path)
dev_df = pd.DataFrame(dev_data_list)

In [None]:
# 没有分类标为"None"
categories = list(set([triple[4] for triples in train_df['triples'] for triple in triples])) + ['None']

rc_train_dataset = Dataset(train_df, categories=categories)
rc_dev_dataset = Dataset(dev_df, categories=categories,
                         is_train=False)

#### 2. 词典创建和生成分词器

In [None]:
from transformers import AutoTokenizer
bert_vocab = AutoTokenizer.from_pretrained('nghuyong/ernie-1.0')

In [None]:
entity_categories = list(set([entity[1] for entities in train_df['entities'] for entity in entities]))
special_tokens = []
for category in entity_categories:
    special_tokens.append(f'[{category}]')
    special_tokens.append(f'[/{category}]')
bert_vocab.add_special_tokens({'additional_special_tokens': special_tokens})

In [None]:
tokenizer = Tokenizer(bert_vocab, max_seq_len=200)

#### 3. ID化

In [None]:
rc_train_dataset.convert_to_ids(tokenizer)
rc_dev_dataset.convert_to_ids(tokenizer)

<br>

### 二、模型构建

#### 1. 模型参数设置

In [None]:
bert_config = PUREBertConfig.from_pretrained('nghuyong/ernie-1.0',
                                               num_labels=len(rc_train_dataset.cat2id))

#### 2. 模型创建

In [None]:
dl_module = PUREBert.from_pretrained('nghuyong/ernie-1.0',
                                       config=bert_config)

<br>

### 三、任务构建

#### 1. 任务参数和必要部件设定

In [None]:
optimizer = get_default_model_optimizer(dl_module)

#### 2. 任务创建

In [None]:
model = Task(dl_module, optimizer, 'ce', cuda_device=0)

#### 3. 训练

In [None]:
model.fit(
    rc_train_dataset,
    rc_dev_dataset,
    lr=2e-5,
    epochs=10,
    batch_size=32
)

<br>

### 四、模型预测

In [None]:
from tqdm import tqdm
from ark_nlp.model.rc.pure_bert import Predictor

pure_rc_predictor_instance = Predictor(model.module, tokenizer, rc_train_dataset.cat2id)

In [None]:
text = '急性胰腺炎@有研究显示，进行早期 ERCP （24 小时内）可以降低梗阻性胆总管结石患者的并发症发生率和死亡率； 但是，对于无胆总管梗阻的胆汁性急性胰腺炎患者，不需要进行早期 ERCP。'
entities = [('急性胰腺炎', '疾病', 0, 4), ('ERCP', '检查', 17, 20)]

In [None]:
pure_rc_predictor_instance.predict_one_sample(text, entities, topk=1)

<br>

### 五、CMeIE结果输出

In [None]:
test_df = pd.read_json('../data/output_datasets/CMeIE_test_entities.json')

In [None]:
test_df['entities'] = test_df['entities'].apply(lambda x: [(entity['entity'], entity['type'], int(entity['start_idx']), int(entity['end_idx'])) for entity in x])

In [None]:
result = []
for text, entities in tqdm(zip(test_df['text'], test_df['entities'])):
    result.append(pure_rc_predictor_instance.predict_one_sample(text, entities, topk=1))

In [None]:
test_df['triples'] = result

In [None]:
result_ = []
for triples in result:
    
    triples_ = []
    for triple in triples:
        
        if triple[1] == 'None' or triple[0][1] != '疾病':
            continue
        predicate = triple[1] + '@' + triple[2][1]
        
        triples_.append((triple[0][0], predicate, triple[2][0]))
    result_.append(triples_)

In [None]:
test_data_path = '../data/source_datasets/CMeIE/CMeIE_test.json'
schemas_data_path = '../data/source_datasets/CMeIE/53_schemas.json'
output_data_path = '../data/output_datasets/CMeIE_test.jsonl'

In [None]:
all_subject_type = []
all_predicate = []
all_shcemas = []
predicate2subject = {}
with open(schemas_data_path, 'r', encoding='utf-8') as fs:
    for jsonstr in fs.readlines():
        jsonstr = json.loads(jsonstr)
        # all_shcemas.append(jsonstr)
        
        predicate2subject[jsonstr['predicate']+'@'+jsonstr['object_type']] = jsonstr['subject_type']
        
    fs.close()
    
all_predicate = set(all_predicate)
with open(output_data_path, 'w', encoding='utf-8') as fw:
    with open(test_data_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        for index_, jsonstr in tqdm(enumerate(lines)):
            line = json.loads(jsonstr)
            results_len = []
            sentence = line['text']
            dict_list = result_[index_]
            new = []
            for list_ in dict_list:
                for predicate_ in predicate2subject:
                    if list_[1] == predicate_:
                        if list_[-1] != '' and list_[-1] != '[UNK]':
                            result_dict = {
                                'predicate': predicate_.split('@')[0],
                                "subject": list_[0],
                                'subject_type': predicate2subject[predicate_],
                                "object": {"@value": list_[-1]},
                                'object_type': {"@value":predicate_.split('@')[-1]}
                                }
                        else:
                            continue
                        if result_dict not in new:
                            new.append(result_dict)
            if sum([item.count('。') for item in sentence]) >= 2:
                for item in new:
                    item['Combined'] = True
            else:
                for item in new:
                    item['Combined'] = False

            if len(new) == 0:
                new = [{
                    "Combined": '',
                    "predicate": '',
                    "subject": '',
                    "subject_type": '',
                    "object": {"@value": ""},
                    "object_type": {"@value": ""},
                }]
                pred_dict = {
                    "text": ''.join(sentence),
                    "spo_list": new,
                }
            else:

                pred_dict = {
                    "text": ''.join(sentence),
                    "spo_list": new,
                }
            fw.write(json.dumps(pred_dict, ensure_ascii=False) + '\n')
f.close()
fw.close()