# **一**、Sofa 中文医疗关系抽取
kelvincjr/shared/casrel-main/
19年关系抽取sofa论文的pytorch复现, 用于天池的中文医疗关系抽取数据集,Result: F1 = 54.6
!pip install keras==2.2.4 tensorflow==1.13.1 keras-bert==0.81.1 tensorflow-gpu==1.13.1 transformers

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip /content/drive/MyDrive/Data/NLP/NLP_中文医疗信息处理挑战榜CBLUE/CMeIE.zip

In [None]:
import os

base_path = os.path.dirname(os.getcwd())
data_path = base_path + 'content/CMeIE'
print(data_path)

## 数据预处理

In [None]:
import json
import re
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from transformers import BertTokenizer
from random import choice
import codecs

device = 'cuda:0' #device = 'cpu'

class MyDataset(Dataset):    
    def __get_train_data(self, path):
        with open(path, 'r', encoding='utf-8') as f:
            data = f.readlines()
        res = [json.loads(i) for i in data]
        return res

    def __get_test_data(self, path):
        with open(path, 'r', encoding='utf-8') as f:
            data = f.readlines()
        data = [json.loads(i) for i in data]
        res = []
        for entry in data:
            entry['spo_list'] = []
            res.append(entry)
        return res
    
    def __init__(self, path, config):
        super(MyDataset, self).__init__()
        self.config = config
        if(self.config['mode'] == 'train'):
            self.data = self.__get_train_data(path)
        elif(self.config['mode'] == 'test'):
            self.data = self.__get_test_data(path)
        with open(data_path+'/relation2idx.json', 'r', encoding='utf-8') as f:
            self.relation2idx = json.load(f)
        self.idx2relation = dict()
        for key in self.relation2idx:
            self.idx2relation[self.relation2idx[key]] = key
        self.tokenizer = BertTokenizer.from_pretrained(self.config['model_name'])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        text, gold = self.data[item]['text'], self.data[item]['spo_list']
        text = text if len(text) <= 512 else text[:512]
        sample = list(text)
        sample = self.tokenizer.convert_tokens_to_ids(sample)
        sub_start = len(sample)*[0]
        sub_end = len(sample)*[0]
        relation_start = [[0 for _ in range(self.config['relation_types'])] for _ in range(len(sample))]
        relation_end = [[0 for _ in range(self.config['relation_types'])] for _ in range(len(sample))]
        #   dim = (seq_len, relation_types)
        sub_start_single = len(sample) * [0]
        sub_end_single = len(sample)*[0]
        s2ro_map = {}
        for entry in gold:
            sub = entry['subject']
            obj = entry['object']['@value']
            relation = '同义词-' + entry['subject_type'] if entry['predicate'] == '同义词' else entry['predicate']
            #正则表达式无法处理小括号, 所以抛出异常
            try:
                sub_pos = re.search(sub, text).span()
                obj_pos = re.search(obj, text).span()
                relation_idx = self.relation2idx[relation]
                sub_start[sub_pos[0]] = 1
                sub_end[sub_pos[1]-1] = 1
                if sub_pos not in s2ro_map:
                    s2ro_map[sub_pos] = []
                s2ro_map[sub_pos].append((obj_pos, relation_idx))
            except:
                pass
        if s2ro_map:
            sub_pos = choice(list(s2ro_map.keys()))
            sub_start_single[sub_pos[0]] = 1
            sub_end_single[sub_pos[1] - 1] = 1
            for obj_pos, relation_idx in s2ro_map.get(sub_pos, []):
                relation_start[obj_pos[0]][relation_idx] = 1
                relation_end[obj_pos[1]-1][relation_idx] = 1
        return sample, sub_start, sub_end, relation_start, relation_end, sub_start_single, sub_end_single

def collate_fn(data):
    data.sort(key= lambda x: len(x[0]), reverse = True)
    sample, sub_start, sub_end, relation_start, relation_end, sub_start_single, sub_end_single = zip(*data)
    mask = [[1 if j < len(i) else 0 for j in range(len(sample[0]))] for i in sample]
    sample = [torch.tensor(i).long().to(device) for i in sample]
    sub_start = [torch.tensor(i).long().to(device) for i in sub_start]
    sub_end = [torch.tensor(i).long().to(device) for i in sub_end]
    relation_start = [torch.tensor(i).long().to(device) for i in relation_start]
    relation_end = [torch.tensor(i).long().to(device) for i in relation_end]
    sub_start_single = [torch.tensor(i).long().to(device) for i in sub_start_single]
    sub_end_single = [torch.tensor(i).long().to(device) for i in sub_end_single]
    mask = torch.tensor(mask).long().to(device)
    sample = pad_sequence(sample, batch_first=True, padding_value=0)
    sub_start = pad_sequence(sub_start, batch_first=True, padding_value=0)
    sub_end = pad_sequence(sub_end, batch_first=True, padding_value=0)
    relation_start = pad_sequence(relation_start, batch_first=True, padding_value=0)
    relation_end = pad_sequence(relation_end, batch_first=True, padding_value=0)
    sub_start_single = pad_sequence(sub_start_single, batch_first=True, padding_value=0)
    sub_end_single = pad_sequence(sub_end_single, batch_first=True, padding_value=0)
    return sample, sub_start, sub_end, relation_start, relation_end, mask, sub_start_single, sub_end_single
    # dim(sample) = dim(sub_start) = dim(sub_end) = (batch_size, seq_len]
    # dim(relation_start) = dim(relation_end) = (batch_size, seq_len, relation_types)

predicate2id = {}
id2predicate = {}
with open(data_path+'/53_schemas.json', encoding="utf-8") as f:
    for l in f:
        l = json.loads(l)
        if l['predicate'] not in predicate2id:
            id2predicate[len(predicate2id)] = l['predicate']
            predicate2id[l['predicate']] = len(predicate2id)
f.close()

with open(data_path+'/relation2idx.json', 'w', encoding='utf-8') as fw:
    fw.write(json.dumps(id2predicate, ensure_ascii=False))
fw.close()

config = {
    'mode': 'test', 
    'model_name': 'bert-base-multilingual-cased',
    'batch_size': 512,
    'relation_types': 53
}
path = data_path+'/CMeIE_test.json'
data = MyDataset(path, config)
dataloader = DataLoader(data, batch_size=config['batch_size'], shuffle=False, collate_fn=collate_fn)
batch_data = next(iter(dataloader))
a, b = batch_data[3][37], batch_data[4][37]
print(len(batch_data), len(batch_data[0]))
print(len(a), a)
print(len(b), b)

## CasRel

In [None]:
import torch
import torch as t
from torch import nn
from transformers import BertModel
import numpy as np

device = 'cuda:0'

