In [1]:
import warnings
warnings.filterwarnings("ignore")

import os
import json
import jieba
import torch
import pickle
import codecs
import torch.nn as nn
import torch.optim as optim
import pandas as pd

import sys
sys.path.append('/home/shencj/workspace/code/nlp/Frame/ark-nlp-0.0.5/')

import os
os.environ["CUDA_VISIBLE_DEVICES"] = '1'

from ark_nlp.model.re.afea_bert import AFEABert
from ark_nlp.model.re.afea_bert import AFEABertConfig
from ark_nlp.model.re.afea_bert import Dataset
from ark_nlp.model.re.afea_bert import Task
from ark_nlp.model.re.afea_bert import get_default_model_optimizer
from ark_nlp.model.re.afea_bert import Tokenizer
from ark_nlp.factory.utils.seed import set_seed
set_seed(42)

### 一、数据读入与处理

#### 1. 数据读入

In [2]:
# 目录地址

train_data_path = '/home/shencj/workspace/data/medical/CBLUE/CMeIE/CMeIE_train.json'
dev_data_path = '/home/shencj/workspace/data/medical/CBLUE/CMeIE/CMeIE_dev.json'

In [3]:
def data_preprocess(data_path):

    data_list = []

    with codecs.open(data_path, mode='r', encoding='utf8') as f:
        lines = f.readlines()
        for index_, line_ in enumerate(lines):
            record_ = {}
            line_ = json.loads(line_.strip())
            record_['text'] = line_['text']
            record_['entities'] = []
            record_['triples'] = []
            for triple_ in line_['spo_list']:
                record_['entities'].append([
                    triple_['subject'],
                    '疾病',
                    record_['text'].index(triple_['subject']),
                    record_['text'].index(triple_['subject'])+ len(triple_['subject']) - 1,
                ])
                record_['entities'].append([
                    triple_['object']['@value'],
                    triple_['object_type']['@value'],
                    record_['text'].index(triple_['object']['@value']),
                    record_['text'].index(triple_['object']['@value']) + len(triple_['object']['@value']) - 1,
                ])
                record_['triples'].append([
                    triple_['subject'],
                    '疾病',
                    record_['text'].index(triple_['subject']),
                    record_['text'].index(triple_['subject'])+ len(triple_['subject']) - 1,
                    triple_['predicate'],
                    triple_['object']['@value'],
                    triple_['object_type']['@value'],
                    record_['text'].index(triple_['object']['@value']),
                    record_['text'].index(triple_['object']['@value']) + len(triple_['object']['@value']) - 1,
                ])
            record_['entities'] = list(set([tuple(entity) for entity in record_['entities']]))
            record_['entities'] = sorted(record_['entities'], key = lambda x: x[2])
            data_list.append(record_)
    return data_list

train_data_list = data_preprocess(train_data_path)
train_df = pd.DataFrame(train_data_list)

dev_data_list = data_preprocess(dev_data_path)
dev_df = pd.DataFrame(dev_data_list)

In [4]:
train_df

Unnamed: 0,text,entities,triples
0,产后抑郁症@区分产后抑郁症与轻度情绪失调（产后忧郁或“婴儿忧郁”）是重要的，因为轻度情绪失调...,"[(产后抑郁症, 疾病, 0, 4), (轻度情绪失调, 疾病, 14, 19)]","[[产后抑郁症, 疾病, 0, 4, 鉴别诊断, 轻度情绪失调, 疾病, 14, 19]]"
1,类风湿关节炎@尺侧偏斜是由于MCP关节炎症造成的。,"[(尺侧偏斜, 症状, 7, 10), (MCP关节炎症, 疾病, 14, 20)]","[[MCP关节炎症, 疾病, 14, 20, 临床表现, 尺侧偏斜, 症状, 7, 10]]"
2,唇腭裂@ ### 腭瘘 | 存在差异 | 低 大约 10% 至 20% 颚成形术发生腭瘘。 ...,"[(腭瘘, 疾病, 9, 10), (婴儿伤口, 社会学, 57, 60), (营养状况, ...","[[腭瘘, 疾病, 9, 10, 风险评估因素, 婴儿伤口, 社会学, 57, 60], [..."
3,成人哮喘@ 应在低剂量 ICS 的基础上加用一种 LABA， 或将ICS增加到中等剂量。 成...,"[(成人哮喘, 疾病, 0, 3), (哮喘, 疾病, 2, 3), (ICS, 药物, 1...","[[成人哮喘, 疾病, 0, 3, 药物治疗, ICS, 药物, 12, 14], [哮喘,..."
4,口咽癌@[ 声嘶及发声障碍的评估 ](/topics/zh-cn/845) ### 手术或放...,"[(口咽癌, 疾病, 0, 2), (放疗后吞咽困难, 疾病, 45, 51), (误吸, ...","[[口咽癌, 疾病, 0, 2, 并发症, 放疗后吞咽困难, 疾病, 45, 51], [口..."
...,...,...,...
14334,"【临床表现】 原发感染灶 多数脓毒症患者都有轻重不等的原发感染灶。 肝脾大 一般仅轻度增大,...","[(原发感染灶, 症状, 7, 11), (脓毒症, 疾病, 15, 17), (肝脾大, ...","[[脓毒症, 疾病, 15, 17, 临床表现, 肝脾大, 症状, 34, 36], [脓毒..."
14335,【病理】 RS的病理改变主要表现在脑和肝脏。电镜检查可见线粒体肿胀和变形，线粒体嵴可消失，肝...,"[(RS, 疾病, 5, 6), (脑, 部位, 17, 17), (肝脏, 部位, 19,...","[[RS, 疾病, 5, 6, 临床表现, 线粒体肿胀和变形, 症状, 28, 35], [..."
14336,综上所述，目前对SSPE的诊断只要具备相应的临床表现（不一定十分典型）以及脑脊液麻疹抗体升高...,"[(SSPE, 疾病, 8, 11), (干扰素, 药物, 121, 123)]","[[SSPE, 疾病, 8, 11, 药物治疗, 干扰素, 药物, 121, 123]]"
14337,喉癌@所有最常见癌症里排名第11，上呼吸消化道的最常见癌症里排名第2。 喉癌@疾病的发生发展...,"[(喉癌, 疾病, 0, 1), (烟酒嗜好, 社会学, 48, 51)]","[[喉癌, 疾病, 0, 1, 病因, 烟酒嗜好, 社会学, 48, 51]]"


