In [1]:
import json
import re

In [2]:
data_dir = './CMeIE/'

In [3]:
relation_schemas = []
for line in open(data_dir + '53_schemas.json',encoding='utf-8'):
    predicate = json.loads(line)['predicate']
    if predicate not in relation_schemas:
        relation_schemas.append(predicate)

In [4]:
label_mapping = {
    'DO':['预防','阶段','就诊科室'], # Disease_Others
    'DOT':['辅助治疗','化疗','放射治疗'], # Disease_Other Therapy
    'DST':['手术治疗'], # Disease_Surgical Therapy
    'DT':['实验室检查','影像学检查','辅助检查','组织学检查','内窥镜检查','筛查'], # Disease_Test
    'DE':['多发群体','发病率','发病年龄','多发地区','发病性别倾向','死亡率','传播途径','多发季节'], # Disease_Epidemiology
    'DDi':['并发症','病理分型','相关（导致）','鉴别诊断','相关（转化）','相关（症状）'], # Disease_Disease
    'DSy':['临床表现','治疗后症状','侵及周围组织转移的症状'], # Disease_Symptoms
    'DSo':['病因','高危因素','风险评估因素','病史','遗传因素','发病机制','病理生理'], # Disease_Sociology
    'DDr':['药物治疗'], # Disease_Drugs
    'DB':['发病部位','转移部位','外侵部位'], # Disease_Body
    'DP':['预后状况','预后生存率'], # Disease_Prognosis
    'Synonyms':['同义词'] # Synonyms
}

In [6]:
len(label_mapping.values())

12

In [218]:
# 统计出现次数最多的关系
# labels = list(label_mapping.keys())
# label_count = [0] * len(label_mapping)
# sequence_count = [0] * len(label_mapping)

In [219]:
Disease_Disease_Label_Mapping = {
    '并发症':'C',
    '病理分型':'PA',
    '相关（导致）':'RO',
    '鉴别诊断':'DD',
    '相关（转化）':'RT',
    '相关（症状）':'RS'
}

In [220]:
# labels = list(Disease_Disease_Label_Mapping.keys())
# label_count = [0] * len(Disease_Disease_Label_Mapping)
# sequence_count = [0] * len(Disease_Disease_Label_Mapping)

In [221]:
def labeling(text_index,label_tmp,start,end,template,error_index,flag):
    if flag == 'subject':
        if label_tmp[start:end] == ['O']*(end-start) or (label_tmp[start:start+1] == ['B-'+template+'-1'] and label_tmp[start+1:end] == ['I-'+template+'-1']*(end-start-1)):
            label_tmp = label_tmp[0:start] + ['B-'+template+'-1']*1 + label_tmp[start+1:]
            label_tmp = label_tmp[0:start+1] + ['I-'+template+'-1']*(end-start-1) + label_tmp[end:]
        else:
            if text_index not in error_index:
                error_index.append(text_index)
    else:
        if label_tmp[start:end] == ['O']*(end-start) or (label_tmp[start:start+1] == ['B-'+template+'-2'] and label_tmp[start+1:end] == ['I-'+template+'-2']*(end-start-1)):
            label_tmp = label_tmp[0:start] + ['B-'+template+'-2']*1 + label_tmp[start+1:]
            label_tmp = label_tmp[0:start+1] + ['I-'+template+'-2']*(end-start-1) + label_tmp[end:]
        else:
            if text_index not in error_index:
                error_index.append(text_index)
    return label_tmp,error_index

In [222]:
def dealing_escape_tag(escape_tag,subject_):
    for tag in escape_tag:
        escape_index = subject_.find(tag)
        if escape_index > -1:
            subject_list = list(subject_)
            if escape_index > 0:
                subject_list.insert(escape_index-1, '\\')
                subject_ = ''.join(subject_list)
            else:
                subject_ = '\\' + subject_
    return subject_      

In [223]:
def generate_dataset(dataset_name):
    train = {
        'text':[],
        'label':[]
    }
    train_final = {
        'text':[],
        'label':[]
    }
    text_index = 0
    escape_tag = ['（','）','(',')']
    ignore_index = []
    error_index = []
    for line in open(data_dir + dataset_name,encoding='utf-8'):
        text = json.loads(line)['text']
        train['text'].append(text)

        spo_list = json.loads(line)['spo_list']

        label_tmp = ['O']*len(text)

        predicates = list(set(list(map(lambda x:x['predicate'],spo_list))))
        predicates_intersection = [item for item in predicates if item in list(Disease_Disease_Label_Mapping.keys())]

        if len(predicates_intersection) == 0:
            ignore_index.append(text_index)
        else:
            for item in spo_list:
                relation_sub_label = item['predicate']
                subject_ = item['subject']
                object_ = item['object']['@value']
                if relation_sub_label in Disease_Disease_Label_Mapping:
                    subject_ = dealing_escape_tag(escape_tag,subject_)
                    object_ =  dealing_escape_tag(escape_tag,object_)
                    try:
                        subject_index = [(m.start(0), m.end(0)) for m in re.finditer(subject_, text)]
                        object_index = [(m.start(0), m.end(0)) for m in re.finditer(object_, text)]
                    except Exception:
        #                 print('escape_error',text_index)
                        if text_index not in error_index:
                            error_index.append(text_index)
                    for item in subject_index:
                        label_tmp,error_index = labeling(text_index,label_tmp,item[0],item[1],Disease_Disease_Label_Mapping[relation_sub_label],error_index,'subject')
                    for item in object_index:
                        label_tmp,error_index = labeling(text_index,label_tmp,item[0],item[1],Disease_Disease_Label_Mapping[relation_sub_label],error_index,'object')
        #             print(text)
        #             print(subject,object_)
        text_index += 1
        train['label'].append(label_tmp)

    for index,item in enumerate(train['text']):
        if index not in ignore_index and index not in error_index:
            train_final['text'].append(train['text'][index])
            train_final['label'].append(train['label'][index])
    return train_final  