class CasRel(nn.Module):
    def __init__(self, config):
        super(CasRel, self).__init__()
        self.config = config
        self.bert_dim = 768
        self.bert_encoder = BertModel.from_pretrained(config['model_name'])
        self.sub_start_tageer = nn.Linear(self.bert_dim, 1)
        self.sub_end_tagger = nn.Linear(self.bert_dim, 1)
        self.obj_start_tagger = nn.Linear(self.bert_dim, config['relation_types'])
        self.obj_end_tagger = nn.Linear(self.bert_dim, config['relation_types'])

    def get_encoded_text(self, data):
        # with torch.no_grad():   #   out of GPU Memory
        encoded_text = self.bert_encoder(data['token_ids'], attention_mask=data['mask'])[0]
        return encoded_text #   dim  = (batch_size, seq_len, bert_dim)

    def get_sub(self, encoded_text):
        #   dim(pred) = (batch_size, seq_len, 1)
        pred_sub_start = self.sub_start_tageer(encoded_text)
        pred_sub_start = torch.sigmoid(pred_sub_start)
        pred_sub_end = self.sub_end_tagger(encoded_text)
        pred_sub_end = torch.sigmoid(pred_sub_end)
        return pred_sub_start, pred_sub_end

    # def get_sub_info(self, encoded_text, pred_sub_start, pred_sub_end, real_sub_start, real_sub_end):
    #     if(self.config['mode'] == 'train'):
    #         start = real_sub_start
    #         end = real_sub_end
    #     elif(self.config['mode'] == 'test'):
    #         threshold = 0.5
    #         start = [pred_sub_start > threshold]
    #         end = [pred_sub_end > threshold]
    #     pred_sub_lst = []
    #     for idx_start, i in enumerate(start):
    #         if(i == 1):
    #             for idx_end in range(idx_start, len(end)):
    #                 if(end[idx_end] == 1):
    #    Problem: 一个batch中不同样本的sub_list长度不同

    def get_obj(self, sub_start_mapping, sub_end_mapping, encoded_text):
        #   dim(sub_start_mapping) = dim(sub_end_mapping) = (batch_size, 1, seq_len)
        #   dim(encoded_text) = (batch_size, seq_len, bert_dim)
        sub_start = torch.matmul(sub_start_mapping.float(), encoded_text)
        sub_end = torch.matmul(sub_end_mapping.float(), encoded_text)
        #   dim(sub_start) = dim(sub_end) = (batch_size, 1, bert_dim)
        sub = (sub_start + sub_end) / 2
        encoded_text = encoded_text + sub
        pred_obj_start = self.obj_start_tagger(encoded_text)
        pred_obj_end = self.obj_end_tagger(encoded_text)
        pred_obj_start = torch.sigmoid(pred_obj_start)
        pred_obj_end = torch.sigmoid(pred_obj_end)
        return pred_obj_start, pred_obj_end
        #   dim = (batch_size, seq_len, relation_types)

    def get_list(self, start, end, text, h_bar=0.5, t_bar=0.5):
        res = []
        start, end = start[: 512], end[: 512]
        start_idxs, end_idxs = [], []
        
        for idx in range(len(start)):
            print('---> strt[idx]:', start[idx])
            if(start[idx] > h_bar):
                start_idxs.append(idx)
            if(end[idx] > t_bar):
                end_idxs.append(idx)
        for start_idx in start_idxs:
            for end_idx in end_idxs:
                if(end_idx >= start_idx):
                    entry = {}
                    entry['text'] = text[start_idx: end_idx+1]
                    entry['start'] = start_idx
                    entry['end'] = end_idx
                    res.append(entry)
                    break
        return res

    def forward(self, data):
        encoded_text = self.get_encoded_text(data)
        pred_sub_start, pred_sub_end = self.get_sub(encoded_text)
        sub_start_mapping = data['sub_start'].unsqueeze(1)
        # (batch_size, seq_len) --> (batch_size, 1, seq_len)
        sub_end_mapping = data['sub_end'].unsqueeze(1)
        pred_obj_start, pred_obj_end = self.get_obj(sub_start_mapping, sub_end_mapping, encoded_text)
        return pred_sub_start, pred_sub_end, pred_obj_start, pred_obj_end
        #   dim(pred_sub_start) = dim(pred_sub_end) = (batch_size, seq_len, 1)
        #   dim(pred_obj_start) = dim(pred_obj_end) = (batch_size, seq_len, realtion_types)

    def test(self, data):
        encoded_text = self.get_encoded_text(data)
        pred_sub_start, pred_sub_end = self.get_sub(encoded_text)
        sub_list = self.get_list(pred_sub_start.squeeze(0).squeeze(-1), pred_sub_end.squeeze(0).squeeze(-1), data['text'])
        if(sub_list):
            repeated_encoded_text = encoded_text.repeat(len(sub_list), 1, 1)
            sub_start_mapping = torch.zeros(len(sub_list), 1, encoded_text.shape[1]).to(device)
            sub_end_mapping = torch.zeros(len(sub_list), 1, encoded_text.shape[1]).to(device)
            for idx, sub in enumerate(sub_list):
                sub_start_mapping[idx][0][sub['start']] = 1
                sub_end_mapping[idx][0][sub['end']] = 1
            pred_obj_start, pred_obj_end = self.get_obj(sub_start_mapping, sub_end_mapping, repeated_encoded_text)
            return sub_list, pred_obj_start, pred_obj_end
        else:
            return None

## 训练

In [None]:
import time
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader


device = 'cpu' #'cuda:0'# 'cpu'
# torch.set_num_threads(6)

def get_loss(pred, gold, mask):
    pred = pred.squeeze(-1)
    loss = F.binary_cross_entropy(pred, gold.float(), reduction='none') #以向量形式返回loss
    if loss.shape != mask.shape:
        mask = mask.unsqueeze(-1)
    loss = torch.sum(loss*mask)/torch.sum(mask)
    return loss


config = {
    'mode': 'train',
    'model_name': 'bert-base-multilingual-cased', # 'bert-base-chinese', #'bert-base-multilingual-cased'
    'batch_size': 40,
    'epoch': 1,
    'relation_types': 53,
    'sub_weight': 1,
    'obj_weight': 1
}
path = data_path+'/CMeIE_train.json'
data = MyDataset(path, config)
dataloader = DataLoader(data, batch_size=config['batch_size'], shuffle=True, collate_fn=collate_fn)

model = CasRel(config).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, betas=(0.9, 0.999))
loss_recorder = 0
for epoch_index in range(config['epoch']):
    print('---> epoch_index:', epoch_index)
    time_start = time.perf_counter()
    for batch_index, (sample, sub_start, sub_end, relation_start, relation_end, mask, sub_start_single, sub_end_single) in enumerate(iter(dataloader)):
        batch_data = dict()
        batch_data['token_ids'] = sample
        batch_data['mask'] = mask
        batch_data['sub_start'] = sub_start_single
        batch_data['sub_end'] = sub_end_single
        pred_sub_start, pred_sub_end, pred_obj_start, pred_obj_end = model(batch_data)
        sub_start_loss = get_loss(pred_sub_start, sub_start, mask)
        sub_end_loss = get_loss(pred_sub_end, sub_end, mask)
        obj_start_loss = get_loss(pred_obj_start, relation_start, mask)
        obj_end_loss = get_loss(pred_obj_end, relation_end, mask)
        loss = config['sub_weight']*(sub_start_loss + sub_end_loss) + config['obj_weight']*(obj_start_loss + obj_end_loss)
        loss_recorder += loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # print("epoch: %d batch: %d loss: %f"% (epoch_index, batch_index, loss))
        if(batch_index%100 == 99):
            print('--->loss:', loss_recorder)
            loss_recorder = 0
    time_end = time.perf_counter()
    torch.save(model.state_dict(), data_path+'/models.pkl')
    print("successfully saved! time used = %fs."% (time_end-time_start), batch_index, loss)

In [None]:
from transformers import BertTokenizer, TFBertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
model = TFBertModel.from_pretrained("bert-base-multilingual-cased")
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='tf')
output = model(encoded_input)
output

In [None]:
import codecs
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import json
import numpy as np

device = 'cuda:0'
# torch.set_num_threads(1)

def trans_schemas(path):
    re2sub = dict()
    re2obj = dict()
    with open(path, "r", encoding='utf-8') as f:
        sens = f.readlines()
    schemas = []
    for sen in sens:
        schemas.append(json.loads(sen.strip()))
    for entry in schemas:
        re2sub[entry['predicate']] = entry['subject_type']
        re2obj[entry['predicate']] = entry['object_type']
    return re2sub, re2obj


def get_list(start, end, text, h_bar=0.5, t_bar=0.5):
    res = []
    start, end = start[: 512], end[: 512]
    start_idxs, end_idxs = [], []
    for idx in range(len(start)):
        if (start[idx] > h_bar):
            start_idxs.append(idx)
        if (end[idx] > t_bar):
            end_idxs.append(idx)
    for start_idx in start_idxs:
        for end_idx in end_idxs:
            if (end_idx >= start_idx):
                entry = {}
                entry['text'] = text[start_idx: end_idx+1]
                entry['start'] = start_idx
                entry['end'] = end_idx
                res.append(entry)
                break
    return res

def get_text(path):
    with open(path, 'r', encoding='utf-8') as f:
        data = f.readlines()
    data = [json.loads(i) for i in data]
    return data



config = {
    'mode': 'test',
    'model_name': 'bert-base-multilingual-cased', # 'bert-base-multilingual-cased'
    'batch_size': 512,
    'epoch': 1,
    'relation_types': 53,
    'sub_weight': 1,
    'obj_weight': 1
}
path = data_path+'/CMeIE_test.json'
schemas_path = data_path+'/53_schemas.json'
res_path = data_path+'/CMeIE_test_res.json'
res_file = codecs.open(res_path, 'w', encoding='utf-8')
raw_data = get_text(path)
re2sub, re2obj = trans_schemas(schemas_path)
data = MyDataset(path, config)
dataloader = DataLoader(data, batch_size=config['batch_size'], shuffle=False, collate_fn=collate_fn)