In [5]:
# 没有分类标为"None"
categories = list(set([triple[4] for triples in train_df['triples'] for triple in triples])) + ['None']

re_train_dataset = Dataset(train_df, categories=categories)
re_dev_dataset = Dataset(dev_df, categories=categories,
                         is_train=False)

#### 2. 词典创建和生成分词器

In [6]:
from transformers import AutoTokenizer
bert_vocab = AutoTokenizer.from_pretrained('nghuyong/ernie-1.0')

In [7]:
entity_categories = list(set([entity[1] for entities in train_df['entities'] for entity in entities]))
special_tokens = []
for category in entity_categories:
    special_tokens.append(f'[{category}]')
    special_tokens.append(f'[/{category}]')
bert_vocab.add_special_tokens({'additional_special_tokens': special_tokens})

22

In [8]:
tokenizer = Tokenizer(bert_vocab, max_seq_len=200)

#### 3. ID化

In [9]:
re_train_dataset.convert_to_ids(tokenizer)
re_dev_dataset.convert_to_ids(tokenizer)

<br>

### 二、模型构建

#### 1. 模型参数设置

In [10]:
bert_config = AFEABertConfig.from_pretrained('nghuyong/ernie-1.0',
                                               num_labels=len(re_train_dataset.cat2id))

#### 2. 模型创建

In [11]:
dl_module = AFEABert.from_pretrained('nghuyong/ernie-1.0',
                                       config=bert_config)

