In [None]:
import warnings
warnings.filterwarnings("ignore")

import os
import json
import jieba
import torch
import pickle
import codecs
import torch.nn as nn
import torch.optim as optim
import pandas as pd

from ark_nlp.model.re.gplinker_bert import Module
from ark_nlp.model.re.gplinker_bert import ModuleConfig
from ark_nlp.model.re.gplinker_bert import Dataset
from ark_nlp.model.re.gplinker_bert import Task
from ark_nlp.model.re.gplinker_bert import get_default_model_optimizer
from ark_nlp.model.re.gplinker_bert import Tokenizer
from ark_nlp.factory.utils.seed import set_seed

In [None]:
set_seed(42)

In [None]:
# 目录地址
# 数据集下载地址：https://tianchi.aliyun.com/dataset/dataDetail?dataId=95414

train_data_path = '../data/source_datasets/CMeIE/CMeIE_train.json'
dev_data_path = '../data/source_datasets/CMeIE/CMeIE_dev.json'

In [None]:
model_name = 'freedomking/ernie-1.0'

### 一、数据读入与处理

#### 1. 数据读入

In [None]:
train_data_list = []

with codecs.open(train_data_path, mode='r', encoding='utf8') as f:
    lines = f.readlines()
    for index_, line_ in enumerate(lines):
        record_ = {}
        line_ = json.loads(line_.strip())
        record_['text'] = line_['text']
        record_['label'] = []
        for triple_ in line_['spo_list']:
            record_['label'].append([
                triple_['subject'],
                record_['text'].index(triple_['subject']),
                record_['text'].index(triple_['subject'])+ len(triple_['subject']) - 1,
                triple_['predicate'] + '@' + triple_['object_type']['@value'],
                triple_['object']['@value'],
                record_['text'].index(triple_['object']['@value']),
                record_['text'].index(triple_['object']['@value']) + len(triple_['object']['@value']) - 1,
            ])
        train_data_list.append(record_)

train_df = pd.DataFrame(train_data_list)

In [None]:
dev_data_list = []
counter = 0
with codecs.open(dev_data_path, mode='r', encoding='utf8') as f:
    lines = f.readlines()
    for index_, line_ in enumerate(lines):
        record_ = {}
        line_ = json.loads(line_.strip())
        record_['text'] = line_['text']
        record_['label'] = []
        for triple_ in line_['spo_list']:
            record_['label'].append([
                triple_['subject'],
                record_['text'].index(triple_['subject']),
                record_['text'].index(triple_['subject'])+ len(triple_['subject']) - 1,
                triple_['predicate'] + '@' + triple_['object_type']['@value'],
                triple_['object']['@value'],
                record_['text'].index(triple_['object']['@value']),
                record_['text'].index(triple_['object']['@value']) + len(triple_['object']['@value']) - 1,
            ])
            counter += 1
        dev_data_list.append(record_)
        
dev_df = pd.DataFrame(dev_data_list)

In [None]:
re_train_dataset = Dataset(train_df)
re_dev_dataset = Dataset(dev_df,
                         categories = re_train_dataset.categories,
                         is_train=False)

#### 2. 词典创建和生成分词器

In [None]:
tokenizer = Tokenizer(vocab=model_name, max_seq_len=100)

#### 3. ID化

In [None]:
re_train_dataset.convert_to_ids(tokenizer)
re_dev_dataset.convert_to_ids(tokenizer)

<br>

### 二、模型构建

#### 1. 模型参数设置

In [None]:
bert_config = ModuleConfig.from_pretrained(model_name, num_labels=len(re_train_dataset.cat2id))

#### 2. 模型创建

In [None]:
# entity_type_num 实体类型, 默认值为2, 即头实体和尾实体 
# relation_type_num 关系类型, 一般是使用头实体类型+关系类型+尾实体类型, 默认值为2
dl_module = Module.from_pretrained(model_name,
                                  config=bert_config,
                                  entity_type_num=2, 
                                  relation_type_num=len(re_train_dataset.cat2id))

<br>