model = CasRel(config).to(device)
model.load_state_dict(torch.load(data_path+'/models.pkl'))
for batch_index, (sample, sub_start, sub_end, relation_start, relation_end, mask, _, _) in enumerate(iter(dataloader)):
    with torch.no_grad():
        text =  raw_data[batch_index]['text']
        batch_data = dict()
        batch_data['token_ids'] = sample
        batch_data['mask'] = mask
        batch_data['text'] = text
        ret = model.test(batch_data)
        spo_list = []
        if ret:
            sub_list, pred_obj_start, pred_obj_end = ret
            for idx, sub in enumerate(sub_list):
                obj_start, obj_end = pred_obj_start[idx].transpose(0, 1), pred_obj_end[idx].transpose(0, 1)
                for i in range(config['relation_types']):
                    obj_list = get_list(obj_start[i], obj_end[i], text)
                    for obj in obj_list:
                        entry = {}
                        entry['Combined'] = '。' in text[sub['end']: obj['start']] or '。' in text[obj['end']: sub['start']]
                        entry['subject'] = sub['text']
                        entry['predicate'] = data.idx2relation[i]
                        entry['object'] = {'@value': obj['text']}
                        entry['subject_type'] = re2sub[data.idx2relation[i]]
                        entry['object_type'] = {'@value': re2obj[data.idx2relation[i]]}
                        spo_list.append(entry)
        res = {}
        res['text'] = text
        res['spo_list'] = spo_list
        json.dump(res, res_file, ensure_ascii=False)
        res_file.write('\n')
        print('--->', batch_index, spo_list, text)


# **二**、CasRel-bert 关系抽取
[{"subject": "阿司匹林", "relation": "病因", "object": "药物 因素"}, {"subject": "阿司匹林", "relation": "病因", "object": "保泰松"}, {"subject": "阿司匹林", "relation": "病因", "object": "精神 因素"}, {"subject": "消化性溃疡", "relation": "病因", "object": "吲哚美辛"}, {"subject": "消化性溃疡", "relation": "病因", "object": "肾上腺皮质激素"}, {"subject": "消化性溃疡", "relation": "病因", "object": "精神 因素"}, {"subject": "阿司匹林", "relation": "病因", "object": "吲哚美辛"}, {"subject": "消化性溃疡", "relation": "病因", "object": "药物 因素"}, {"subject": "消化性溃疡", "relation": "病因", "object": "保泰松"}, {"subject": "阿司匹林", "relation": "病因", "object": "肾上腺皮质激素"}]

In [None]:
!git clone https://github.com/longlongman/CasRel-pytorch-reimplement.git
!cp -r /content/CasRel-pytorch-reimplement/CasRel-reimplement/data/CMED /content/CMED

In [None]:
!gdown --id '1AQitrjbvCWc51SYiLN-cJq4e0WiNN4KY' --output /content/chinese-bert-wwm.zip
!mkdir /content/Models
!mkdir /content/Models/cinese-bert-wwm
!unzip -d /content/Models/cinese-bert-wwm /content/chinese-bert-wwm.zip
# from transformers import BertModel
# bert_encoder = BertModel.from_pretrained("hfl/chinese-bert-wwm", cache_dir=data_path+'/model')

In [None]:
import os
base_path = '/home/aid/Github/NLP_relation_extraction'#os.path.dirname(os.getcwd())
data_path = base_path + '/data/CMED'#'content/CMED'
model_path = base_path + '/models/chinese-bert-wwm'#'content/Models/chinese-bert-wwm'
print(data_path)

## tokenizer

In [None]:
from keras_bert import Tokenizer
import codecs
import unicodedata


class HBTokenizer(Tokenizer):
    def _tokenize(self, text):
        if not self._cased:
            text = unicodedata.normalize('NFD', text)
            text = ''.join([ch for ch in text if unicodedata.category(ch) != 'Mn'])
            text = text.lower()
        spaced = ''
        for ch in text:
            if ord(ch) == 0 or ord(ch) == 0xfffd or self._is_control(ch):
                continue
            else:
                spaced += ch
        tokens = []
        for word in spaced.strip().split():
            tokens += self._word_piece_tokenize(word)
            tokens.append('[unused1]')
        return tokens


def get_tokenizer(vocab_path):
    token_dict = {}
    with codecs.open(vocab_path, 'r', 'utf8') as reader:
        for line in reader:
            token = line.strip()
            token_dict[token] = len(token_dict)
    return HBTokenizer(token_dict, cased=True)


## data_loader

In [None]:
from torch.utils.data import DataLoader, Dataset
import json
import os
import torch
import numpy as np
from random import choice

tokenizer = get_tokenizer(model_path+'/vocab.txt')
BERT_MAX_LEN = 512


def find_head_idx(source, target):
    target_len = len(target)
    for i in range(len(source)):
        if source[i: i + target_len] == target:
            return i
    return -1


class CMEDDataset(Dataset):
    def __init__(self, config, prefix, is_test, tokenizer):
        self.config = config
        self.prefix = prefix
        self.is_test = is_test
        self.tokenizer = tokenizer
        if self.config.debug:
            self.json_data = json.load(open(os.path.join(self.config.data_path, prefix + '.json')))[:500]
        else:
            self.json_data = json.load(open(os.path.join(self.config.data_path, prefix + '.json')))
        self.rel2id = json.load(open(os.path.join(self.config.data_path, 'rel2id.json')))[1]

    def __len__(self):
        return len(self.json_data)

    def __getitem__(self, idx):
        ins_json_data = self.json_data[idx]
        text = ins_json_data['text']
        text = ' '.join(text.split()[:self.config.max_len])
        tokens = self.tokenizer.tokenize(text)
        if len(tokens) > BERT_MAX_LEN:
            tokens = tokens[: BERT_MAX_LEN]
        text_len = len(tokens)

        if not self.is_test:
            s2ro_map = {}
            for triple in ins_json_data['triple_list']:
                triple = (self.tokenizer.tokenize(triple[0])[1:-1], triple[1], self.tokenizer.tokenize(triple[2])[1:-1])
                sub_head_idx = find_head_idx(tokens, triple[0])
                obj_head_idx = find_head_idx(tokens, triple[2])
                if sub_head_idx != -1 and obj_head_idx != -1:
                    sub = (sub_head_idx, sub_head_idx + len(triple[0]) - 1)
                    if sub not in s2ro_map:
                        s2ro_map[sub] = []
                    s2ro_map[sub].append((obj_head_idx, obj_head_idx + len(triple[2]) - 1, self.rel2id[triple[1]]))

            if s2ro_map:
                token_ids, segment_ids = self.tokenizer.encode(first=text)
                masks = segment_ids
                if len(token_ids) > text_len:
                    token_ids = token_ids[:text_len]
                    masks = masks[:text_len]
                token_ids = np.array(token_ids)
                masks = np.array(masks) + 1
                sub_heads, sub_tails = np.zeros(text_len), np.zeros(text_len)
                for s in s2ro_map:
                    sub_heads[s[0]] = 1
                    sub_tails[s[1]] = 1
                sub_head_idx, sub_tail_idx = choice(list(s2ro_map.keys()))
                sub_head, sub_tail = np.zeros(text_len), np.zeros(text_len)
                sub_head[sub_head_idx] = 1
                sub_tail[sub_tail_idx] = 1
                obj_heads, obj_tails = np.zeros((text_len, self.config.rel_num)), np.zeros((text_len, self.config.rel_num))
                for ro in s2ro_map.get((sub_head_idx, sub_tail_idx), []):
                    obj_heads[ro[0]][ro[2]] = 1
                    obj_tails[ro[1]][ro[2]] = 1
                return token_ids, masks, text_len, sub_heads, sub_tails, sub_head, sub_tail, obj_heads, obj_tails, ins_json_data['triple_list'], tokens
            else:
                return None
        else:
            token_ids, segment_ids = self.tokenizer.encode(first=text)
            masks = segment_ids
            if len(token_ids) > text_len:
                token_ids = token_ids[:text_len]
                masks = masks[:text_len]
            token_ids = np.array(token_ids)
            masks = np.array(masks) + 1
            sub_heads, sub_tails = np.zeros(text_len), np.zeros(text_len)
            sub_head, sub_tail = np.zeros(text_len), np.zeros(text_len)
            obj_heads, obj_tails = np.zeros((text_len, self.config.rel_num)), np.zeros((text_len, self.config.rel_num))
            return token_ids, masks, text_len, sub_heads, sub_tails, sub_head, sub_tail, obj_heads, obj_tails, ins_json_data['triple_list'], tokens


