In [1]:
import warnings
warnings.filterwarnings("ignore")

import os
import json
import jieba
import torch
import pickle
import codecs
import torch.nn as nn
import torch.optim as optim
import pandas as pd

from ark_nlp.model.re.prgc_bert import PRGCBert
from ark_nlp.model.re.prgc_bert import PRGCBertConfig
from ark_nlp.model.re.prgc_bert import Dataset
from ark_nlp.model.re.prgc_bert import Task
from ark_nlp.model.re.prgc_bert import get_default_model_optimizer
from ark_nlp.model.re.prgc_bert import Tokenizer

In [2]:
# 目录地址

train_data_path = '../data/source_datasets/CMeIE/CMeIE_train.json'
dev_data_path = '../data/source_datasets/CMeIE/CMeIE_dev.json'

#### 1. 数据读入

In [3]:
train_data_list = []

with codecs.open(train_data_path, mode='r', encoding='utf8') as f:
    lines = f.readlines()
    for index_, line_ in enumerate(lines):
        record_ = {}
        line_ = json.loads(line_.strip())
        record_['text'] = line_['text']
        record_['label'] = []
        for triple_ in line_['spo_list']:
            record_['label'].append([
                triple_['subject'],
                record_['text'].index(triple_['subject']),
                record_['text'].index(triple_['subject'])+ len(triple_['subject']) - 1,
                triple_['predicate'] + '@' + triple_['object_type']['@value'],
                triple_['object']['@value'],
                record_['text'].index(triple_['object']['@value']),
                record_['text'].index(triple_['object']['@value']) + len(triple_['object']['@value']) - 1,
            ])
        train_data_list.append(record_)

train_data_df = pd.DataFrame(train_data_list)

In [4]:
dev_data_list = []
counter = 0
with codecs.open(dev_data_path, mode='r', encoding='utf8') as f:
    lines = f.readlines()
    for index_, line_ in enumerate(lines):
        record_ = {}
        line_ = json.loads(line_.strip())
        record_['text'] = line_['text']
        record_['label'] = []
        for triple_ in line_['spo_list']:
            record_['label'].append([
                triple_['subject'],
                record_['text'].index(triple_['subject']),
                record_['text'].index(triple_['subject'])+ len(triple_['subject']) - 1,
                triple_['predicate'] + '@' + triple_['object_type']['@value'],
                triple_['object']['@value'],
                record_['text'].index(triple_['object']['@value']),
                record_['text'].index(triple_['object']['@value']) + len(triple_['object']['@value']) - 1,
            ])
            counter += 1
        dev_data_list.append(record_)
        
dev_data_df = pd.DataFrame(dev_data_list)

In [5]:
re_train_dataset = Dataset(train_data_df, is_retain_dataset=True)
re_dev_dataset = Dataset(dev_data_df,
                         categories = re_train_dataset.categories,
                         is_train=False)

#### 2. 词典创建和生成分词器

In [6]:
tokenizer = Tokenizer(vocab='nghuyong/ernie-1.0', max_seq_len=100)

#### 3. ID化

In [7]:
re_train_dataset.convert_to_ids(tokenizer)
re_dev_dataset.convert_to_ids(tokenizer)

<br>

### 二、模型构建

#### 1. 模型参数设置

In [8]:
bert_config = PRGCBertConfig.from_pretrained('nghuyong/ernie-1.0',
                                               num_labels=len(re_train_dataset.cat2id))

#### 2. 模型创建

In [9]:
dl_module = PRGCBert.from_pretrained('nghuyong/ernie-1.0',
                                       config=bert_config)

Some weights of the model checkpoint at nghuyong/ernie-1.0 were not used when initializing PRGCBert: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing PRGCBert from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing PRGCBert from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of PRGCBert were not initialized from the model checkpoint at nghuyong/ernie-1.0 and are newly initialized: ['sequence_tagging_sub.hidden2tag.bias', 'sequence_tagging_sum.h

<br>

### 三、任务构建

#### 1. 任务参数和必要部件设定

In [10]:
optimizer = get_default_model_optimizer(dl_module) 

#### 2. 任务创建

In [11]:
model = Task(dl_module, optimizer, None, cuda_device=0)

#### 训练

In [None]:
model.fit(re_train_dataset,
          re_dev_dataset,
          epochs=30, 
          batch_size=6)

<br>

### 四、模型预测

#### 1. 模型验证

In [13]:
from ark_nlp.model.re.prgc_bert import Predictor

prgc_re_predictor_instance = Predictor(model.module, tokenizer, re_train_dataset.cat2id)

In [14]:
text = '骨性关节炎@在其他关节（如踝关节和腕关节），骨性关节炎比较少见，并且一般有潜在的病因（如结晶性关节病、创伤）'

In [15]:
prgc_re_predictor_instance.predict_one_sample(text)

[('骨性关节炎', '相关（导致）@疾病', '踝关节'),
 ('骨性关节炎', '病因@社会学', '踝关节'),
 ('骨性关节炎', '发病部位@部位', '踝关节'),
 ('骨性关节炎', '相关（导致）@疾病', '腕关节'),
 ('骨性关节炎', '发病部位@部位', '腕关节'),
 ('骨性关节炎', '病因@社会学', '腕关节')]

#### 2. 多样本验证

In [16]:
record_ = []
with codecs.open('../data/source_datasets/CMeIE/CMeIE_test.json', mode='r', encoding='utf8') as f:
    lines = f.readlines()
    for index_, line_ in enumerate(lines):
        line_ = json.loads(line_.strip())
        record_.append([line_['text'], prgc_re_predictor_instance.predict_one_sample(line_['text'])])

<br>

### 五、模型测试报告

In [None]:
1. 基本功能测试 通过
2. one sample predict 通过
3. 多样本验证  通过