Some weights of the model checkpoint at nghuyong/ernie-1.0 were not used when initializing AFEABert: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing AFEABert from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AFEABert from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of AFEABert were not initialized from the model checkpoint at nghuyong/ernie-1.0 and are newly initialized: ['linear.weight', 'linear.bias', 'classifier.weight', 'classifie

<br>

### 三、任务构建

#### 1. 任务参数和必要部件设定

In [12]:
optimizer = get_default_model_optimizer(dl_module)

#### 2. 任务创建

In [13]:
model = Task(dl_module, optimizer, 'ce', cuda_device=0)

#### 3. 训练

In [14]:
model.fit(
    re_train_dataset,
    re_dev_dataset,
    lr=2e-5,
    epochs=10,
    batch_size=32
)

 22%|██▏       | 100/449 [00:51<02:59,  1.94it/s]

[99/449],train loss is:1.236724,train evaluation is:10.461778


 45%|████▍     | 200/449 [01:43<02:09,  1.93it/s]

[199/449],train loss is:0.959788,train evaluation is:11.022111


 67%|██████▋   | 300/449 [02:35<01:17,  1.92it/s]

[299/449],train loss is:0.825182,train evaluation is:11.249333


 89%|████████▉ | 400/449 [03:27<00:25,  1.91it/s]

[399/449],train loss is:0.725161,train evaluation is:11.644833


100%|██████████| 449/449 [03:52<00:00,  1.93it/s]


epoch:[0],train loss is:0.689928,train evaluation is:11.622371 

classification_report: 
               precision    recall  f1-score   support

        发病机制       0.00      0.00      0.00         6
         死亡率       0.00      0.00      0.00        15
        传播途径       0.00      0.00      0.00        14
          阶段       0.00      0.00      0.00        38
        病理生理       0.00      0.00      0.00        13
 侵及周围组织转移的症状       0.00      0.00      0.00        12
       组织学检查       0.00      0.00      0.00        92
        预后状况       0.00      0.00      0.00        48
        发病部位       0.74      0.45      0.56       307
          化疗       0.00      0.00      0.00        34
        多发地区       0.00      0.00      0.00        35
        鉴别诊断       0.85      0.04      0.07       285
        病理分型       0.70      0.21      0.33       518
          病因       0.69      0.44      0.54       776
      风险评估因素       0.00      0.00      0.00       121
         同义词       0.78      0.39      0.52  

 22%|██▏       | 100/449 [00:52<03:02,  1.91it/s]

[99/449],train loss is:0.354710,train evaluation is:11.505778


 45%|████▍     | 200/449 [01:44<02:10,  1.91it/s]

[199/449],train loss is:0.325554,train evaluation is:12.092889


 67%|██████▋   | 300/449 [02:37<01:18,  1.91it/s]

[299/449],train loss is:0.306895,train evaluation is:12.300593


 89%|████████▉ | 400/449 [03:29<00:25,  1.91it/s]

[399/449],train loss is:0.290578,train evaluation is:12.551056


100%|██████████| 449/449 [03:54<00:00,  1.91it/s]


epoch:[1],train loss is:0.284249,train evaluation is:12.484088 

classification_report: 
               precision    recall  f1-score   support

        发病机制       0.00      0.00      0.00         6
         死亡率       0.00      0.00      0.00        15
        传播途径       0.00      0.00      0.00        14
          阶段       1.00      0.03      0.05        38
        病理生理       0.00      0.00      0.00        13
 侵及周围组织转移的症状       0.00      0.00      0.00        12
       组织学检查       0.80      0.04      0.08        92
        预后状况       1.00      0.27      0.43        48
        发病部位       0.76      0.83      0.79       307
          化疗       0.00      0.00      0.00        34
        多发地区       0.97      0.86      0.91        35
        鉴别诊断       0.67      0.64      0.65       285
        病理分型       0.68      0.57      0.62       518
          病因       0.77      0.79      0.78       776
      风险评估因素       0.50      0.01      0.02       121
         同义词       0.80      0.53      0.64  

 22%|██▏       | 100/449 [00:52<03:04,  1.89it/s]

[99/449],train loss is:0.199258,train evaluation is:13.177333


 45%|████▍     | 200/449 [01:45<02:10,  1.90it/s]

[199/449],train loss is:0.197066,train evaluation is:12.930000


 67%|██████▋   | 300/449 [02:37<01:18,  1.90it/s]

[299/449],train loss is:0.197421,train evaluation is:12.730000


 89%|████████▉ | 400/449 [03:30<00:25,  1.90it/s]

[399/449],train loss is:0.195366,train evaluation is:12.837778


100%|██████████| 449/449 [03:55<00:00,  1.90it/s]


epoch:[2],train loss is:0.193381,train evaluation is:12.722989 

classification_report: 
               precision    recall  f1-score   support

        发病机制       0.00      0.00      0.00         6
         死亡率       0.00      0.00      0.00        15
        传播途径       1.00      0.07      0.13        14
          阶段       1.00      0.68      0.81        38
        病理生理       0.00      0.00      0.00        13
 侵及周围组织转移的症状       0.00      0.00      0.00        12
       组织学检查       0.77      0.43      0.56        92
        预后状况       0.92      0.71      0.80        48
        发病部位       0.72      0.87      0.79       307
          化疗       0.54      0.56      0.55        34
        多发地区       1.00      0.91      0.96        35
        鉴别诊断       0.75      0.66      0.70       285
        病理分型       0.71      0.59      0.65       518
          病因       0.75      0.84      0.79       776
      风险评估因素       0.33      0.02      0.03       121
         同义词       0.82      0.68      0.74  

 22%|██▏       | 100/449 [00:52<03:03,  1.90it/s]

[99/449],train loss is:0.152358,train evaluation is:13.060222


 45%|████▍     | 200/449 [01:45<02:10,  1.90it/s]

[199/449],train loss is:0.154860,train evaluation is:12.831222


 67%|██████▋   | 300/449 [02:37<01:18,  1.90it/s]

[299/449],train loss is:0.153856,train evaluation is:12.836074


 89%|████████▉ | 400/449 [03:30<00:25,  1.90it/s]

[399/449],train loss is:0.151646,train evaluation is:12.899722


100%|██████████| 449/449 [03:56<00:00,  1.90it/s]


epoch:[3],train loss is:0.151095,train evaluation is:12.856174 

classification_report: 
               precision    recall  f1-score   support

        发病机制       0.00      0.00      0.00         6
         死亡率       0.91      0.67      0.77        15
        传播途径       1.00      0.36      0.53        14
          阶段       1.00      0.82      0.90        38
        病理生理       0.00      0.00      0.00        13
 侵及周围组织转移的症状       0.00      0.00      0.00        12
       组织学检查       0.78      0.55      0.65        92
        预后状况       0.83      0.79      0.81        48
        发病部位       0.79      0.88      0.83       307
          化疗       0.66      0.62      0.64        34
        多发地区       1.00      0.94      0.97        35
        鉴别诊断       0.73      0.72      0.72       285
        病理分型       0.75      0.64      0.69       518
          病因       0.83      0.77      0.80       776
      风险评估因素       0.30      0.10      0.15       121
         同义词       0.82      0.69      0.75  

 22%|██▏       | 100/449 [00:52<03:03,  1.90it/s]

[99/449],train loss is:0.130522,train evaluation is:12.952667


 45%|████▍     | 200/449 [01:45<02:10,  1.90it/s]

[199/449],train loss is:0.124710,train evaluation is:12.776222


 67%|██████▋   | 300/449 [02:37<01:18,  1.90it/s]

[299/449],train loss is:0.125742,train evaluation is:12.854444


 89%|████████▉ | 400/449 [03:30<00:25,  1.91it/s]

[399/449],train loss is:0.126414,train evaluation is:13.003778


100%|██████████| 449/449 [03:55<00:00,  1.90it/s]


epoch:[4],train loss is:0.126512,train evaluation is:12.943232 

classification_report: 
               precision    recall  f1-score   support

        发病机制       0.00      0.00      0.00         6
         死亡率       1.00      0.60      0.75        15
        传播途径       1.00      0.64      0.78        14
          阶段       0.86      0.82      0.84        38
        病理生理       0.00      0.00      0.00        13
 侵及周围组织转移的症状       0.00      0.00      0.00        12
       组织学检查       0.79      0.50      0.61        92
        预后状况       0.90      0.77      0.83        48
        发病部位       0.84      0.84      0.84       307
          化疗       0.72      0.53      0.61        34
        多发地区       1.00      0.94      0.97        35
        鉴别诊断       0.84      0.60      0.70       285
        病理分型       0.72      0.67      0.69       518
          病因       0.82      0.81      0.81       776
      风险评估因素       0.25      0.17      0.20       121
         同义词       0.84      0.70      0.77  

 22%|██▏       | 100/449 [00:52<03:03,  1.90it/s]

[99/449],train loss is:0.105923,train evaluation is:13.904222


 45%|████▍     | 200/449 [01:45<02:11,  1.90it/s]

[199/449],train loss is:0.106391,train evaluation is:13.421000


 67%|██████▋   | 300/449 [02:38<01:18,  1.90it/s]

[299/449],train loss is:0.107408,train evaluation is:13.195852


 89%|████████▉ | 400/449 [03:30<00:25,  1.90it/s]

[399/449],train loss is:0.106876,train evaluation is:13.097333


100%|██████████| 449/449 [03:56<00:00,  1.90it/s]


epoch:[5],train loss is:0.107906,train evaluation is:13.006780 

classification_report: 
               precision    recall  f1-score   support

        发病机制       0.33      0.17      0.22         6
         死亡率       1.00      0.67      0.80        15
        传播途径       1.00      0.64      0.78        14
          阶段       0.73      0.87      0.80        38
        病理生理       1.00      0.08      0.14        13
 侵及周围组织转移的症状       0.00      0.00      0.00        12
       组织学检查       0.83      0.58      0.68        92
        预后状况       0.89      0.83      0.86        48
        发病部位       0.81      0.87      0.84       307
          化疗       0.73      0.65      0.69        34
        多发地区       1.00      0.94      0.97        35
        鉴别诊断       0.82      0.74      0.78       285
        病理分型       0.75      0.64      0.69       518
          病因       0.84      0.82      0.83       776
      风险评估因素       0.45      0.27      0.34       121
         同义词       0.87      0.67      0.76  

 22%|██▏       | 100/449 [00:52<03:03,  1.90it/s]

[99/449],train loss is:0.101697,train evaluation is:12.865778


 45%|████▍     | 200/449 [01:45<02:11,  1.90it/s]

[199/449],train loss is:0.095450,train evaluation is:12.993111


 67%|██████▋   | 300/449 [02:38<01:18,  1.90it/s]

[299/449],train loss is:0.093971,train evaluation is:12.806148


 89%|████████▉ | 400/449 [03:30<00:25,  1.90it/s]

[399/449],train loss is:0.093441,train evaluation is:13.007667


100%|██████████| 449/449 [03:56<00:00,  1.90it/s]


epoch:[6],train loss is:0.093083,train evaluation is:13.065479 

classification_report: 
               precision    recall  f1-score   support

        发病机制       0.00      0.00      0.00         6
         死亡率       1.00      0.80      0.89        15
        传播途径       0.80      0.86      0.83        14
          阶段       0.79      0.87      0.82        38
        病理生理       1.00      0.31      0.47        13
 侵及周围组织转移的症状       0.00      0.00      0.00        12
       组织学检查       0.90      0.39      0.55        92
        预后状况       0.89      0.88      0.88        48
        发病部位       0.79      0.87      0.83       307
          化疗       0.74      0.74      0.74        34
        多发地区       1.00      1.00      1.00        35
        鉴别诊断       0.77      0.80      0.78       285
        病理分型       0.75      0.65      0.70       518
          病因       0.81      0.87      0.83       776
      风险评估因素       0.58      0.06      0.11       121
         同义词       0.84      0.73      0.78  

 22%|██▏       | 100/449 [00:52<03:04,  1.89it/s]

[99/449],train loss is:0.089917,train evaluation is:12.986000


 45%|████▍     | 200/449 [01:45<02:10,  1.90it/s]

[199/449],train loss is:0.085954,train evaluation is:13.003333


 67%|██████▋   | 300/449 [02:38<01:18,  1.89it/s]

[299/449],train loss is:0.083432,train evaluation is:13.169259


 89%|████████▉ | 400/449 [03:30<00:25,  1.90it/s]

[399/449],train loss is:0.083254,train evaluation is:13.116722


100%|██████████| 449/449 [03:56<00:00,  1.90it/s]


epoch:[7],train loss is:0.082980,train evaluation is:13.098243 

classification_report: 
               precision    recall  f1-score   support

        发病机制       0.50      0.17      0.25         6
         死亡率       0.92      0.80      0.86        15
        传播途径       0.73      0.57      0.64        14
          阶段       0.77      0.87      0.81        38
        病理生理       1.00      0.31      0.47        13
 侵及周围组织转移的症状       0.00      0.00      0.00        12
       组织学检查       0.83      0.60      0.70        92
        预后状况       0.91      0.90      0.91        48
        发病部位       0.79      0.89      0.84       307
          化疗       0.77      0.71      0.74        34
        多发地区       1.00      0.94      0.97        35
        鉴别诊断       0.79      0.79      0.79       285
        病理分型       0.76      0.64      0.69       518
          病因       0.82      0.85      0.83       776
      风险评估因素       0.42      0.33      0.37       121
         同义词       0.89      0.67      0.77  

 22%|██▏       | 100/449 [00:52<03:03,  1.90it/s]

[99/449],train loss is:0.071823,train evaluation is:13.029111


 45%|████▍     | 200/449 [01:45<02:11,  1.89it/s]

[199/449],train loss is:0.072761,train evaluation is:12.920667


 67%|██████▋   | 300/449 [02:38<01:18,  1.89it/s]

[299/449],train loss is:0.072104,train evaluation is:13.106963


 89%|████████▉ | 400/449 [03:30<00:25,  1.89it/s]

[399/449],train loss is:0.072230,train evaluation is:13.181389


100%|██████████| 449/449 [03:56<00:00,  1.90it/s]


epoch:[8],train loss is:0.071438,train evaluation is:13.137639 

classification_report: 
               precision    recall  f1-score   support

        发病机制       0.67      0.33      0.44         6
         死亡率       0.86      0.80      0.83        15
        传播途径       0.80      0.86      0.83        14
          阶段       0.92      0.87      0.89        38
        病理生理       1.00      0.46      0.63        13
 侵及周围组织转移的症状       0.00      0.00      0.00        12
       组织学检查       0.86      0.62      0.72        92
        预后状况       0.90      0.79      0.84        48
        发病部位       0.81      0.86      0.83       307
          化疗       0.81      0.76      0.79        34
        多发地区       1.00      1.00      1.00        35
        鉴别诊断       0.78      0.79      0.78       285
        病理分型       0.77      0.69      0.73       518
          病因       0.86      0.84      0.85       776
      风险评估因素       0.42      0.33      0.37       121
         同义词       0.93      0.65      0.76  

 22%|██▏       | 100/449 [00:52<03:03,  1.90it/s]

[99/449],train loss is:0.063540,train evaluation is:12.916444


 45%|████▍     | 200/449 [01:45<02:11,  1.90it/s]

[199/449],train loss is:0.062000,train evaluation is:12.912667


 67%|██████▋   | 300/449 [02:38<01:18,  1.90it/s]

[299/449],train loss is:0.062694,train evaluation is:13.089630


 89%|████████▉ | 400/449 [03:30<00:25,  1.90it/s]

[399/449],train loss is:0.062734,train evaluation is:13.252944


100%|██████████| 449/449 [03:56<00:00,  1.90it/s]


epoch:[9],train loss is:0.062784,train evaluation is:13.167632 

classification_report: 
               precision    recall  f1-score   support

        发病机制       0.25      0.17      0.20         6
         死亡率       0.92      0.80      0.86        15
        传播途径       0.80      0.86      0.83        14
          阶段       0.92      0.87      0.89        38
        病理生理       1.00      0.46      0.63        13
 侵及周围组织转移的症状       0.00      0.00      0.00        12
       组织学检查       0.83      0.60      0.70        92
        预后状况       0.93      0.88      0.90        48
        发病部位       0.83      0.89      0.86       307
          化疗       0.79      0.76      0.78        34
        多发地区       1.00      1.00      1.00        35
        鉴别诊断       0.77      0.82      0.80       285
        病理分型       0.73      0.73      0.73       518
          病因       0.83      0.84      0.84       776
      风险评估因素       0.35      0.50      0.41       121
         同义词       0.82      0.75      0.79  

In [16]:
model.fit(
    re_train_dataset,
    re_dev_dataset,
    epochs=5,
    batch_size=32
)

 22%|██▏       | 100/449 [00:51<03:00,  1.93it/s]

[99/449],train loss is:0.058835,train evaluation is:12.428222


 45%|████▍     | 200/449 [01:43<02:10,  1.91it/s]

[199/449],train loss is:0.057734,train evaluation is:12.950889


 67%|██████▋   | 300/449 [02:35<01:18,  1.91it/s]

[299/449],train loss is:0.056343,train evaluation is:13.010444


 89%|████████▉ | 400/449 [03:28<00:25,  1.91it/s]

[399/449],train loss is:0.055600,train evaluation is:13.146889


100%|██████████| 449/449 [03:53<00:00,  1.92it/s]


epoch:[0],train loss is:0.055835,train evaluation is:13.196041 

classification_report: 
               precision    recall  f1-score   support

        发病机制       0.29      0.33      0.31         6
         死亡率       1.00      0.80      0.89        15
        传播途径       0.80      0.86      0.83        14
          阶段       0.89      0.87      0.88        38
        病理生理       1.00      0.46      0.63        13
 侵及周围组织转移的症状       0.00      0.00      0.00        12
       组织学检查       0.84      0.66      0.74        92
        预后状况       0.93      0.88      0.90        48
        发病部位       0.82      0.91      0.86       307
          化疗       0.90      0.76      0.83        34
        多发地区       1.00      0.97      0.99        35
        鉴别诊断       0.77      0.81      0.79       285
        病理分型       0.74      0.73      0.73       518
          病因       0.83      0.86      0.85       776
      风险评估因素       0.46      0.40      0.43       121
         同义词       0.74      0.82      0.78  

 22%|██▏       | 100/449 [00:52<03:03,  1.90it/s]

[99/449],train loss is:0.054022,train evaluation is:13.175556


 45%|████▍     | 200/449 [01:45<02:11,  1.90it/s]

[199/449],train loss is:0.051770,train evaluation is:13.300444


 67%|██████▋   | 300/449 [02:37<01:18,  1.90it/s]

[299/449],train loss is:0.051319,train evaluation is:13.435556


 89%|████████▉ | 400/449 [03:30<00:25,  1.89it/s]

[399/449],train loss is:0.052516,train evaluation is:13.333833


100%|██████████| 449/449 [03:55<00:00,  1.90it/s]


epoch:[1],train loss is:0.052851,train evaluation is:13.209552 

classification_report: 
               precision    recall  f1-score   support

        发病机制       0.50      0.17      0.25         6
         死亡率       1.00      0.80      0.89        15
        传播途径       1.00      0.64      0.78        14
          阶段       0.91      0.84      0.88        38
        病理生理       1.00      0.46      0.63        13
 侵及周围组织转移的症状       0.00      0.00      0.00        12
       组织学检查       0.87      0.59      0.70        92
        预后状况       0.91      0.83      0.87        48
        发病部位       0.87      0.84      0.85       307
          化疗       0.85      0.85      0.85        34
        多发地区       0.97      1.00      0.99        35
        鉴别诊断       0.79      0.79      0.79       285
        病理分型       0.76      0.68      0.72       518
          病因       0.83      0.83      0.83       776
      风险评估因素       0.57      0.25      0.34       121
         同义词       0.94      0.66      0.77  

 22%|██▏       | 100/449 [00:52<03:03,  1.91it/s]

[99/449],train loss is:0.050055,train evaluation is:12.992889


 45%|████▍     | 200/449 [01:45<02:10,  1.90it/s]

[199/449],train loss is:0.047577,train evaluation is:13.227444


 67%|██████▋   | 300/449 [02:38<01:18,  1.90it/s]

[299/449],train loss is:0.047608,train evaluation is:13.344741


 89%|████████▉ | 400/449 [03:30<00:25,  1.90it/s]

[399/449],train loss is:0.048061,train evaluation is:13.204667


100%|██████████| 449/449 [03:56<00:00,  1.90it/s]


epoch:[2],train loss is:0.047853,train evaluation is:13.225043 

classification_report: 
               precision    recall  f1-score   support

        发病机制       0.25      0.17      0.20         6
         死亡率       1.00      0.80      0.89        15
        传播途径       1.00      0.50      0.67        14
          阶段       0.89      0.84      0.86        38
        病理生理       1.00      0.31      0.47        13
 侵及周围组织转移的症状       0.00      0.00      0.00        12
       组织学检查       0.85      0.61      0.71        92
        预后状况       0.91      0.88      0.89        48
        发病部位       0.83      0.91      0.87       307
          化疗       0.93      0.79      0.86        34
        多发地区       0.97      1.00      0.99        35
        鉴别诊断       0.70      0.82      0.76       285
        病理分型       0.75      0.71      0.73       518
          病因       0.80      0.88      0.84       776
      风险评估因素       0.54      0.41      0.47       121
         同义词       0.86      0.74      0.80  

 22%|██▏       | 100/449 [00:52<03:03,  1.90it/s]

[99/449],train loss is:0.045359,train evaluation is:13.706444


 45%|████▍     | 200/449 [01:45<02:11,  1.90it/s]

[199/449],train loss is:0.043236,train evaluation is:13.771222


 67%|██████▋   | 300/449 [02:37<01:19,  1.88it/s]

[299/449],train loss is:0.043686,train evaluation is:13.533704


 89%|████████▉ | 400/449 [03:30<00:25,  1.90it/s]

[399/449],train loss is:0.044401,train evaluation is:13.216611


100%|██████████| 449/449 [03:56<00:00,  1.90it/s]


epoch:[3],train loss is:0.044196,train evaluation is:13.240535 

classification_report: 
               precision    recall  f1-score   support

        发病机制       0.40      0.33      0.36         6
         死亡率       1.00      0.80      0.89        15
        传播途径       1.00      0.64      0.78        14
          阶段       0.82      0.87      0.85        38
        病理生理       1.00      0.46      0.63        13
 侵及周围组织转移的症状       0.00      0.00      0.00        12
       组织学检查       0.84      0.64      0.73        92
        预后状况       0.94      0.92      0.93        48
        发病部位       0.81      0.92      0.86       307
          化疗       0.77      0.79      0.78        34
        多发地区       0.97      0.94      0.96        35
        鉴别诊断       0.83      0.80      0.81       285
        病理分型       0.77      0.69      0.73       518
          病因       0.83      0.86      0.84       776
      风险评估因素       0.43      0.49      0.46       121
         同义词       0.84      0.77      0.80  

 22%|██▏       | 100/449 [00:52<03:04,  1.89it/s]

[99/449],train loss is:0.040861,train evaluation is:13.304000


 45%|████▍     | 200/449 [01:45<02:11,  1.89it/s]

[199/449],train loss is:0.042164,train evaluation is:13.520222


 67%|██████▋   | 300/449 [02:37<01:18,  1.90it/s]

[299/449],train loss is:0.042174,train evaluation is:13.208519


 89%|████████▉ | 400/449 [03:30<00:25,  1.90it/s]

[399/449],train loss is:0.041089,train evaluation is:13.206500


100%|██████████| 449/449 [03:56<00:00,  1.90it/s]


epoch:[4],train loss is:0.041317,train evaluation is:13.248057 

classification_report: 
               precision    recall  f1-score   support

        发病机制       0.25      0.17      0.20         6
         死亡率       0.92      0.80      0.86        15
        传播途径       1.00      0.64      0.78        14
          阶段       0.89      0.82      0.85        38
        病理生理       1.00      0.54      0.70        13
 侵及周围组织转移的症状       0.00      0.00      0.00        12
       组织学检查       0.94      0.51      0.66        92
        预后状况       0.95      0.88      0.91        48
        发病部位       0.83      0.88      0.86       307
          化疗       0.85      0.68      0.75        34
        多发地区       0.97      1.00      0.99        35
        鉴别诊断       0.80      0.81      0.81       285
        病理分型       0.76      0.69      0.73       518
          病因       0.87      0.80      0.83       776
      风险评估因素       0.51      0.45      0.48       121
         同义词       0.89      0.72      0.80  

In [15]:
assert(1>2)

AssertionError: 

<br>

### 四、模型预测

In [17]:
from tqdm import tqdm
from ark_nlp.model.re.afea_bert import Predictor

afea_re_predictor_instance = Predictor(model.module, tokenizer, re_train_dataset.cat2id)

In [18]:
text = '急性胰腺炎@有研究显示，进行早期 ERCP （24 小时内）可以降低梗阻性胆总管结石患者的并发症发生率和死亡率； 但是，对于无胆总管梗阻的胆汁性急性胰腺炎患者，不需要进行早期 ERCP。'
entities = [('急性胰腺炎', '疾病', 0, 4), ('ERCP', '检查', 17, 20)]

In [19]:
afea_re_predictor_instance.predict_one_sample(text, entities, topk=1)

[[('急性胰腺炎', '疾病', 0, 4), '影像学检查', ('ERCP', '检查', 17, 20)],
 [('ERCP', '检查', 17, 20), 'None', ('急性胰腺炎', '疾病', 0, 4)]]

<br>

### 五、CMeIE结果输出

In [20]:
test_df = pd.read_json('output_datasets/CMeEE_test.json')

In [21]:
test_df['entities'] = test_df['entities'].apply(lambda x: [(entity['entity'], entity['type'], int(entity['start_idx']), int(entity['end_idx'])) for entity in x])

In [22]:
result = []
for text, entities in tqdm(zip(test_df['text'], test_df['entities'])):
    result.append(afea_re_predictor_instance.predict_one_sample(text, entities, topk=1))

4482it [01:03, 70.61it/s]


In [23]:
test_df['triples'] = result

In [24]:
test_df

Unnamed: 0,text,entities,triples
0,痔@肛镜检查方法简便安全，可全方位观察肛管及所有痔组织。痔@另一种替代的方法为纤维内窥镜反转...,"[(肛镜检查, 检查, 2, 5), (纤维内窥镜, 检查, 39, 43), (纤维内窥镜...","[[(肛镜检查, 检查, 2, 5), None, (纤维内窥镜, 检查, 39, 43)]..."
1,慢性肾病@进行眼底检查十分必要，因为其有助于诊断糖尿病性或高血压性视网膜病变，有助于判断在肾...,"[(眼底检查, 检查, 7, 10), (慢性肾病, 疾病, 0, 3), (高血压性视网膜...","[[(眼底检查, 检查, 7, 10), None, (慢性肾病, 疾病, 0, 3)], ..."
2,5.药物因素 引起消化性溃疡的药物中较重要的有三类：①阿司匹林（ASA）；②非甾体抗炎药物（...,"[(消化性溃疡, 疾病, 9, 13), (阿司匹林, 疾病, 27, 30), (非甾体抗...","[[(消化性溃疡, 疾病, 9, 13), 病因, (阿司匹林, 疾病, 27, 30)],..."
3,"登革热@大约 90% 的登革出血热 (dengue haemorrhagic fever, ...","[(5 岁以下儿童, 流行病学, 55, 61), (成人, 流行病学, 78, 79), ...","[[(5 岁以下儿童, 流行病学, 55, 61), None, (成人, 流行病学, 78..."
4,"【病因】 缺氧是HIE发病的核心,其中围生期窒息是最主要的病因。 【发病机制】 脑血流改变 ...","[(HIE, 疾病, 8, 10), (缺氧, 社会学, 5, 6), (围生期窒息, 社会...","[[(HIE, 疾病, 8, 10), 病因, (缺氧, 社会学, 5, 6)], [(HI..."
...,...,...,...
4477,铅中毒@由于总体的毒性作用及缺乏确认的疗效，在非妊娠患者中青霉胺是三线药物。,"[(铅中毒, 疾病, 0, 2), (青霉胺, 药物, 29, 31)]","[[(铅中毒, 疾病, 0, 2), 药物治疗, (青霉胺, 药物, 29, 31)], [..."
4478,狂犬病@可通过直接荧光抗体检测出狂犬病病毒抗原或PCR检测出狂犬病病毒RNA 。,"[(荧光抗体, 检查, 9, 12), (PCR, 检查, 24, 26), (狂犬病, 疾...","[[(荧光抗体, 检查, 9, 12), None, (PCR, 检查, 24, 26)],..."
4479,（二）中型至大型缺损 患儿常在生后1～2个月肺循环阻力下降时出现临床表现。由于肺循环流量大产...,"[(中型至大型缺损, 疾病, 3, 9), (肺静脉压力增高, 症状, 51, 57), (...","[[(中型至大型缺损, 疾病, 3, 9), 临床表现, (肺静脉压力增高, 症状, 51,..."
4480,狂犬病@ 麻痹型狂犬病 * 前驱期发热症状后快速进展为弛缓性瘫痪。狂犬病@之后发生行为改变。,"[(狂犬病, 疾病, 0, 2), (麻痹型狂犬病, 疾病, 5, 10), (弛缓性瘫痪,...","[[(狂犬病, 疾病, 0, 2), 病理分型, (麻痹型狂犬病, 疾病, 5, 10)],..."


In [25]:
result_ = []
for triples in result:
    
    triples_ = []
    for triple in triples:
        
        if triple[1] == 'None' or triple[0][1] != '疾病':
            continue
        predicate = triple[1] + '@' + triple[2][1]
        
        triples_.append((triple[0][0], predicate, triple[2][0]))
    result_.append(triples_)

In [26]:
test_data_path = '/home/shencj/workspace/data/medical/CBLUE/CMeIE/CMeIE_test.json'
schemas_data_path = '/home/shencj/workspace/data/medical/CBLUE/CMeIE/53_schemas.json'
output_data_path = 'output_datasets/CMeIE_test.json'

In [27]:
all_subject_type = []
all_predicate = []
all_shcemas = []
predicate2subject = {}
with open(schemas_data_path, 'r', encoding='utf-8') as fs:
    for jsonstr in fs.readlines():
        jsonstr = json.loads(jsonstr)
        # all_shcemas.append(jsonstr)
        
        predicate2subject[jsonstr['predicate']+'@'+jsonstr['object_type']] = jsonstr['subject_type']
        
    fs.close()
    
all_predicate = set(all_predicate)
with open(output_data_path, 'w', encoding='utf-8') as fw:
    with open(test_data_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        for index_, jsonstr in tqdm(enumerate(lines)):
            line = json.loads(jsonstr)
            results_len = []
            sentence = line['text']
            dict_list = result_[index_]
            new = []
            for list_ in dict_list:
                for predicate_ in predicate2subject:
                    if list_[1] == predicate_:
                        if list_[-1] != '' and list_[-1] != '[UNK]':
                            result_dict = {
                                'predicate': predicate_.split('@')[0],
                                "subject": list_[0],
                                'subject_type': predicate2subject[predicate_],
                                "object": {"@value": list_[-1]},
                                'object_type': {"@value":predicate_.split('@')[-1]}
                                }
                        else:
                            continue
                        if result_dict not in new:
                            new.append(result_dict)
            if sum([item.count('。') for item in sentence]) >= 2:
                for item in new:
                    item['Combined'] = True
            else:
                for item in new:
                    item['Combined'] = False

            if len(new) == 0:
                new = [{
                    "Combined": '',
                    "predicate": '',
                    "subject": '',
                    "subject_type": '',
                    "object": {"@value": ""},
                    "object_type": {"@value": ""},
                }]
                pred_dict = {
                    "text": ''.join(sentence),
                    "spo_list": new,
                }
            else:

                pred_dict = {
                    "text": ''.join(sentence),
                    "spo_list": new,
                }
            fw.write(json.dumps(pred_dict, ensure_ascii=False) + '\n')
f.close()
fw.close()

4482it [00:00, 22121.37it/s]