def cmed_collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    batch.sort(key=lambda x: x[2], reverse=True)
    token_ids, masks, text_len, sub_heads, sub_tails, sub_head, sub_tail, obj_heads, obj_tails, triples, tokens = zip(*batch)
    cur_batch = len(batch)
    max_text_len = max(text_len)
    batch_token_ids = torch.LongTensor(cur_batch, max_text_len).zero_()
    batch_masks = torch.LongTensor(cur_batch, max_text_len).zero_()
    batch_sub_heads = torch.Tensor(cur_batch, max_text_len).zero_()
    batch_sub_tails = torch.Tensor(cur_batch, max_text_len).zero_()
    batch_sub_head = torch.Tensor(cur_batch, max_text_len).zero_()
    batch_sub_tail = torch.Tensor(cur_batch, max_text_len).zero_()
    batch_obj_heads = torch.Tensor(cur_batch, max_text_len, 44).zero_()
    batch_obj_tails = torch.Tensor(cur_batch, max_text_len, 44).zero_()

    for i in range(cur_batch):
        batch_token_ids[i, :text_len[i]].copy_(torch.from_numpy(token_ids[i]))
        batch_masks[i, :text_len[i]].copy_(torch.from_numpy(masks[i]))
        batch_sub_heads[i, :text_len[i]].copy_(torch.from_numpy(sub_heads[i]))
        batch_sub_tails[i, :text_len[i]].copy_(torch.from_numpy(sub_tails[i]))
        batch_sub_head[i, :text_len[i]].copy_(torch.from_numpy(sub_head[i]))
        batch_sub_tail[i, :text_len[i]].copy_(torch.from_numpy(sub_tail[i]))
        batch_obj_heads[i, :text_len[i], :].copy_(torch.from_numpy(obj_heads[i]))
        batch_obj_tails[i, :text_len[i], :].copy_(torch.from_numpy(obj_tails[i]))

    return {'token_ids': batch_token_ids,
            'mask': batch_masks,
            'sub_heads': batch_sub_heads,
            'sub_tails': batch_sub_tails,
            'sub_head': batch_sub_head,
            'sub_tail': batch_sub_tail,
            'obj_heads': batch_obj_heads,
            'obj_tails': batch_obj_tails,
            'triples': triples,
            'tokens': tokens}


def get_loader(config, prefix, is_test=False, num_workers=0, collate_fn=cmed_collate_fn):
    dataset = CMEDDataset(config, prefix, is_test, tokenizer)
    if not is_test:
        data_loader = DataLoader(dataset=dataset,
                                 batch_size=config.batch_size,
                                 shuffle=True,
                                 pin_memory=True,
                                 num_workers=num_workers,
                                 collate_fn=collate_fn)
    else:
        data_loader = DataLoader(dataset=dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True,
                                 num_workers=num_workers,
                                 collate_fn=collate_fn)
    return data_loader


class DataPreFetcher(object):
    def __init__(self, loader):
        self.loader = iter(loader)
        self.stream = torch.cuda.Stream()
        self.preload()

    def preload(self):
        try:
            self.next_data = next(self.loader)
        except StopIteration:
            self.next_data = None
            return
        with torch.cuda.stream(self.stream):
            for k, v in self.next_data.items():
                if isinstance(v, torch.Tensor):
                    self.next_data[k] = self.next_data[k].cuda(non_blocking=True)

    def next(self):
        torch.cuda.current_stream().wait_stream(self.stream)
        data = self.next_data
        self.preload()
        return data


## models

In [None]:
from torch import nn
from transformers import BertModel, BertTokenizer
import torch


class Casrel(nn.Module):
    def __init__(self, config):
        super(Casrel, self).__init__()
        self.config = config
        self.bert_dim = 768
        self.bert_encoder = BertModel.from_pretrained("hfl/chinese-bert-wwm", cache_dir=model_path)
        # self.bert_encoder = BertTokenizer.from_pretrained(config.bert_path)
        self.sub_heads_linear = nn.Linear(self.bert_dim, 1)
        self.sub_tails_linear = nn.Linear(self.bert_dim, 1)
        self.obj_heads_linear = nn.Linear(self.bert_dim, self.config.rel_num)
        self.obj_tails_linear = nn.Linear(self.bert_dim, self.config.rel_num)
        

    def get_objs_for_specific_sub(self, sub_head_mapping, sub_tail_mapping, encoded_text):
        # [batch_size, 1, bert_dim]
        sub_head = torch.matmul(sub_head_mapping, encoded_text)
        # [batch_size, 1, bert_dim]
        sub_tail = torch.matmul(sub_tail_mapping, encoded_text)
        # [batch_size, 1, bert_dim]
        sub = (sub_head + sub_tail) / 2
        # [batch_size, seq_len, bert_dim]
        encoded_text = encoded_text + sub
        # [batch_size, seq_len, rel_num]
        pred_obj_heads = self.obj_heads_linear(encoded_text)
        pred_obj_heads = torch.sigmoid(pred_obj_heads)
        # [batch_size, seq_len, rel_num]
        pred_obj_tails = self.obj_tails_linear(encoded_text)
        pred_obj_tails = torch.sigmoid(pred_obj_tails)
        return pred_obj_heads, pred_obj_tails

    def get_encoded_text(self, token_ids, mask):
        # [batch_size, seq_len, bert_dim(768)]
        encoded_text = self.bert_encoder(token_ids, attention_mask=mask)[0]
        return encoded_text

    def get_subs(self, encoded_text):
        # [batch_size, seq_len, 1]
        pred_sub_heads = self.sub_heads_linear(encoded_text)
        pred_sub_heads = torch.sigmoid(pred_sub_heads)
        # [batch_size, seq_len, 1]
        pred_sub_tails = self.sub_tails_linear(encoded_text)
        pred_sub_tails = torch.sigmoid(pred_sub_tails)
        return pred_sub_heads, pred_sub_tails

    def forward(self, data):
        # [batch_size, seq_len]
        token_ids = data['token_ids']
        # [batch_size, seq_len]
        mask = data['mask']
        # [batch_size, seq_len, bert_dim(768)]
        encoded_text = self.get_encoded_text(token_ids, mask)
        # [batch_size, seq_len, 1]
        pred_sub_heads, pred_sub_tails = self.get_subs(encoded_text)
        # [batch_size, 1, seq_len]
        sub_head_mapping = data['sub_head'].unsqueeze(1)
        # [batch_size, 1, seq_len]
        sub_tail_mapping = data['sub_tail'].unsqueeze(1)
        # [batch_size, seq_len, rel_num]
        pred_obj_heads, pred_obj_tails = self.get_objs_for_specific_sub(sub_head_mapping, sub_tail_mapping, encoded_text)
        return pred_sub_heads, pred_sub_tails, pred_obj_heads, pred_obj_tails


In [None]:
import argparse
import os
import numpy as np
import random
import torch.optim as optim
from torch import nn
import torch.nn.functional as F
import torch
import json
import time


