In [1]:
import os
import sys
import random
import pickle
import numpy as np
from tqdm import tqdm
import tensorflow as tf 
from bert4keras.backend import K,keras,search_layer
from bert4keras.snippets import ViterbiDecoder,to_array

from data_load import *
from build_model import bert_bilstm_crf

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
Using TensorFlow backend.


In [2]:
# 固定随机种子
seed = 233
tf.set_random_seed(seed)
np.random.seed(seed)
os.environ['PYTHONHSHSEED'] = str(seed)

# 权重参数
epochs = 4
batch_size = 16
lstm_units = 128
drop_rate = 0.1 #有改动0.1-》0.01
learning_rate = 5e-5
max_len =168

#精细训练
fine_train_list = [0 for i in range(8275)]
train_predict_list = []

# 权重路径
config_path = './bert_weight_file/uncased_L-4_H-768_A-12/bert_config.json'
checkpoint_path = './bert_weight_file/uncased_L-4_H-768_A-12/bert_model.ckpt'

# 模型保存路径
model_save_path = './save_model/bert_bilstm_crf.weight'
CRF_save_path = './save_model/CRF.npy'

class NamedEntityRecognizer(ViterbiDecoder):
    """命名实体识别器
    """
    def recognize(self, text):
        tokens = tokenizer.tokenize(text)
        while len(tokens) > max_len:
            tokens.pop(-2)
        mapping = tokenizer.rematch(text, tokens)
        token_ids = tokenizer.tokens_to_ids(tokens)
        segment_ids = [0] * len(token_ids)
        token_ids, segment_ids = to_array([token_ids], [segment_ids]) # ndarray
        nodes = model.predict([token_ids, segment_ids])[0] # [sqe_len,23]
        labels = self.decode(nodes) # id [sqe_len,], [0 0 0 0 0 7 8 8 0 0 0 0 0 0 0]
        entities, starting = [], False
        for i, label in enumerate(labels):
            if label > 0:
                if label % 2 == 1:
                    starting = True
                    entities.append([[i], id2label[(label - 1) // 2]])
                elif starting:
                    entities[-1][0].append(i)
                else:
                    starting = False
            else:
                starting = False
        return [(text[mapping[w[0]][0]:mapping[w[-1]][-1] + 1], l) for w, l in entities]
    
#相等应加set（）中源文本的数量    
def ner_metrics(data,fine_train_list):
    X,Y,Z = 1e-6,1e-6,1e-6
    count = 0
    for d in tqdm(data):
        text = ''.join([i[0] for i in d])
        pred= NER.recognize(text)
        R = set(pred)
        T = set([tuple(i) for i in d if i[1] != 'O'])
        
        # 便于T和R做交集
        m = []
        for i in T:
            for j in i[0]:
                m.append((j,i[1]))
        T = set(m)
        
        # 填充train_predict_list,更新fine_train_list
        if len(T) > 0 :  
            if len(train_predict_list) < 8275:
                train_predict_list.append(R&T)
            else:
                if len(R&T) > fine_train_list[count]:
#                     print('text: ',text)
#                     print('T: ',T)
#                     print('R&T: ',R&T)
                    train_predict_list[count] = R&T
            if len(R&T) > fine_train_list[count]:
                fine_train_list[count] = len(R&T)
            
            
#         if len(T) < fine_train_list[count]:
#             print(False)
#             print('text: ',text)
#             print('T: ',T)
#             print('R&T: ',R&T)
#             print('fine_train_list[count]: ',fine_train_list[count])
#             print()

        X += fine_train_list[count]
        if len(R) < fine_train_list[count]:
            Y += fine_train_list[count]
        else:
            Y += len(R)
            
        Z += len(T)
        count += 1

    f1,precision,recall = 2 * X / (Y + Z),X / Y,X / Z
    return f1,precision,recall

class Evaluator(keras.callbacks.Callback):
    def __init__(self):
        super(Evaluator, self).__init__()
        self.best_val_f1 = 0
    def on_epoch_end(self, epoch,logs=None):
        NER.trans = K.eval(CRF.trans) # 可能有错
        f1, precision, recall = ner_metrics(valid_data,fine_train_list)
        if f1 > self.best_val_f1:
            model.save_weights(model_save_path)
            self.best_val_f1 = f1
            print('save model to {}'.format(checkpoint_path))
        else:
            global learning_rate
            learning_rate = learning_rate / 5
        print(
              'valid: f1: %.5f, precision: %.5f, recall: %.5f, best f1: %.5f\n' %
              (f1,precision,recall,self.best_val_f1)
        )
        
# def adversarial_training(model, embedding_name, epsilon=1):
#     """
#     给模型添加对抗训练
#     其中model是需要添加对抗训练的keras模型
#     """
#     if model.train_function is None:  # 如果还没有训练函数
#         model._make_train_function()  # 手动make
#     old_train_function = model.train_function  # 备份旧的训练函数

#     # 查找Embedding层
#     for output in model.outputs:
#         embedding_layer = search_layer(output, embedding_name)
#         if embedding_layer is not None:
#             break
#     if embedding_layer is None:
#         raise Exception('Embedding layer not found')

#     # 求Embedding梯度
#     embeddings = embedding_layer.embeddings  # Embedding矩阵
#     gradients = K.gradients(model.total_loss, [embeddings])  # Embedding梯度
#     gradients = K.zeros_like(embeddings) + gradients[0]  # 转为dense tensor

#     # 封装为函数
#     inputs = (
#         model._feed_inputs + model._feed_targets + model._feed_sample_weights
#     )  # 所有输入层
#     embedding_gradients = K.function(
#         inputs=inputs,
#         outputs=[gradients],
#         name='embedding_gradients',
#     )  # 封装为函数

#     def train_function(inputs):
#         # 重新定义训练函数
#         grads = embedding_gradients(inputs)[0]  # Embedding梯度
#         delta = epsilon * grads / (np.sqrt((grads**2).sum()) + 1e-8)  # 计算扰动
#         K.set_value(embeddings, K.eval(embeddings) + delta)  # 注入扰动
#         outputs = old_train_function(inputs)  # 梯度下降
#         K.set_value(embeddings, K.eval(embeddings) - delta)  # 删除扰动
#         return outputs
#     model.train_function = train_function  # 覆盖原训练函数        



model,CRF = bert_bilstm_crf(config_path,checkpoint_path,num_labels,lstm_units,drop_rate,learning_rate)
# adversarial_training(model,'Embedding-Token',0.5)
NER = NamedEntityRecognizer(trans=K.eval(CRF.trans), starts=[0], ends=[0])

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


In [4]:
if __name__ == "__main__":
    train_data,_ = load_data('./ner_data/train/train.txt',128)
    valid_data,_ = load_data('./ner_data/dev/test.txt',128)
    

    print(train_data[0:10])

[[['医生：你好我是您的接诊医生', 'O']], [['医生：宝贝最近吃奶量可以吗？下降了吗', 'O']], [['患者：没有，也没怎么', 'O'], ['哭闹', 'Symptom']], [['医生：宝妈有没有吃生冷辛辣刺激食物油腻食物来吗？', 'O']], [['医生：宝贝奶粉的话最近换过牌子吗？', 'O']], [['医生：宝贝肚子着凉来吗？', 'O']], [['患者：喝茶油腻也少，菜吃很多', 'O']], [['医生：嗯嗯，宝妈饮食一定注意，生冷辛辣刺激食物不能吃油腻食物不能吃，特别油腻食物的奥，清淡饮食为主，这个时候宝贝胃肠功能可能会有影响，能吃多少吃多少别强喂的奥！', 'O']], [['医生：宝贝最近有没有', 'O'], ['呕吐', 'Symptom'], ['症状呢？', 'O']], [['患者：', 'O'], ['呕吐', 'Symptom'], ['，有时会', 'O'], ['吐', 'Symptom'], ['，不多', 'O']]]


In [5]:
if __name__ == "__main__":
    train_data,_ = load_data('./ner_data/train/train.txt',128)
    valid_data,_ = load_data('./ner_data/dev/test.txt',128)
    
    flag = False
    count = 0
    i = 0
    while(i<len(train_data)):
        if flag==True:
            i = i-1
        if train_data[i][0][1] == 'O'and len(train_data[i])==1:
            del train_data[i]
            flag = True
            count+=1
        else:
            for j in range(count):
                train_data.append(train_data[i])
            flag = False
            count = 0
        i += 1
    
    train_generator = data_generator(train_data, batch_size)
    valid_generator = data_generator(valid_data, batch_size*5)
    
    evaluator = Evaluator()
    
    def scheduler(epoch):
        return learning_rate/(max(2*(epoch-1),1))

    lr_scheduler = keras.callbacks.LearningRateScheduler(scheduler)

    
    model.fit(
        train_generator.forfit(),
        steps_per_epoch = len(train_generator),
        validation_data = valid_generator.forfit(),
        validation_steps = len(valid_generator),
        epochs = epochs,
        callbacks = [evaluator,lr_scheduler]
    )
    
    print(K.eval(CRF.trans))
    print(K.eval(CRF.trans).shape)
    model.save_weights(model_save_path)
    np.save(CRF_save_path, K.eval(CRF.trans))

    # torch.save(model, model_save_path)
    # pickle.dump(K.eval(CRF.trans),open('./save_model/crf_trans.pkl','rb'))
    
else:
    # model = torch.load(model_save_path)
    model.load_weights(model_save_path)
    # NER.trans = pickle.load(open('./save_model/crf_trans.pkl','rb'))


Epoch 1/4


100%|██████████████████████████████████████████████████████████████████████████████| 8275/8275 [03:31<00:00, 39.07it/s]


save model to ./bert_weight_file/uncased_L-4_H-768_A-12/bert_model.ckpt
valid: f1: 0.91758, precision: 0.86984, recall: 0.97086, best f1: 0.91758

Epoch 2/4


100%|██████████████████████████████████████████████████████████████████████████████| 8275/8275 [03:30<00:00, 39.34it/s]


valid: f1: 0.86092, precision: 0.77163, recall: 0.97357, best f1: 0.91758

Epoch 3/4


100%|██████████████████████████████████████████████████████████████████████████████| 8275/8275 [03:26<00:00, 40.07it/s]


valid: f1: 0.89976, precision: 0.83598, recall: 0.97408, best f1: 0.91758

Epoch 4/4


100%|██████████████████████████████████████████████████████████████████████████████| 8275/8275 [03:27<00:00, 39.97it/s]


valid: f1: 0.90911, precision: 0.85226, recall: 0.97408, best f1: 0.91758

[[ 0.57122225 -0.6133416  -0.0752689   0.11904959 -0.41914526 -0.5894722
  -0.48531607 -0.34697372 -0.05795071 -0.6759177  -0.3357175 ]
 [-0.3182494  -0.2311095  -0.37499154  0.2069319   0.19512522 -0.5219225
  -0.3076908  -0.22989158  0.2468979  -0.01747333 -0.08698532]
 [-0.5956915  -0.24547535 -0.18386608 -0.3560527  -0.36180395 -0.33080438
   0.45021257 -0.05060589 -0.49123704 -0.14037913 -0.17109317]
 [-0.61277467  0.04253093  0.21743886 -0.00321927 -0.0349032   0.00765893
   0.18040866  0.00577142 -0.2869069  -0.64182186 -0.1522031 ]
 [-0.65195215 -0.21891981  0.37502876 -0.19458427  0.07435843 -0.49571252
  -0.51378894  0.12189484  0.4857217  -0.39254132  0.51117957]
 [ 0.30123577  0.3590206   0.11038376  0.14211111 -0.14814666 -0.2882339
  -0.33709186 -0.50399256 -0.60252005 -0.01256281 -0.3170192 ]
 [-0.7756505   0.08338065  0.26491472 -0.35381716  0.34544346 -0.21493115
  -0.02121825 -0.06578448  0.229

In [17]:
print(fine_train_list[0:100])
print(train_predict_list[0:10])

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 0, 0, 0, 2, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 8, 3, 0, 0, 0, 0, 2, 0, 0, 4, 0, 1, 1, 0, 4, 2, 0, 0, 0, 0, 4, 0, 0, 0, 0, 2, 2]
[set(), {('肚', 'Symptom'), ('拉', 'Symptom'), ('子', 'Symptom')}, set(), set(), set(), set(), set(), set(), set(), {('肚', 'Symptom'), ('拉', 'Symptom'), ('子', 'Symptom')}]


In [6]:
import numpy as np
# 保存矩阵
fine=np.array(fine_train_list)
tpl = np.array(train_predict_list)
np.save('./fine_train_list.npy',fine)
np.save('./train_predict_list.npy',tpl)

In [7]:
model.save_weights(model_save_path)

In [3]:
# 下载矩阵
fine = np.load('./fine_train_list.npy')
fine_train_list = fine.tolist()
tpl = np.load('./train_predict_list.npy',allow_pickle=True)
train_predict_list = tpl.tolist()

In [4]:
model.load_weights(model_save_path)

In [5]:
if __name__ == "__main__":
    epochs = 2
    
    train_data,_ = load_data('./ner_data/train/train.txt',128)
    valid_data,_ = load_data('./ner_data/dev/test.txt',128)
    

    train_generator = data_generator(train_data, batch_size)
    valid_generator = data_generator(valid_data, batch_size*5)
    
    checkpoint = keras.callbacks.ModelCheckpoint(
        model_save_path,
        monitor = 'val_sparse_accuracy',
        verbose = 1,
        save_best_only = True,
        mode = 'max'
    )
    evaluator = Evaluator()
    
#     def scheduler(epoch):
#         return learning_rate/(max(2*(epoch-1),1))

#     lr_scheduler = keras.callbacks.LearningRateScheduler(scheduler)

    
    model.fit(
        train_generator.forfit(),
        steps_per_epoch = len(train_generator),
        validation_data = valid_generator.forfit(),
        validation_steps = len(valid_generator),
        epochs = epochs,
        callbacks = [evaluator]
    )
    
    print(K.eval(CRF.trans))
    print(K.eval(CRF.trans).shape)
    model.save_weights(model_save_path)
    np.save(CRF_save_path, K.eval(CRF.trans))

    # torch.save(model, model_save_path)
    # pickle.dump(K.eval(CRF.trans),open('./save_model/crf_trans.pkl','rb'))
    
else:
    # model = torch.load(model_save_path)
    model.load_weights(model_save_path)
    # NER.trans = pickle.load(open('./save_model/crf_trans.pkl','rb'))


Epoch 1/2


100%|██████████████████████████████████████████████████████████████████████████████| 8275/8275 [03:31<00:00, 39.11it/s]


save model to ./bert_weight_file/uncased_L-4_H-768_A-12/bert_model.ckpt
valid: f1: 0.95462, precision: 0.93437, recall: 0.97576, best f1: 0.95462

Epoch 2/2


100%|██████████████████████████████████████████████████████████████████████████████| 8275/8275 [03:27<00:00, 39.79it/s]


valid: f1: 0.94940, precision: 0.92415, recall: 0.97606, best f1: 0.95462

[[ 0.645299   -0.7370785  -0.56542546  0.00503931 -0.90727884 -0.68640953
  -0.9690469  -0.39257857 -0.5478962  -0.8409375  -0.8255535 ]
 [-0.4483786  -0.15622358 -0.6728855   0.10875657 -0.1081326  -0.6022599
  -0.60791546 -0.27157855 -0.05456834 -0.07525794 -0.39097717]
 [-1.0685419  -0.55251616 -0.18916054 -0.59706146 -0.36600417 -0.6172872
   0.4415291  -0.17096259 -0.4977     -0.5430737  -0.17884265]
 [-0.7189193  -0.073939   -0.01607681  0.06212772 -0.26827216 -0.0518255
  -0.05279429 -0.05428631 -0.51619667 -0.7174855  -0.38692203]
 [-1.1253781  -0.5278463   0.36783957 -0.4331767   0.06927758 -0.7758725
  -0.51649743  0.00273154  0.47274798 -0.78524625  0.49940804]
 [ 0.1995825   0.28520766 -0.13622653  0.07774224 -0.39409238 -0.20836465
  -0.5750944  -0.6027726  -0.8602776  -0.08078734 -0.5711158 ]
 [-1.2439601  -0.22419953  0.25872597 -0.5886075   0.33913046 -0.49701262
  -0.02538386 -0.18118031  0.2200

In [4]:
a = {('者', 'Drug'), ('六', 'Medical_Examination'), ('今', 'Medical_Examination'), ('医院，', 'Medical_Examination'), ('药', 'Medical_Examination'), ('服医院的', 'Medical_Examination'), ('口', 'Medical_Examination'), ('：', 'Medical_Examination'), ('天', 'Medical_Examination'), ('去的', 'Medical_Examination'), ('次', 'Medical_Examination'), ('是', 'Medical_Examination'), ('中', 'Medical_Examination'), ('明', 'Medical_Examination'), ('周', 'Medical_Examination'), ('午', 'Medical_Examination'), ('上', 'Medical_Examination')}
v = {('题', 'Medical_Examination'), ('么', 'Medical_Examination'), ('可', 'Medical_Examination'), ('留', 'Medical_Examination'), ('以给我', 'Medical_Examination'), ('什', 'Medical_Examination'), ('问', 'Medical_Examination')}
c = []
m = {('题', 'Medical_Examination'), ('么', 'Medical_Examination'), ('可', 'Medical_Examination')}
h = set()
c.append(a)
c.append(m)
c.append(v)
c.append(h)
c[1] & c[2]

{('么', 'Medical_Examination'),
 ('可', 'Medical_Examination'),
 ('题', 'Medical_Examination')}

In [5]:
c[0] = {('题', 'Medical_Examination'), ('么', 'Medical_Examination'), ('可', 'Medical_Examination')}
c

[{('么', 'Medical_Examination'),
  ('可', 'Medical_Examination'),
  ('题', 'Medical_Examination')},
 {('么', 'Medical_Examination'),
  ('可', 'Medical_Examination'),
  ('题', 'Medical_Examination')},
 {('么', 'Medical_Examination'),
  ('什', 'Medical_Examination'),
  ('以给我', 'Medical_Examination'),
  ('可', 'Medical_Examination'),
  ('留', 'Medical_Examination'),
  ('问', 'Medical_Examination'),
  ('题', 'Medical_Examination')},
 set()]

In [6]:
# 学习率分层：f1：0.40678

In [7]:
import re
import numpy as np
from tqdm import tqdm
import math


In [11]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from collections import defaultdict

import numpy as np


def get_entities(seq, suffix=False):
    """Gets entities from sequence.

    Args:
        seq (list): sequence of labels.

    Returns:
        list: list of (chunk_type, chunk_start, chunk_end).

    Example:
        # >>> from seqeval.metrics.sequence_labeling import get_entities
        # >>> seq = ['B-PER', 'I-PER', 'O', 'B-LOC']
        # >>> get_entities(seq)
        [('PER', 0, 1), ('LOC', 3, 3)]
    """
    # for nested list
    if any(isinstance(s, list) for s in seq):
        seq = [item for sublist in seq for item in sublist + ['O']]

    prev_tag = 'O'
    prev_type = ''
    begin_offset = 0
    chunks = []
    # print(seq)
    for i, chunk in enumerate(seq + ['O']):
        # print(i,chunk)
        if suffix:
            tag = chunk[-1]
            type_ = chunk.split('-')[0]
        else:
            try:
                tag = chunk[0]
                type_ = chunk.split('-')[-1]
            except IndexError:
                tag = 'O'
                type_ = 'O'
        if end_of_chunk(prev_tag, tag, prev_type, type_):
            chunks.append((prev_type, begin_offset, i-1))
        if start_of_chunk(prev_tag, tag, prev_type, type_):
            begin_offset = i
        prev_tag = tag
        prev_type = type_

    return chunks


def end_of_chunk(prev_tag, tag, prev_type, type_):
    """Checks if a chunk ended between the previous and current word.

    Args:
        prev_tag: previous chunk tag.
        tag: current chunk tag.
        prev_type: previous type.
        type_: current type.

    Returns:
        chunk_end: boolean.
    """
    chunk_end = False

    if prev_tag == 'E': chunk_end = True
    if prev_tag == 'S': chunk_end = True

    if prev_tag == 'B' and tag == 'B': chunk_end = True
    if prev_tag == 'B' and tag == 'S': chunk_end = True
    if prev_tag == 'B' and tag == 'O': chunk_end = True
    if prev_tag == 'I' and tag == 'B': chunk_end = True
    if prev_tag == 'I' and tag == 'S': chunk_end = True
    if prev_tag == 'I' and tag == 'O': chunk_end = True

    if prev_tag != 'O' and prev_tag != '.' and prev_type != type_:
        chunk_end = True

    return chunk_end


def start_of_chunk(prev_tag, tag, prev_type, type_):
    """Checks if a chunk started between the previous and current word.

    Args:
        prev_tag: previous chunk tag.
        tag: current chunk tag.
        prev_type: previous type.
        type_: current type.

    Returns:
        chunk_start: boolean.
    """
    chunk_start = False

    if tag == 'B': chunk_start = True
    if tag == 'S': chunk_start = True

    if prev_tag == 'E' and tag == 'E': chunk_start = True
    if prev_tag == 'E' and tag == 'I': chunk_start = True
    if prev_tag == 'S' and tag == 'E': chunk_start = True
    if prev_tag == 'S' and tag == 'I': chunk_start = True
    if prev_tag == 'O' and tag == 'E': chunk_start = True
    if prev_tag == 'O' and tag == 'I': chunk_start = True

    if tag != 'O' and tag != '.' and prev_type != type_:
        chunk_start = True

    return chunk_start


def f1_score(y_true: object, y_pred: object, average: object = 'micro', suffix: object = False) -> object:
    """Compute the F1 score.

    The F1 score can be interpreted as a weighted average of the precision and
    recall, where an F1 score reaches its best value at 1 and worst score at 0.
    The relative contribution of precision and recall to the F1 score are
    equal. The formula for the F1 score is::

        F1 = 2 * (precision * recall) / (precision + recall)

    Args:
        y_true : 2d array. Ground truth (correct) target values.
        y_pred : 2d array. Estimated targets as returned by a tagger.

    Returns:
        score : float.

    Example:
        # >>> from seqeval.metrics import f1_score
        # >>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
        # >>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
        # >>> f1_score(y_true, y_pred)
        0.50
    """
    true_entities = set(get_entities(y_true, suffix))
    pred_entities = set(get_entities(y_pred, suffix))

    nb_correct = len(true_entities & pred_entities)
    nb_pred = len(pred_entities)
    nb_true = len(true_entities)

    p = nb_correct / nb_pred if nb_pred > 0 else 0
    r = nb_correct / nb_true if nb_true > 0 else 0
    score = 2 * p * r / (p + r) if p + r > 0 else 0

    return score


def accuracy_score(y_true, y_pred):
    """Accuracy classification score.

    In multilabel classification, this function computes subset accuracy:
    the set of labels predicted for a sample must *exactly* match the
    corresponding set of labels in y_true.

    Args:
        y_true : 2d array. Ground truth (correct) target values.
        y_pred : 2d array. Estimated targets as returned by a tagger.

    Returns:
        score : float.

    Example:
        # >>> from seqeval.metrics import accuracy_score
        # >>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
        # >>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
        # >>> accuracy_score(y_true, y_pred)
        0.80
    """
    if any(isinstance(s, list) for s in y_true):
        y_true = [item for sublist in y_true for item in sublist]
        y_pred = [item for sublist in y_pred for item in sublist]

    nb_correct = sum(y_t==y_p for y_t, y_p in zip(y_true, y_pred))
    nb_true = len(y_true)

    score = nb_correct / nb_true

    return score


def precision_score(y_true, y_pred, average='micro', suffix=False):
    """Compute the precision.

    The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of
    true positives and ``fp`` the number of false positives. The precision is
    intuitively the ability of the classifier not to label as positive a sample.

    The best value is 1 and the worst value is 0.

    Args:
        y_true : 2d array. Ground truth (correct) target values.
        y_pred : 2d array. Estimated targets as returned by a tagger.

    Returns:
        score : float.

    Example:
        # >>> from seqeval.metrics import precision_score
        # >>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
        # >>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
        # >>> precision_score(y_true, y_pred)
        0.50
    """
    true_entities = set(get_entities(y_true, suffix))
    pred_entities = set(get_entities(y_pred, suffix))

    nb_correct = len(true_entities & pred_entities)
    nb_pred = len(pred_entities)

    score = nb_correct / nb_pred if nb_pred > 0 else 0

    return score


def recall_score(y_true, y_pred, average='micro', suffix=False):
    """Compute the recall.

    The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of
    true positives and ``fn`` the number of false negatives. The recall is
    intuitively the ability of the classifier to find all the positive samples.

    The best value is 1 and the worst value is 0.

    Args:
        y_true : 2d array. Ground truth (correct) target values.
        y_pred : 2d array. Estimated targets as returned by a tagger.

    Returns:
        score : float.

    Example:
        # >>> from seqeval.metrics import recall_score
        # >>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
        # >>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
        # >>> recall_score(y_true, y_pred)
        0.50
    """
    true_entities = set(get_entities(y_true, suffix))
    pred_entities = set(get_entities(y_pred, suffix))

    nb_correct = len(true_entities & pred_entities)
    nb_true = len(true_entities)

    score = nb_correct / nb_true if nb_true > 0 else 0

    return score


def performance_measure(y_true, y_pred):
    """
    Compute the performance metrics: TP, FP, FN, TN

    Args:
        y_true : 2d array. Ground truth (correct) target values.
        y_pred : 2d array. Estimated targets as returned by a tagger.

    Returns:
        performance_dict : dict

    Example:
        # >>> from seqeval.metrics import performance_measure
        # >>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'O', 'B-ORG'], ['B-PER', 'I-PER', 'O']]
        # >>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O', 'O'], ['B-PER', 'I-PER', 'O']]
        # >>> performance_measure(y_true, y_pred)
        (3, 3, 1, 4)
    """
    performace_dict = dict()
    if any(isinstance(s, list) for s in y_true):
        y_true = [item for sublist in y_true for item in sublist]
        y_pred = [item for sublist in y_pred for item in sublist]
    performace_dict['TP'] = sum(y_t == y_p for y_t, y_p in zip(y_true, y_pred)
                                if ((y_t != 'O') or (y_p != 'O')))
    performace_dict['FP'] = sum(y_t != y_p for y_t, y_p in zip(y_true, y_pred))
    performace_dict['FN'] = sum(((y_t != 'O') and (y_p == 'O'))
                                for y_t, y_p in zip(y_true, y_pred))
    performace_dict['TN'] = sum((y_t == y_p == 'O')
                                for y_t, y_p in zip(y_true, y_pred))

    return performace_dict


def classification_report(y_true, y_pred, digits=2, suffix=False):
    """Build a text report showing the main classification metrics.

    Args:
        y_true : 2d array. Ground truth (correct) target values.
        y_pred : 2d array. Estimated targets as returned by a classifier.
        digits : int. Number of digits for formatting output floating point values.

    Returns:
        report : string. Text summary of the precision, recall, F1 score for each class.

    Examples:
        # >>> from seqeval.metrics import classification_report
        # >>> y_true = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
        # >>> y_pred = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
        # >>> print(classification_report(y_true, y_pred))
                     precision    recall  f1-score   support
        <BLANKLINE>
               MISC       0.00      0.00      0.00         1
                PER       1.00      1.00      1.00         1
        <BLANKLINE>
          micro avg       0.50      0.50      0.50         2
          macro avg       0.50      0.50      0.50         2
        <BLANKLINE>
    """
    true_entities = set(get_entities(y_true, suffix))
    pred_entities = set(get_entities(y_pred, suffix))

    name_width = 0
    d1 = defaultdict(set)
    d2 = defaultdict(set)
    for e in true_entities:
        d1[e[0]].add((e[1], e[2]))
        name_width = max(name_width, len(e[0]))
    for e in pred_entities:
        d2[e[0]].add((e[1], e[2]))

    last_line_heading = 'macro avg'
    width = max(name_width, len(last_line_heading), digits)

    headers = ["precision", "recall", "f1-score", "support"]
    head_fmt = u'{:>{width}s} ' + u' {:>9}' * len(headers)
    report = head_fmt.format(u'', *headers, width=width)
    report += u'\n\n'

    row_fmt = u'{:>{width}s} ' + u' {:>9.{digits}f}' * 3 + u' {:>9}\n'

    ps, rs, f1s, s = [], [], [], []
    for type_name, true_entities in d1.items():
        pred_entities = d2[type_name]
        nb_correct = len(true_entities & pred_entities)
        nb_pred = len(pred_entities)
        nb_true = len(true_entities)

        p = nb_correct / nb_pred if nb_pred > 0 else 0
        r = nb_correct / nb_true if nb_true > 0 else 0
        f1 = 2 * p * r / (p + r) if p + r > 0 else 0

        report += row_fmt.format(*[type_name, p, r, f1, nb_true], width=width, digits=digits)

        ps.append(p)
        rs.append(r)
        f1s.append(f1)
        s.append(nb_true)

    report += u'\n'

    # compute averages
    report += row_fmt.format('micro avg',
                             precision_score(y_true, y_pred, suffix=suffix),
                             recall_score(y_true, y_pred, suffix=suffix),
                             f1_score(y_true, y_pred, suffix=suffix),
                             np.sum(s),
                             width=width, digits=digits)
    report += row_fmt.format(last_line_heading,
                             np.average(ps, weights=s),
                             np.average(rs, weights=s),
                             np.average(f1s, weights=s),
                             np.sum(s),
                             width=width, digits=digits)

    return report


def report_span_accuracy(_true, _pred):
    """
    calculate span accuracy, namely ignore class label. Just check whether the predicted span is right.
    :param _true:
    :param _pred:
    :return:
    """
    y_true, y_pred = [], []
    for _t in _true:
        y_true.append([tag if tag == 'O' else '%s-d' % tag[0] for tag in _t])  # add dummy class-label: B-a --> B-d
    for _p in _pred:
        y_pred.append([tag if tag == 'O' else '%s-d' % tag[0] for tag in _p])  # add dummy class-label: B-a --> B-d
    print('Span accuracy:')
    print(classification_report(y_true, y_pred))

 
    
def load_eval_data(data_path,max_len):
    X = []
    y = []
    sentence = []
    labels = []
    split_pattern = re.compile(r'[；;。，、？！\.\?,! ]')
    with open(data_path,'r',encoding = 'utf8') as f:
        for line in f.readlines():
            #每行为一个字符和其tag，中间用tab或者空格隔开
            # sentence = [w1,w2,w3,...,wn], labels=[B-xx,I-xxx,,,...,O]
            line = line.strip().split()
            if(not line or len(line) < 2): 
                X.append(sentence)
                y.append(labels)
                sentence = []
                labels = []
                continue
            #word, tag = line[0], line[1].replace('_','-').replace('M','I').replace('E','I').replace('S','B') # BMES -> BIO
            word, tag = line[0], line[1]
            if split_pattern.match(word) and len(sentence)+8 >= max_len:
                sentence.append(word)
                labels.append(tag)
                X.append(sentence)
                y.append(labels)
                sentence = []
                labels = []
            else:
                sentence.append(word)
                labels.append(tag)
    if len(sentence):
        X.append(sentence)
        sentence = []
        y.append(labels)
        labels = []
    return X,y

# def predict_label(data,y_true):
#     y_pred = []
#     for d in data:
#         text = ''.join([i[0] for i in d])
#         entity_mentions = NER.recognize(text)
#         pred = ['O' for _ in range(len(text))]
#         b = 0

#         for i in range(1,len(entity_mentions[1])-1):
#             item = entity_mentions[1][i]
#             print(item)
#             word,typ = item[0],item[1]
#             print(word)
#             start = text.where(word,b)
#             end = start + len(word)
#             pred[start] = 'B-' + typ
#             for i in range(start + 1, end):
#                 pred[i] = 'I-' + typ
#             b += len(word)
#         y_pred.append(pred)

#     return y_pred
def predict_label(data,y_true):
    y_pred = []
    for d in data:
        text = ''.join([i[0] for i in d])
        entity_mentions = NER.recognize(text)
        # print("entity_mentions",entity_mentions)
        pred = ['O' for _ in range(len(text))]
        b = 0
        for item in entity_mentions:
            word,typ = item[0],item[1]
            start = text.find(word,b)
            end = start + len(word)
            pred[start] = 'B-' + typ
            for i in range(start + 1, end):
                pred[i] = 'I-' + typ
            b += len(word)
        y_pred.append(pred)

    return y_pred

def evaluate():
    
    eval_path = './ner_data/dev/test.txt'
    test_data,y_true = load_eval_data(eval_path,max_len)
    y_pred = predict_label(test_data,y_true)

    f1 = f1_score(y_true,y_pred)
    p = precision_score(y_true,y_pred)
    r = recall_score(y_true,y_pred)

    acc = accuracy_score(y_true,y_pred)

    print("f1_score: {:.4f}, precision_score: {:.4f}, recall_score: {:.4f}, accuracy_score: {:.4f}".format(f1,p,r,acc))
    print(classification_report(y_true, y_pred, digits=4, suffix=False))

if __name__ == '__main__':
    evaluate()

f1_score: 0.0076, precision_score: 0.0065, recall_score: 0.0089, accuracy_score: 0.8839
                     precision    recall  f1-score   support

            Symptom     0.0151    0.0133    0.0142      3007
      Drug_Category     0.0000    0.0000    0.0000       552
               Drug     0.0000    0.0000    0.0000       658
Medical_Examination     0.0063    0.0083    0.0072       846
          Operation     0.0000    0.0000    0.0000       189

          micro avg     0.0065    0.0089    0.0076      5252
          macro avg     0.0097    0.0089    0.0093      5252