### 三、任务构建

#### 1. 任务参数和必要部件设定

In [None]:
optimizer = get_default_model_optimizer(dl_module) 

#### 2. 任务创建

In [None]:
# 新版的 loss 有默认的选项，所以这里可以设置为None
model = Task(dl_module, optimizer, None, cuda_device=0)

#### 3. 训练

In [None]:
# 新版的 fit 参数还包含 task_utils 中的所有参数，
# 主要用来控制 结果打印 和 模型保存 的条件

# 部分问题:

# Q1: 训练时 train_loss 大于 loss？
# A1: 新版的 train loss 和 loss 计算公式不同。

# 注意：evaluate_during_training_step 等于0 时不会自动保存最佳模型。

model.fit(re_train_dataset,
          re_dev_dataset,
          epoch_num=10
          batch_size=32,
          evaluate_during_training_step=200,
          do_save_best_module=True,
          save_best_module_metric='f1-score',
         )

<br>

### 四、模型预测

#### 1. 模型验证

In [None]:
from ark_nlp.model.re.gplinker_bert import Predictor

gpliner_re_predictor_instance = Predictor(model.module, tokenizer, re_train_dataset.cat2id) 

In [None]:
text = '骨性关节炎@在其他关节（如踝关节和腕关节），骨性关节炎比较少见，并且一般有潜在的病因（如结晶性关节病、创伤）'

In [None]:
gpliner_re_predictor_instance.predict_one_sample(text)

#### 2. 多样本验证

In [None]:
test_data_path = '../data/source_datasets/CMeIE/CMeIE_test.json'
schemas_data_path = '../data/source_datasets/CMeIE/53_schemas.json'
output_data_path = '../data/output_datasets/CMeIE_test.jsonl'

In [None]:
result = []

with open(test_data_path, 'r', encoding='utf-8') as f:
    lines = f.readlines()
    for line_ in lines:
        result.append(gpliner_re_predictor_instance.predict_one_sample(eval(line_)['text']))

In [None]:
from tqdm import tqdm

all_subject_type = []
all_predicate = []
all_shcemas = []
predicate2subject = {}
with open(schemas_data_path, 'r', encoding='utf-8') as fs:
    for jsonstr in fs.readlines():
        jsonstr = json.loads(jsonstr)
        # all_shcemas.append(jsonstr)
        
        predicate2subject[jsonstr['predicate']+'@'+jsonstr['object_type']] = jsonstr['subject_type']
        
    fs.close()
    
all_predicate = set(all_predicate)
with open(output_data_path, 'w', encoding='utf-8') as fw:
    with open(test_data_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        for index_, jsonstr in tqdm(enumerate(lines)):
            line = json.loads(jsonstr)
            results_len = []
            sentence = line['text']
            dict_list = result[index_]
            new = []
            for list_ in dict_list:
                for predicate_ in predicate2subject:
                    if list_[1] == predicate_:
                        if list_[-1] != '' and list_[-1] != '[UNK]':
                            result_dict = {
                                'predicate': predicate_.split('@')[0],
                                "subject": list_[0],
                                'subject_type': predicate2subject[predicate_],
                                "object": {"@value": list_[-1]},
                                'object_type': {"@value":predicate_.split('@')[-1]}
                                }
                        else:
                            continue
                        if result_dict not in new:
                            new.append(result_dict)
            if sum([item.count('。') for item in sentence]) >= 2:
                for item in new:
                    item['Combined'] = True
            else:
                for item in new:
                    item['Combined'] = False

            if len(new) == 0:
                new = [{
                    "Combined": '',
                    "predicate": '',
                    "subject": '',
                    "subject_type": '',
                    "object": {"@value": ""},
                    "object_type": {"@value": ""},
                }]
                pred_dict = {
                    "text": ''.join(sentence),
                    "spo_list": new,
                }
            else:

                pred_dict = {
                    "text": ''.join(sentence),
                    "spo_list": new,
                }
            fw.write(json.dumps(pred_dict, ensure_ascii=False) + '\n')
f.close()
fw.close()