class Framework(object):
    def __init__(self, con):
        self.config = con

    def logging(self, s, print_=True, log_=True):
        if print_:
            print(s)
        if log_:
            with open(os.path.join(self.config.log_dir, self.config.log_save_name), 'a+') as f_log:f_log.write(s + '\n')

    def train(self, model_pattern):
        # initialize the model
        ori_model = model_pattern(self.config)
        ori_model.cuda()

        # define the optimizer
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, ori_model.parameters()), lr=self.config.learning_rate)

        # whether use multi GPU
        if self.config.multi_gpu:
            model = nn.DataParallel(ori_model)
        else:
            model = ori_model

        # define the loss function
        def loss(gold, pred, mask):
            pred = pred.squeeze(-1)
            los = F.binary_cross_entropy(pred, gold, reduction='none')
            if los.shape != mask.shape:
                mask = mask.unsqueeze(-1)
            los = torch.sum(los * mask) / torch.sum(mask)
            return los

        # check the checkpoint dir
        if not os.path.exists(self.config.checkpoint_dir):
            os.mkdir(self.config.checkpoint_dir)

        # check the log dir
        if not os.path.exists(self.config.log_dir):
            os.mkdir(self.config.log_dir)

        # get the data loader
        train_data_loader = get_loader(self.config, prefix=self.config.train_prefix)
        dev_data_loader = get_loader(self.config, prefix=self.config.dev_prefix, is_test=True)

        # other
        model.train()
        global_step = 0
        loss_sum = 0

        best_f1_score = 0
        best_precision = 0
        best_recall = 0

        best_epoch = 0
        init_time = time.time()
        start_time = time.time()

        # the training loop
        for epoch in range(self.config.max_epoch):
            train_data_prefetcher = DataPreFetcher(train_data_loader)
            data = train_data_prefetcher.next()
            while data is not None:
                pred_sub_heads, pred_sub_tails, pred_obj_heads, pred_obj_tails = model(data)

                sub_heads_loss = loss(data['sub_heads'], pred_sub_heads, data['mask'])
                sub_tails_loss = loss(data['sub_tails'], pred_sub_tails, data['mask'])
                obj_heads_loss = loss(data['obj_heads'], pred_obj_heads, data['mask'])
                obj_tails_loss = loss(data['obj_tails'], pred_obj_tails, data['mask'])
                total_loss = (sub_heads_loss + sub_tails_loss) + (obj_heads_loss + obj_tails_loss)

                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

                global_step += 1
                loss_sum += total_loss.item()

                if global_step % self.config.period == 0:
                    cur_loss = loss_sum / self.config.period
                    elapsed = time.time() - start_time
                    self.logging("epoch: {:3d}, step: {:4d}, speed: {:5.2f}ms/b, train loss: {:5.3f}".
                                 format(epoch, global_step, elapsed * 1000 / self.config.period, cur_loss))
                    loss_sum = 0
                    start_time = time.time()

                data = train_data_prefetcher.next()

            if (epoch + 1) % self.config.test_epoch == 0:
                eval_start_time = time.time()
                model.eval()
                # call the test function
                precision, recall, f1_score = self.test(dev_data_loader, model)
                model.train()
                self.logging('epoch {:3d}, eval time: {:5.2f}s, f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2f}'.
                             format(epoch, time.time() - eval_start_time, f1_score, precision, recall))

                if f1_score > best_f1_score:
                    best_f1_score = f1_score
                    best_epoch = epoch
                    best_precision = precision
                    best_recall = recall
                    self.logging("saving the model, epoch: {:3d}, best f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2f}".
                                 format(best_epoch, best_f1_score, precision, recall))
                    # save the best model
                    path = os.path.join(self.config.checkpoint_dir, self.config.model_save_name)
                    if not self.config.debug:
                        torch.save(ori_model.state_dict(), path)

            # manually release the unused cache
            torch.cuda.empty_cache()

        self.logging("finish training")
        self.logging("best epoch: {:3d}, best f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2}, total time: {:5.2f}s".format(best_epoch, best_f1_score, best_precision, best_recall, time.time() - init_time))

    def test(self, test_data_loader, model, output=False, h_bar=0.5, t_bar=0.5):

        if output:
            # check the result dir
            if not os.path.exists(self.config.result_dir):
                os.mkdir(self.config.result_dir)

            path = os.path.join(self.config.result_dir, self.config.result_save_name)

            fw = open(path, 'w')

        orders = ['subject', 'relation', 'object']

        def to_tup(triple_list):
            ret = []
            for triple in triple_list:
                ret.append(tuple(triple))
            return ret

        test_data_prefetcher = DataPreFetcher(test_data_loader)
        data = test_data_prefetcher.next()
        id2rel = json.load(open(os.path.join(self.config.data_path, 'rel2id.json')))[0]
        correct_num, predict_num, gold_num = 0, 0, 0

        while data is not None:
            with torch.no_grad():
                token_ids = data['token_ids']
                tokens = data['tokens'][0]
                mask = data['mask']
                encoded_text = model.get_encoded_text(token_ids, mask)
                pred_sub_heads, pred_sub_tails = model.get_subs(encoded_text)
                sub_heads, sub_tails = np.where(pred_sub_heads.cpu()[0] > h_bar)[0], np.where(pred_sub_tails.cpu()[0] > t_bar)[0]
                subjects = []
                for sub_head in sub_heads:
                    sub_tail = sub_tails[sub_tails >= sub_head]
                    if len(sub_tail) > 0:
                        sub_tail = sub_tail[0]
                        subject = tokens[sub_head: sub_tail]
                        subjects.append((subject, sub_head, sub_tail))
                if subjects:
                    triple_list = []
                    # [subject_num, seq_len, bert_dim]
                    repeated_encoded_text = encoded_text.repeat(len(subjects), 1, 1)
                    # [subject_num, 1, seq_len]
                    sub_head_mapping = torch.Tensor(len(subjects), 1, encoded_text.size(1)).zero_()
                    sub_tail_mapping = torch.Tensor(len(subjects), 1, encoded_text.size(1)).zero_()
                    for subject_idx, subject in enumerate(subjects):
                        sub_head_mapping[subject_idx][0][subject[1]] = 1
                        sub_tail_mapping[subject_idx][0][subject[2]] = 1
                    sub_tail_mapping = sub_tail_mapping.to(repeated_encoded_text)
                    sub_head_mapping = sub_head_mapping.to(repeated_encoded_text)
                    pred_obj_heads, pred_obj_tails = model.get_objs_for_specific_sub(sub_head_mapping, sub_tail_mapping, repeated_encoded_text)
                    for subject_idx, subject in enumerate(subjects):
                        sub = subject[0]
                        sub = ''.join([i.lstrip("##") for i in sub])
                        sub = ' '.join(sub.split('[unused1]'))
                        obj_heads, obj_tails = np.where(pred_obj_heads.cpu()[subject_idx] > h_bar), np.where(pred_obj_tails.cpu()[subject_idx] > t_bar)
                        for obj_head, rel_head in zip(*obj_heads):
                            for obj_tail, rel_tail in zip(*obj_tails):
                                if obj_head <= obj_tail and rel_head == rel_tail:
                                    rel = id2rel[str(int(rel_head))]
                                    obj = tokens[obj_head: obj_tail]
                                    obj = ''.join([i.lstrip("##") for i in obj])
                                    obj = ' '.join(obj.split('[unused1]'))
                                    triple_list.append((sub, rel, obj))
                                    break
                    triple_set = set()
                    for s, r, o in triple_list:
                        triple_set.add((s, r, o))
                    pred_list = list(triple_set)
                else:
                    pred_list = []
                pred_triples = set(pred_list)
                gold_triples = set(to_tup(data['triples'][0]))

                correct_num += len(pred_triples & gold_triples)
                predict_num += len(pred_triples)
                gold_num += len(gold_triples)

                if output:
                    result = json.dumps({
                        # 'text': ' '.join(tokens),
                        'triple_list_gold': [
                            dict(zip(orders, triple)) for triple in gold_triples
                        ],
                        'triple_list_pred': [
                            dict(zip(orders, triple)) for triple in pred_triples
                        ],
                        'new': [
                            dict(zip(orders, triple)) for triple in pred_triples - gold_triples
                        ],
                        'lack': [
                            dict(zip(orders, triple)) for triple in gold_triples - pred_triples
                        ]
                    }, ensure_ascii=False)
                    fw.write(result + '\n')

                data = test_data_prefetcher.next()

        print("correct_num: {:3d}, predict_num: {:3d}, gold_num: {:3d}".format(correct_num, predict_num, gold_num))

        precision = correct_num / (predict_num + 1e-10)
        recall = correct_num / (gold_num + 1e-10)
        f1_score = 2 * precision * recall / (precision + recall + 1e-10)
        return precision, recall, f1_score

    def testall(self, model_pattern):
        model = model_pattern(self.config)
        path = os.path.join(self.config.checkpoint_dir, self.config.model_save_name)
        model.load_state_dict(torch.load(path))
        model.cuda()
        model.eval()
        test_data_loader = get_loader(self.config, prefix=self.config.test_prefix, is_test=True)
        precision, recall, f1_score = self.test(test_data_loader, model, True)
        print("f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2f}".format(f1_score, precision, recall))