In [224]:
train = generate_dataset('CMeIE_train.json')
dev = generate_dataset('CMeIE_dev.json')

In [225]:
import pickle
with open(data_dir + 'CMeIE_train', 'wb') as f_train:
    pickle.dump(train, f_train)
with open(data_dir + 'CMeIE_dev', 'wb') as f_val:
    pickle.dump(dev, f_val)

In [228]:
def dataset_to_txt(target_txt_name,dataset):
    with open(data_dir + target_txt_name,'w') as f:
        res = ''
        for t_index,item in enumerate(dataset['text']):
            char_list = list(item)
            for c_index,char in enumerate(char_list):
                res += (char+'\t')
                res += dataset['label'][t_index][c_index]
                res += '\n'
            res += '\n'
        f.write(res)

In [229]:
dataset_to_txt('CMeIE_train.txt',train)
dataset_to_txt('CMeIE_dev.txt',dev)

In [64]:
train_spo_list = []
for line in open(data_dir + 'CMeIE_train.json',encoding='utf-8'):
    spo_list = json.loads(line)['spo_list']
    train_spo_list.append(spo_list)

In [65]:
from sklearn import model_selection

pool, labeled = model_selection.train_test_split(train_spo_list,test_size=0.2 ,random_state=1213)

In [76]:
labels = list(Disease_Disease_Label_Mapping.keys())

In [77]:
labels

['并发症', '病理分型', '相关（导致）', '鉴别诊断', '相关（转化）', '相关（症状）']

In [78]:
label_count = [0] * len(labels)
sequence_count = [0] * len(labels)

In [79]:
# 统计“疾病-疾病”子关系关系出现的次数
for spo_list in labeled:
    for item in spo_list:
        relation_sub_label = item['predicate']
        for index,label in enumerate(labels):
            if label == relation_sub_label:
                label_count[index] += 1

In [80]:
label_count

[447, 351, 284, 245, 173, 76]

In [48]:
# 统计关系出现的次数
for line in open(data_dir + 'CMeIE_train.json',encoding='utf-8'):
    text = json.loads(line)['text']
    spo_list = json.loads(line)['spo_list']
    for item in spo_list:
        relation_sub_label = item['predicate']
        relation_label = list(filter(lambda x:relation_sub_label in x[1],label_mapping.items()))[0][0] 
        for index,label in enumerate(labels):
            if label == relation_label:
                label_count[index] += 1
    for index,label in enumerate(labels):
        if label == relation_label:
            sequence_count[index] += 1

In [49]:
labels

['DO',
 'DOT',
 'DST',
 'DT',
 'DE',
 'DDi',
 'DSy',
 'DSo',
 'DDr',
 'DB',
 'DP',
 'Synonyms']

In [50]:
label_count

[641, 1833, 923, 4433, 1857, 7879, 11761, 4617, 4570, 1536, 265, 3345]

In [51]:
sequence_count

[289, 804, 454, 1925, 780, 2335, 2658, 1566, 1879, 413, 101, 1135]

In [21]:
relation_schemas

['预防',
 '阶段',
 '就诊科室',
 '同义词',
 '辅助治疗',
 '化疗',
 '放射治疗',
 '手术治疗',
 '实验室检查',
 '影像学检查',
 '辅助检查',
 '组织学检查',
 '内窥镜检查',
 '筛查',
 '多发群体',
 '发病率',
 '发病年龄',
 '多发地区',
 '发病性别倾向',
 '死亡率',
 '多发季节',
 '传播途径',
 '并发症',
 '病理分型',
 '相关（导致）',
 '鉴别诊断',
 '相关（转化）',
 '相关（症状）',
 '临床表现',
 '治疗后症状',
 '侵及周围组织转移的症状',
 '病因',
 '高危因素',
 '风险评估因素',
 '病史',
 '遗传因素',
 '发病机制',
 '病理生理',
 '药物治疗',
 '发病部位',
 '转移部位',
 '外侵部位',
 '预后状况',
 '预后生存率']