## 训练

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
config = {
    'model_name' : 'Casrel',
    'lr' : 1e-5,
    'learning_rate' : 1e-5,
    'multi_gpu' : False,
    'dataset' : 'CMED',
    'batch_size' : 4,
    'max_epoch' : 5,
    'test_epoch' : 20,
    'train_prefix' : 'my_train_triples',
    'dev_prefix' : 'dev_triples',
    'test_prefix' : 'my_train_triples',
    'max_len' : 512,
    'rel_num' : 44,
    'period' : 50,
    'debug' : False,
    'bert_path' : model_path
}
config['checkpoint_dir'] = base_path + '/checkpoint'
config['log_dir'] = base_path + '/log'
config['data_path'] = data_path
config['result_dir'] = data_path 
config['model_save_name'] = config['model_name'] + '_DATASET_' + config['dataset'] + "_LR_" + str(config['lr']) + "_BS_" + str(config['batch_size'])
config['log_save_name'] = "LOG_" + config['model_name'] + '_DATASET_' + config['dataset'] + "_LR_" + str(config['lr']) + "_BS_" + str(config['batch_size'])
config['result_save_name'] = "RESULT_" + config['model_name'] + '_DATASET_' + config['dataset'] + "_LR_" + str(config['lr']) + "_BS_" + str(config['batch_size'])

seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
config = argparse.Namespace(**config)
framework = Framework(config)
framework.train(Casrel)


## 测试

In [None]:
import argparse
config = {
    'model_name' : 'Casrel',
    'lr' : 1e-5,
    'learning_rate' : 1e-5,
    'multi_gpu' : False,
    'dataset' : 'CMED',
    'batch_size' : 6,
    'max_epoch' : 5,
    'test_epoch' : 5,
    'train_prefix' : 'my_train_triples',
    'dev_prefix' : 'dev_triples',
    'test_prefix' : 'my_test_triples',
    'max_len' : 150,
    'rel_num' : 44,
    'period' : 50,
    'debug' : False,
    'bert_path' : model_path
}
config['checkpoint_dir'] = base_path + '/checkpoint'
config['log_dir'] = base_path + '/log'
config['data_path'] = data_path
config['result_dir'] = data_path 
config['model_save_name'] = config['model_name'] + '_DATASET_' + config['dataset'] + "_LR_" + str(config['lr']) + "_BS_" + str(config['batch_size'])
config['log_save_name'] = "LOG_" + config['model_name'] + '_DATASET_' + config['dataset'] + "_LR_" + str(config['lr']) + "_BS_" + str(config['batch_size'])
config['result_save_name'] = "RESULT_" + config['model_name'] + '_DATASET_' + config['dataset'] + "_LR_" + str(config['lr']) + "_BS_" + str(config['batch_size'])
config = argparse.Namespace(**config)

framework = Framework(config)
framework.testall(Casrel)

# test_data_loader = get_loader(config, prefix=config.test_prefix)
# print(test_data_loader)
# framework = Framework(config)
# framework.test(test_data_loader=config.test_prefix, model=Casrel)

## 数据生成

In [None]:
import json
import re
import jieba

def remove_punctuation(line):
    line = str(line)
    if line.strip()=='':
        return ''
    #rule = re.compile(u"""[^a-zA-Z0-9\u4E00-\u9FA5]|[a-zA-Z0-9'!"#$%&\'()*+,-./:;<=>?@，。?★、…【】《》？“”‘'！[\\]^_`{|}~\s]+|[\001\002\003\004\005\006\007\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a]+""")
    #line = rule.sub('',line)
    # line = line.replace('\n', ' ').replace('\r', ' ').strip()
    return line

with open(base_path+'/data/MyData/茂名人民医院重症专病入院记录_额外原文.txt', 'r')as fr:
    text_list = fr.readlines()
fr.close()

text_data = []
for i in text_list:
    try:
        tmp_text = []
        i = eval(i)
        for k, v in i.items():
            tmp_text.append(v)
        text_data.append(" ".join(jieba.cut(remove_punctuation(' '.join(tmp_text)))))
    except:
        pass

test_data = []
for i in text_data:
    test_data.append({'text': i,'triple_list': [['subject_placeholder','relation_placeholder','object_placeholder']]})
print(test_data[0])
print(len(test_data))

with open(base_path+'/data/CMED/my_test_triples.json', 'w')as fw:
    json.dump(test_data, fw, ensure_ascii=False)
fw.close()

In [None]:
import json
import jieba

with open(base_path+'/data/CMeIE/CMeIE_train.json', 'r')as fr:
    train_list = fr.readlines()
fr.close()

with open(base_path+'/data/CMeIE/CMeIE_dev.json', 'r')as fr:
    add_dev_list = fr.readlines()
fr.close()

with open(base_path+'/data/CMED/train_triples.json', 'r')as fr:
    add_train_list = fr.readlines()
fr.close()

train_list.extend(add_dev_list)
print(train_list[0])

set_data = []
for i in train_list:
    try:
        tmp_dict = json.loads(i)
        for k, v in tmp_dict.items():
            if k == 'text':
                t_text = " ".join(jieba.cut(v))
            elif k == 'spo_list':
                for i in v:
                    t_object = i['object']['@value']
                    t_predicate = i['predicate']
                    t_subject = i['subject']
                tmp_text = {'text': t_text,'triple_list': [[" ".join(jieba.cut(t_object)),t_predicate," ".join(jieba.cut(t_subject))]]}
                set_data.append(tmp_text)
    except:
        pass

train_list.extend(add_train_list)
print(len(set_data), set_data[0])
with open(base_path+'/data/CMED/my_train_triples.json', 'w')as fw:
    json.dump(set_data, fw, ensure_ascii=False)
fw.close()

In [None]:
rel2id = json.load(open(base_path+'/data/CMED/my_train_triples.json'))
print(len(rel2id))
rel_id = []
for i in rel2id:
    rel_id.append(i['triple_list'][0][1])
rel_id = list(set(rel_id))
print(rel_id)
print(len(rel_id))


# 四、苏神bert4keras
模型设计过程如下：基于“半指针-半标注”的方式来做抽取，顺序是先抽取s，然后传入s来抽取o、p，不同的只是将模型的整体架构换成了bert：
1、原始序列转id后，传入bert的编码器，得到编码序列；
2、编码序列接两个二分类器，预测s；
3、根据传入的s，从编码序列中抽取出s的首和尾对应的编码向量；
4、以s的编码向量作为条件，对编码序列做一次条件Layer Norm；
5、条件Layer Norm后的序列来预测该s对应的o、p。

In [1]:
import os
base_path = '/home/aid/Github/NLP_relation_extraction'#os.path.dirname(os.getcwd())
data_path = base_path + '/data/CMeIE'#'content/CMED'
model_path = base_path + '/models/chinese-bert-wwm'#'content/Models/chinese-bert-wwm'
print(data_path)

/home/aid/Github/NLP_relation_extraction/data/CMeIE


In [None]:
'''
    bert4keras==0.7.8
    Keras==2.2.4
    numpy==1.16.4
    tensorflow-gpu==1.14.0
    tqdm==4.43.0
'''
import keras
from keras.models import Model
from keras.layers import Input, Dense, Lambda, Reshape
from bert4keras.snippets import open
from bert4keras.optimizers import Adam, extend_with_exponential_moving_average
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.layers import LayerNormalization
from bert4keras.backend import K, batch_gather
from bert4keras.snippets import DataGenerator, sequence_padding
from tqdm import tqdm
import numpy as np
import json
import os

rootPath = os.path.dirname(os.path.abspath(__file__)) # os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
modelsPath = rootPath + '/models_tmp/'
dataPath = rootPath + '/data'

class ReextractBertTrainHandler():
    def __init__(self, params, Train=False):
        self.bert_config_path = modelsPath + "chinese_L-12_H-768_A-12/bert_config.json"
        self.bert_checkpoint_path = modelsPath + "chinese_L-12_H-768_A-12/bert_model.ckpt"
        self.bert_vocab_path = modelsPath + "chinese_L-12_H-768_A-12/vocab.txt"
        self.tokenizer = Tokenizer(self.bert_vocab_path, do_lower_case=True)
        self.model_path = modelsPath + "best_model.weights"
        self.params_path = modelsPath + 'params.json'
        gpu_id = params.get("gpu_id", None)
        print("-----> 选择的gpu_id===", gpu_id)
        self._set_gpu_id(gpu_id)  # 设置训练的GPU_ID
        self.memory_fraction = params.get('memory_fraction')
        if Train:
            self.train_data_file_path = params.get('train_data_path')
            self.valid_data_file_path = params.get('valid_data_path')
            self.maxlen = params.get('maxlen', 128)
            self.batch_size = params.get('batch_size', 32)
            self.epoch = params.get('epoch')
            self.data_process()
        else:
            load_params = json.load(open(self.params_path, encoding='utf-8'))
            self.maxlen = load_params.get('maxlen')
            self.num_classes = load_params.get('num_classes')
            self.p2s_dict = load_params.get('p2s_dict')
            self.i2p_dict = load_params.get('i2p_dict')
            self.p2o_dict = load_params.get('p2o_dict')
        self.build_model()
        if not Train:
            self.load_model()

    def _set_gpu_id(self, gpu_id):
        if gpu_id:
            print("---> gpu_id:", gpu_id, os.environ["CUDA_VISIBLE_DEVICES"])
            os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)

    def data_process(self):
        self.train_data, self.valid_data, self.p2s_dict, self.p2o_dict, self.i2p_dict, self.p2i_dict = data_process(
            self.train_data_file_path, self.valid_data_file_path, self.maxlen, self.params_path)
        self.num_classes = len(self.i2p_dict)
        self.train_generator = Data_Generator(self.train_data, self.batch_size, self.tokenizer, self.p2i_dict,
                                              self.maxlen)

    def extrac_subject(self, inputs):
        """根据subject_ids从output中取出subject的向量表征
        """
        output, subject_ids = inputs
        subject_ids = K.cast(subject_ids, 'int32')
        start = batch_gather(output, subject_ids[:, :1])
        end = batch_gather(output, subject_ids[:, 1:])
        subject = K.concatenate([start, end], 2)
        return subject[:, 0]

    def build_model(self):
        import tensorflow as tf
        from keras.backend.tensorflow_backend import set_session
        config = tf.ConfigProto()
        # A "Best-fit with coalescing" algorithm, simplified from a version of dlmalloc.
        config.gpu_options.allocator_type = 'BFC'
        if self.memory_fraction:
            config.gpu_options.per_process_gpu_memory_fraction = self.memory_fraction
            config.gpu_options.allow_growth = False
        else:
            print('------> no memory_fraction')
            # 重点：设置动态分配GPU
            config.gpu_options.allow_growth = True
            # 设置最大占有GPU不超过显存的80%
            config.gpu_options.per_process_gpu_memory_fraction = 0.8
        set_session(tf.Session(config=config))

        # 补充输入
        subject_labels = Input(shape=(None, 2), name='Subject-Labels')
        subject_ids = Input(shape=(2,), name='Subject-Ids')
        object_labels = Input(
            shape=(None, self.num_classes, 2), name='Object-Labels')
        # 加载预训练模型
        bert = build_transformer_model(
            config_path=self.bert_config_path,
            checkpoint_path=self.bert_checkpoint_path,
            return_keras_model=False,
        )
        # 预测subject
        output = Dense(units=2,
                       activation='sigmoid',
                       kernel_initializer=bert.initializer)(bert.model.output)
        subject_preds = Lambda(lambda x: x ** 2)(output)
        self.subject_model = Model(bert.model.inputs, subject_preds)
        # 传入subject，预测object
        # 通过Conditional Layer Normalization将subject融入到object的预测中
        output = bert.model.layers[-2].get_output_at(-1)
        subject = Lambda(self.extrac_subject)([output, subject_ids])
        output = LayerNormalization(conditional=True)([output, subject])
        output = Dense(units=self.num_classes * 2,
                       activation='sigmoid',
                       kernel_initializer=bert.initializer)(output)
        output = Lambda(lambda x: x ** 4)(output)
        object_preds = Reshape((-1, self.num_classes, 2))(output)
        self.object_model = Model(
            bert.model.inputs + [subject_ids], object_preds)
        # 训练模型
        self.train_model = Model(bert.model.inputs + [subject_labels, subject_ids, object_labels],
                                 [subject_preds, object_preds])
        mask = bert.model.get_layer('Embedding-Token').output_mask
        mask = K.cast(mask, K.floatx())
        subject_loss = K.binary_crossentropy(subject_labels, subject_preds)
        subject_loss = K.mean(subject_loss, 2)
        subject_loss = K.sum(subject_loss * mask) / K.sum(mask)
        object_loss = K.binary_crossentropy(object_labels, object_preds)
        object_loss = K.sum(K.mean(object_loss, 3), 2)
        object_loss = K.sum(object_loss * mask) / K.sum(mask)
        self.train_model.add_loss(subject_loss + object_loss)
        AdamEMA = extend_with_exponential_moving_average(Adam, name='AdamEMA')
        self.optimizer = AdamEMA(lr=1e-4)
        self.train_model.compile(optimizer=self.optimizer)

    def load_model(self):
        self.train_model.load_weights(self.model_path)

    def predict(self, text):
        """
        抽取输入text所包含的三元组
        text：str(<离开>是由张宇谱曲，演唱)
        """
        # print('--->', text)
        tokens = self.tokenizer.tokenize(text, max_length=self.maxlen)
        token_ids, segment_ids = self.tokenizer.encode(
            text, max_length=self.maxlen)
        # 抽取subject
        subject_preds = self.subject_model.predict(
            [[token_ids], [segment_ids]])
        # print('---->', subject_preds)
        start = np.where(subject_preds[0, :, 0] > 0.6)[0]
        end = np.where(subject_preds[0, :, 1] > 0.5)[0]
        subjects = []
        for i in start:
            j = end[end >= i]
            if len(j) > 0:
                j = j[0]
                subjects.append((i, j))
        if subjects:
            spoes = []
            token_ids = np.repeat([token_ids], len(subjects), 0)
            segment_ids = np.repeat([segment_ids], len(subjects), 0)
            subjects = np.array(subjects)
            # 传入subject，抽取object和predicate
            object_preds = self.object_model.predict(
                [token_ids, segment_ids, subjects])
            for subject, object_pred in zip(subjects, object_preds):
                start = np.where(object_pred[:, :, 0] > 0.6)
                end = np.where(object_pred[:, :, 1] > 0.5)
                for _start, predicate1 in zip(*start):
                    for _end, predicate2 in zip(*end):
                        if _start <= _end and predicate1 == predicate2:
                            spoes.append((subject, predicate1, (_start, _end)))
                            break
            i2p_values = []
            for k, v in self.i2p_dict.items():
                i2p_values.append(v)

            return [
                (
                    [self.tokenizer.decode(token_ids[0, s[0]:s[1] + 1], tokens[s[0]:s[1] + 1]),
                     self.p2s_dict[i2p_values[p]]],
                    i2p_values[p],
                    [self.tokenizer.decode(token_ids[0, o[0]:o[1] + 1], tokens[o[0]:o[1] + 1]),
                     self.p2o_dict[i2p_values[p]]],
                    (s[0], s[1] + 1),
                    (o[0], o[1] + 1)
                ) for s, p, o in spoes
            ]
        else:
            return []

    def train(self):
        evaluator = Evaluator(self.train_model, self.model_path, self.tokenizer, self.predict, self.optimizer,
                              self.valid_data)

        self.train_model.fit_generator(self.train_generator.forfit(),
                                       steps_per_epoch=len(
                                           self.train_generator),
                                       epochs=self.epoch,
                                       callbacks=[evaluator])
        
class Data_Generator(DataGenerator):
    """数据生成器
    """

    def __init__(self, data, batch_size, tokenizer, p2i_dict, maxlen):
        super().__init__(data, batch_size=batch_size)
        self.tokenizer = tokenizer
        self.p2i_dict = p2i_dict
        self.maxlen = maxlen

    def sample(self, random=False):
        """采样函数，每个样本同时返回一个is_end标记
        """
        if random:
            if self.steps is None:

                def generator():
                    caches, isfull = [], False
                    for d in self.data:
                        caches.append(d)
                        if isfull:
                            i = np.random.randint(len(caches))
                            yield caches.pop(i)
                        # elif len(caches) == self.buffer_size:
                        #     isfull = True
                    while caches:
                        i = np.random.randint(len(caches))
                        yield caches.pop(i)

            else:

                def generator():
                    indices = list(range(len(self.data)))
                    np.random.shuffle(indices)
                    for i in indices:
                        yield self.data[i]

            data = generator()
        else:
            data = iter(self.data)

        d_current = next(data)
        for d_next in data:
            yield False, d_current
            d_current = d_next

        yield True, d_current

    def __iter__(self, random=False):
        batch_token_ids, batch_segment_ids = [], []
        batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], []
        for is_end, d in self.sample(random):
            token_ids, segment_ids = self.tokenizer.encode(first_text=d['text'], max_length=self.maxlen)
            # 整理三元组 {s: [(o_start,0_end, p)]}/{s_token_ids:[]}
            spoes = {}
            for spo in d['new_spo_list']:
                s = spo['s']
                p = spo['p']
                o = spo['o']
                s_token = self.tokenizer.encode(s['entity'])[0][1:-1]
                p = self.p2i_dict[p['entity']]
                o_token = self.tokenizer.encode(o['entity'])[0][1:-1]
                s_idx = search(s_token, token_ids)  # s_idx s起始位置
                o_idx = search(o_token, token_ids)  # o_idx o起始位置
                if s_idx != -1 and o_idx != -1:
                    s = (s_idx, s_idx + len(s_token) - 1)  # s s起始结束位置，s的类别
                    o = (o_idx, o_idx + len(o_token) - 1, p)  # o o起始结束位置及p的id,o的类别
                    if s not in spoes:
                        spoes[s] = []
                    spoes[s].append(o)
            if spoes:
                # subject标签，采用二维向量分别标记subject的起始位置和结束位置
                subject_labels = np.zeros((len(token_ids), 2))
                for s in spoes:
                    subject_labels[s[0], 0] = 1
                    subject_labels[s[1], 1] = 1
                # 随机选一个subject
                start, end = np.array(list(spoes.keys())).T
                start = np.random.choice(start)
                end = np.random.choice(end[end >= start])
                subject_ids = (start, end)
                # 对应的object标签
                object_labels = np.zeros((len(token_ids), len(self.p2i_dict), 2))
                for o in spoes.get(subject_ids, []):
                    object_labels[o[0], o[2], 0] = 1
                    object_labels[o[1], o[2], 1] = 1
                # 构建batch
                batch_token_ids.append(token_ids)
                batch_segment_ids.append(segment_ids)
                batch_subject_labels.append(subject_labels)
                batch_subject_ids.append(subject_ids)
                batch_object_labels.append(object_labels)
                if len(batch_token_ids) == self.batch_size or is_end:
                    batch_token_ids = sequence_padding(batch_token_ids)
                    batch_segment_ids = sequence_padding(batch_segment_ids)
                    batch_subject_labels = sequence_padding(batch_subject_labels, padding=np.zeros(2))
                    batch_subject_ids = np.array(batch_subject_ids)
                    batch_object_labels = sequence_padding(batch_object_labels,
                                                           padding=np.zeros((3, 2)))
                    yield [
                              batch_token_ids, batch_segment_ids,
                              batch_subject_labels, batch_subject_ids, batch_object_labels

                          ], None
                    batch_token_ids, batch_segment_ids = [], []
                    batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], []
                    

class Evaluator(keras.callbacks.Callback):
    """评估和保存模型
    """

    def __init__(self, model, model_path, tokenizer,predict,optimizer,valid_data):
        self.EMAer = optimizer
        self.best_val_f1 = 0.
        self.model = model
        self.model_path = model_path
        self.tokenizer = tokenizer
        self.predict = predict
        self.valid_data = valid_data

    def on_epoch_end(self, epoch, logs=None):
        self.EMAer.apply_ema_weights()
        f1, precision, recall = evaluate(self.tokenizer,self.valid_data,self.predict)
        if f1 >= self.best_val_f1:
            self.best_val_f1 = f1
            self.model.save_weights(self.model_path)
        self.EMAer.reset_old_weights()
        print('f1: %.5f, precision: %.5f, recall: %.5f, best f1: %.5f/n' %
              (f1, precision, recall, self.best_val_f1))

def search(pattern, sequence):
    """从sequence中寻找子串pattern
    如果找到，返回第一个下标；否则返回-1。
    """
    n = len(pattern)
    for i in range(len(sequence)):
        if sequence[i:i + n] == pattern:
            return i
    return -1

def evaluate(tokenizer,data,predict):
    """评估函数，计算f1、precision、recall
    """
    X, Y, Z = 1e-10, 1e-10, 1e-10
    f = open('dev_pred.json', 'w', encoding='utf-8')
    pbar = tqdm()

    class SPO(tuple):
        """用来存三元组的类
        表现跟tuple基本一致，只是重写了 __hash__ 和 __eq__ 方法，
        使得在判断两个三元组是否等价时容错性更好。
        """

        def __init__(self, spo):
            self.spox = (
                tuple(spo[0]),
                spo[1],
                tuple(spo[2]),
            )

        def __hash__(self):
            return self.spox.__hash__()

        def __eq__(self, spo):
            return self.spox == spo.spox

    for d in data:
        R = set([SPO(spo) for spo in
                 [[tokenizer.tokenize(spo_str[0][0]), spo_str[1], tokenizer.tokenize(spo_str[2][0])] for
                  spo_str
                  in predict(d['text'])]])
        T = set([SPO(spo) for spo in
                 [[tokenizer.tokenize(spo_str['s']['entity']), spo_str['p']['entity'],
                   tokenizer.tokenize(spo_str['o']['entity'])] for spo_str
                  in d['new_spo_list']]])
        X += len(R & T)
        Y += len(R)
        Z += len(T)
        f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
        pbar.update()
        pbar.set_description('f1: %.5f, precision: %.5f, recall: %.5f' %
                             (f1, precision, recall))
        s = json.dumps(
            {
                'text': d['text'],
                'spo_list': list(T),
                'spo_list_pred': list(R),
                'new': list(R - T),
                'lack': list(T - R),
            },
            ensure_ascii=False,
            indent=4)
        f.write(s + '/n')
    pbar.close()
    f.close()
    return f1, precision, recall

def data_process(train_data_file_path, valid_data_file_path, max_len, params_path):
    train_data = json.load(open(train_data_file_path, encoding='utf-8'))

    if valid_data_file_path:
        train_data_ret = train_data
        valid_data_ret = json.load(open(valid_data_file_path, encoding='utf-8'))
    else:
        split = int(len(train_data) * 0.8)
        train_data_ret, valid_data_ret = train_data[:split], train_data[split:]
    p2s_dict = {}
    p2o_dict = {}
    predicate = []

    for content in train_data:
        for spo in content.get('new_spo_list'):
            s_type = spo.get('s').get('type')
            p_key = spo.get('p').get('entity')
            o_type = spo.get('o').get('type')
            if p_key not in p2s_dict:
                p2s_dict[p_key] = s_type
            if p_key not in p2o_dict:
                p2o_dict[p_key] = o_type
            if p_key not in predicate:
                predicate.append(p_key)
    i2p_dict = {i: key for i, key in enumerate(predicate)}
    p2i_dict = {key: i for i, key in enumerate(predicate)}
    save_params = {}
    save_params['p2s_dict'] = p2s_dict
    save_params['i2p_dict'] = i2p_dict
    save_params['p2o_dict'] = p2o_dict
    save_params['maxlen'] = max_len
    save_params['num_classes'] = len(i2p_dict)
    # 数据保存
    json.dump(save_params,
              open(params_path, 'w', encoding='utf-8'),
              ensure_ascii=False, indent=4)
    return train_data_ret, valid_data_ret, p2s_dict, p2o_dict, i2p_dict, p2i_dict




if __name__=="__main__":
    # 训练
    params = {
        "gpu_id": 0,
        "maxlen": 128,
        "batch_size": 32,
        "epoch": 10,
        "train_data_path": dataPath + "/train_data.json",
        # "valid_data_path": dataPath + "/valid_test.json",
        "dev_data_path": dataPath + "/valid_data.json",
    }
    model = ReextractBertTrainHandler(params, Train=True)
    model.train()
    
    text = "胃壁、肠管壁未见明确增厚及肿块影，肠腔未见异常扩张"
    print('-->结果：', model.predict(text))